From e2eba4bd0883954c5eba7ace94b0358b80c00aa3 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 12:49:43 +0300 Subject: [PATCH 01/78] add xgboost regressor --- src/operators/ml.cairo | 3 +- src/operators/ml/xgboost_regressor.cairo | 1 + src/operators/ml/xgboost_regressor/core.cairo | 36 +++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 src/operators/ml/xgboost_regressor.cairo create mode 100644 src/operators/ml/xgboost_regressor/core.cairo diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index b3eaef500..143cba273 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -1,6 +1,7 @@ mod tree_regressor; +mod xgboost_regressor; -use orion::operators::ml::tree_regressor::core::TreeRegressorTrait; +use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeNode}; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x16::FP16x16TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor; diff --git a/src/operators/ml/xgboost_regressor.cairo b/src/operators/ml/xgboost_regressor.cairo new file mode 100644 index 000000000..ef33ab296 --- /dev/null +++ b/src/operators/ml/xgboost_regressor.cairo @@ -0,0 +1 @@ +mod core; \ No newline at end of file diff --git a/src/operators/ml/xgboost_regressor/core.cairo b/src/operators/ml/xgboost_regressor/core.cairo new file mode 100644 index 000000000..d1bc87e04 --- /dev/null +++ b/src/operators/ml/xgboost_regressor/core.cairo @@ -0,0 +1,36 @@ +use orion::operators::ml::{TreeNode, TreeRegressorTrait}; +use orion::numbers::FixedTrait; + + +trait XGBoostPredictorTrait { + fn predict(trees: Span>, features: Span, weights: Span) -> T; +} + +fn predict< + T, + MAG, + impl TFixed: FixedTrait, + impl TTreeRegressor: TreeRegressorTrait, + impl TMul: Mul, + impl TAddEq: AddEq, + impl TCopy: Copy, + impl TDrop: Drop, +>( + ref trees: Span>, ref features: Span, ref weights: Span +) -> T { + let mut sum_prediction: T = FixedTrait::ZERO(); + + loop { + match trees.pop_front() { + Option::Some(tree) => { + let mut tree = *tree; + sum_prediction += tree.predict(features) * *weights.pop_front().unwrap() + }, + Option::None(_) => { + break; + } + }; + }; + + sum_prediction +} From 62a4844936be048080db174e808c471edb1be5d6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 15:04:43 +0300 Subject: [PATCH 02/78] implement tree classifier --- src/operators/ml.cairo | 1 + src/operators/ml/tree_classifier.cairo | 2 + src/operators/ml/tree_classifier/core.cairo | 94 +++++++++++++++++++ .../ml/tree_classifier/implementations.cairo | 4 + .../tree_classifier_fp16x16.cairo | 13 +++ .../tree_classifier_fp32x32.cairo | 13 +++ .../tree_classifier_fp64x64.cairo | 13 +++ .../tree_classifier_fp8x23.cairo | 13 +++ 8 files changed, 153 insertions(+) create mode 100644 src/operators/ml/tree_classifier.cairo create mode 100644 src/operators/ml/tree_classifier/core.cairo create mode 100644 src/operators/ml/tree_classifier/implementations.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 143cba273..1f4e0dad4 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -1,4 +1,5 @@ mod tree_regressor; +mod tree_classifier; mod xgboost_regressor; use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeNode}; diff --git a/src/operators/ml/tree_classifier.cairo b/src/operators/ml/tree_classifier.cairo new file mode 100644 index 000000000..2ab1a62ac --- /dev/null +++ b/src/operators/ml/tree_classifier.cairo @@ -0,0 +1,2 @@ +mod core; +mod implementations; \ No newline at end of file diff --git a/src/operators/ml/tree_classifier/core.cairo b/src/operators/ml/tree_classifier/core.cairo new file mode 100644 index 000000000..c8c4a2c7f --- /dev/null +++ b/src/operators/ml/tree_classifier/core.cairo @@ -0,0 +1,94 @@ +use orion::numbers::{FixedTrait}; + +#[derive(Copy, Drop)] +struct TreeNode { + left: Option>>, + right: Option>>, + split_feature: usize, + split_value: T, + prediction: T, + class_distribution: Span, // assuming class labels of type usize (span index), and probability as T. +} + +/// Trait +/// +/// predict - Given a set of features, predicts the target value using the constructed decision tree. +/// predict_proba - Given a set of features, predicts the probability of each X example being of a given class.. +trait TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> T; + fn predict_proba(ref self: TreeNode, features: Span) -> Span; +} + +fn predict< + T, + MAG, + impl FFixedTrait: FixedTrait, + impl TPartialOrd: PartialOrd, + impl FCopy: Copy, + impl FDrop: Drop, +>( + ref self: TreeNode, features: Span +) -> T { + let mut current_node: TreeNode = self; + + loop { + match current_node.left { + Option::Some(left) => { + match current_node.right { + Option::Some(right) => { + if *features.at(current_node.split_feature) < current_node.split_value { + current_node = left.unbox(); + } else { + current_node = right.unbox(); + } + }, + Option::None(_) => { + break; + } + } + }, + Option::None(_) => { + break; + } + }; + }; + + current_node.prediction +} + +fn predict_proba< + T, + MAG, + impl FFixedTrait: FixedTrait, + impl TPartialOrd: PartialOrd, + impl FCopy: Copy, + impl FDrop: Drop, +>( + ref self: TreeNode, features: Span +) -> Span { + let mut current_node: TreeNode = self; + + loop { + match current_node.left { + Option::Some(left) => { + match current_node.right { + Option::Some(right) => { + if *features.at(current_node.split_feature) < current_node.split_value { + current_node = left.unbox(); + } else { + current_node = right.unbox(); + } + }, + Option::None(_) => { + break; + } + } + }, + Option::None(_) => { + break; + } + }; + }; + + current_node.class_distribution +} diff --git a/src/operators/ml/tree_classifier/implementations.cairo b/src/operators/ml/tree_classifier/implementations.cairo new file mode 100644 index 000000000..2421c7809 --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations.cairo @@ -0,0 +1,4 @@ +mod tree_classifier_fp8x23; +mod tree_classifier_fp16x16; +mod tree_classifier_fp32x32; +mod tree_classifier_fp64x64; diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo new file mode 100644 index 000000000..579c18928 --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::FP16x16; + +impl FP16x16TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP16x16 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo new file mode 100644 index 000000000..c10a0c82f --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::{FP32x32, FP32x32Impl}; + +impl FP32x32TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP32x32 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo new file mode 100644 index 000000000..ce3e6541a --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::{FP64x64, FP64x64Impl}; + +impl FP64x64TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP64x64 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo new file mode 100644 index 000000000..88eaf0fc4 --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::FP8x23; + +impl FP8x23TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP8x23 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} From 4e8d686d7eb483d9823913f13bdb4b1cafd65505 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 15:14:21 +0300 Subject: [PATCH 03/78] rename trees struct --- src/operators/ml.cairo | 8 +++++- src/operators/ml/tree_classifier/core.cairo | 18 ++++++------- .../tree_classifier_fp16x16.cairo | 6 ++--- .../tree_classifier_fp32x32.cairo | 6 ++--- .../tree_classifier_fp64x64.cairo | 6 ++--- .../tree_classifier_fp8x23.cairo | 6 ++--- src/operators/ml/tree_regressor/core.cairo | 26 +++++++++---------- .../tree_regressor_fp16x16.cairo | 6 ++--- .../tree_regressor_fp32x32.cairo | 6 ++--- .../tree_regressor_fp64x64.cairo | 6 ++--- .../tree_regressor_fp8x23.cairo | 6 ++--- src/operators/ml/xgboost_regressor/core.cairo | 6 ++--- 12 files changed, 56 insertions(+), 50 deletions(-) diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 1f4e0dad4..0a29e00c2 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -2,8 +2,14 @@ mod tree_regressor; mod tree_classifier; mod xgboost_regressor; -use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeNode}; +use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeRegressor}; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x16::FP16x16TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp64x64::FP64x64TreeRegressor; + +use orion::operators::ml::tree_classifier::core::{TreeClassifierTrait, TreeClassifier}; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp16x16::FP16x16TreeClassifier; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp8x23::FP8x23TreeClassifier; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp32x32::FP32x32TreeClassifier; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp64x64::FP64x64TreeClassifier; diff --git a/src/operators/ml/tree_classifier/core.cairo b/src/operators/ml/tree_classifier/core.cairo index c8c4a2c7f..669c23abd 100644 --- a/src/operators/ml/tree_classifier/core.cairo +++ b/src/operators/ml/tree_classifier/core.cairo @@ -1,9 +1,9 @@ use orion::numbers::{FixedTrait}; #[derive(Copy, Drop)] -struct TreeNode { - left: Option>>, - right: Option>>, +struct TreeClassifier { + left: Option>>, + right: Option>>, split_feature: usize, split_value: T, prediction: T, @@ -15,8 +15,8 @@ struct TreeNode { /// predict - Given a set of features, predicts the target value using the constructed decision tree. /// predict_proba - Given a set of features, predicts the probability of each X example being of a given class.. trait TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> T; - fn predict_proba(ref self: TreeNode, features: Span) -> Span; + fn predict(ref self: TreeClassifier, features: Span) -> T; + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; } fn predict< @@ -27,9 +27,9 @@ fn predict< impl FCopy: Copy, impl FDrop: Drop, >( - ref self: TreeNode, features: Span + ref self: TreeClassifier, features: Span ) -> T { - let mut current_node: TreeNode = self; + let mut current_node: TreeClassifier = self; loop { match current_node.left { @@ -64,9 +64,9 @@ fn predict_proba< impl FCopy: Copy, impl FDrop: Drop, >( - ref self: TreeNode, features: Span + ref self: TreeClassifier, features: Span ) -> Span { - let mut current_node: TreeNode = self; + let mut current_node: TreeClassifier = self; loop { match current_node.left { diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo index 579c18928..1789c8a64 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::FP16x16; impl FP16x16TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP16x16 { + fn predict(ref self: TreeClassifier, features: Span) -> FP16x16 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo index c10a0c82f..442fb100a 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP32x32 { + fn predict(ref self: TreeClassifier, features: Span) -> FP32x32 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo index ce3e6541a..61c9415ec 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP64x64 { + fn predict(ref self: TreeClassifier, features: Span) -> FP64x64 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo index 88eaf0fc4..01f548efe 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::FP8x23; impl FP8x23TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP8x23 { + fn predict(ref self: TreeClassifier, features: Span) -> FP8x23 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index d81557dc2..ea4ac0271 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -3,9 +3,9 @@ use cubit::f64::procgen::rand::u64_between; use orion::numbers::{FixedTrait}; #[derive(Copy, Drop)] -struct TreeNode { - left: Option>>, - right: Option>>, +struct TreeRegressor { + left: Option>>, + right: Option>>, split_feature: usize, split_value: T, prediction: T, @@ -19,7 +19,7 @@ trait TreeRegressorTrait { /// # TreeRegressorTrait::fit /// /// ```rust - /// fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeNode; + /// fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; /// ``` /// /// Builds a decision tree based on the provided data and target values up to a specified maximum depth. @@ -33,7 +33,7 @@ trait TreeRegressorTrait { /// /// ## Returns /// - /// A `TreeNode` representing the root of the constructed decision tree. + /// A `TreeRegressor` representing the root of the constructed decision tree. /// /// ## Type Constraints /// @@ -69,11 +69,11 @@ trait TreeRegressorTrait { /// fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode; + ) -> TreeRegressor; /// # TreeRegressorTrait::predict /// /// ```rust - /// fn predict(ref self: TreeNode, features: Span) -> T; + /// fn predict(ref self: TreeRegressor, features: Span) -> T; /// ``` /// /// Predicts the target value for a set of features using the provided decision tree. @@ -124,7 +124,7 @@ trait TreeRegressorTrait { /// } /// ``` /// - fn predict(ref self: TreeNode, features: Span) -> T; + fn predict(ref self: TreeRegressor, features: Span) -> T; } fn predict< @@ -135,9 +135,9 @@ fn predict< impl FCopy: Copy, impl FDrop: Drop, >( - ref self: TreeNode, features: Span + ref self: TreeRegressor, features: Span ) -> T { - let mut current_node: TreeNode = self; + let mut current_node: TreeRegressor = self; loop { match current_node.left { @@ -363,7 +363,7 @@ fn fit< impl TDrop: Drop, >( data: Span>, target: Span, depth: usize, max_depth: usize, random_state: usize -) -> TreeNode { +) -> TreeRegressor { if depth == max_depth || data.len() < 2 { let mut total = FixedTrait::ZERO(); let mut target_copy = target; @@ -377,7 +377,7 @@ fn fit< } }; }; - return TreeNode { + return TreeRegressor { left: Option::None(()), right: Option::None(()), split_feature: 0, @@ -413,7 +413,7 @@ fn fit< }; }; - TreeNode { + TreeRegressor { left: Option::Some( BoxTrait::new( fit(left_data.span(), left_target.span(), depth + 1, max_depth, random_state) diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo index 3cb35ab13..7aeb6eb69 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::FP16x16; impl FP16x16TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP16x16 { + fn predict(ref self: TreeRegressor, features: Span) -> FP16x16 { core::predict(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo index d1791a9c9..288d7e15d 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP32x32 { + fn predict(ref self: TreeRegressor, features: Span) -> FP32x32 { core::predict(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo index 54adb6ce4..9102428fc 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP64x64 { + fn predict(ref self: TreeRegressor, features: Span) -> FP64x64 { core::predict(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo index baf61096c..54c195704 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::FP8x23; impl FP8x23TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP8x23 { + fn predict(ref self: TreeRegressor, features: Span) -> FP8x23 { core::predict(ref self, features) } } diff --git a/src/operators/ml/xgboost_regressor/core.cairo b/src/operators/ml/xgboost_regressor/core.cairo index d1bc87e04..1fb84c8dc 100644 --- a/src/operators/ml/xgboost_regressor/core.cairo +++ b/src/operators/ml/xgboost_regressor/core.cairo @@ -1,9 +1,9 @@ -use orion::operators::ml::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::{TreeRegressor, TreeRegressorTrait}; use orion::numbers::FixedTrait; trait XGBoostPredictorTrait { - fn predict(trees: Span>, features: Span, weights: Span) -> T; + fn predict(trees: Span>, features: Span, weights: Span) -> T; } fn predict< @@ -16,7 +16,7 @@ fn predict< impl TCopy: Copy, impl TDrop: Drop, >( - ref trees: Span>, ref features: Span, ref weights: Span + ref trees: Span>, ref features: Span, ref weights: Span ) -> T { let mut sum_prediction: T = FixedTrait::ZERO(); From 3de69074da5a97f7b0f524c10e11098ca29db189 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 15:29:33 +0300 Subject: [PATCH 04/78] generate doc --- docgen/src/main.rs | 8 ++ docs/SUMMARY.md | 5 ++ .../tree-classifier/README.md | 23 ++++++ .../tree-classifier/tree.predict.md | 35 +++++++++ .../tree-classifier/tree.predict_proba.md | 38 +++++++++ .../tree-regressor/tree.fit.md | 4 +- .../tree-regressor/tree.predict.md | 4 +- src/operators/ml/tree_classifier/core.cairo | 77 ++++++++++++++++++- src/operators/ml/tree_regressor/core.cairo | 2 +- 9 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 docs/framework/operators/machine-learning/tree-classifier/README.md create mode 100644 docs/framework/operators/machine-learning/tree-classifier/tree.predict.md create mode 100644 docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md diff --git a/docgen/src/main.rs b/docgen/src/main.rs index e628cf980..b29e49d77 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -42,6 +42,14 @@ fn main() { let trait_name: &str = "TreeRegressorTrait"; doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + + // TREE ClASSIFIER DOC + let trait_path = "src/operators/ml/tree_classifier/core.cairo"; + let doc_path = "docs/framework/operators/machine-learning/tree-classifier"; + let label = "tree"; + let trait_name: &str = "TreeClassifierTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); } fn doc_trait(trait_path: &str, doc_path: &str, label: &str) { diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 8cdf6ddcb..520e39f9b 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -99,6 +99,11 @@ * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) + * [Tree Classifier](framework/operators/machine-learning/tree-classifier/README.md) + * [tree.predict](framework/operators/machine-learning/tree-classifier/tree.predict.md) + * [tree.predict_proba](framework/operators/machine-learning/tree-classifier/tree.predict_proba.md) + + ## 🏛 Hub diff --git a/docs/framework/operators/machine-learning/tree-classifier/README.md b/docs/framework/operators/machine-learning/tree-classifier/README.md new file mode 100644 index 000000000..8c371c996 --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-classifier/README.md @@ -0,0 +1,23 @@ +# Tree Classifier + +`TreeClassifierTrait` provides a trait definition for decision tree classifier. This trait offers functionalities to build a decision tree and predict target values based on input features. + +```rust +use orion::operators::ml::TreeClassifierTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `TreeClassifierTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `TreeClassifierTrait` | + +*** + +| function | description | +| --- | --- | +| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. | +| [`tree.predict_proba`](tree.predict\_proba.md) | Predicts class probabilities based on feature data. | + diff --git a/docs/framework/operators/machine-learning/tree-classifier/tree.predict.md b/docs/framework/operators/machine-learning/tree-classifier/tree.predict.md new file mode 100644 index 000000000..efd46cddb --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-classifier/tree.predict.md @@ -0,0 +1,35 @@ +# TreeClassifierTrait::predict + +```rust + fn predict(ref self: TreeClassifier, features: Span) -> T; +``` + +Predicts the target value for a set of features using the provided decision tree. + +## Args + +* `self`: A reference to the decision tree used for making the prediction. +* `features`: A span representing the features for which the prediction is to be made. + +## Returns + +The predicted target value. + +## Type Constraints + +Constrain input and output types to fixed point. + +## Examples + +```rust +use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn tree_classifier_example(tree: TreeClassifier) { + + tree.predict( + array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + ); + +} +``` diff --git a/docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md b/docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md new file mode 100644 index 000000000..56afcd4e0 --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md @@ -0,0 +1,38 @@ +# TreeClassifierTrait::predict_proba + +```rust + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; +``` + +Given a set of features, this method traverses the decision tree +represented by `self` and returns the class distribution (probabilities) +found in the leaf node that matches the provided features. The traversal +stops once a leaf node is reached in the decision tree. + +## Args + +* `self`: A reference to the decision tree used for making the prediction. +* `features`: A span representing the features for which the prediction is to be made. + +## Returns + +Returns a `Span` representing the class distribution at the leaf node. + +## Type Constraints + +Constrain input and output types to fixed points. + +## Examples + +```rust +use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn tree_classifier_example(tree: TreeClassifier) { + + tree.predict_proba( + array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + ); + +} +``` diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md b/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md index e74929547..0ba61814d 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md +++ b/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md @@ -1,7 +1,7 @@ # TreeRegressorTrait::fit ```rust - fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeNode; + fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; ``` Builds a decision tree based on the provided data and target values up to a specified maximum depth. @@ -15,7 +15,7 @@ Builds a decision tree based on the provided data and target values up to a spec ## Returns -A `TreeNode` representing the root of the constructed decision tree. +A `TreeRegressor` representing the root of the constructed decision tree. ## Type Constraints diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md index 28d4a027c..c76714d58 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md +++ b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md @@ -1,7 +1,7 @@ # TreeRegressorTrait::predict ```rust - fn predict(ref self: TreeNode, features: Span) -> T; + fn predict(ref self: TreeRegressor, features: Span) -> T; ``` Predicts the target value for a set of features using the provided decision tree. @@ -17,7 +17,7 @@ The predicted target value. ## Type Constraints -Constrain input and output types to fixed point tensors. +Constrain input and output types to fixed point. ## Examples diff --git a/src/operators/ml/tree_classifier/core.cairo b/src/operators/ml/tree_classifier/core.cairo index 669c23abd..f4c948762 100644 --- a/src/operators/ml/tree_classifier/core.cairo +++ b/src/operators/ml/tree_classifier/core.cairo @@ -13,9 +13,84 @@ struct TreeClassifier { /// Trait /// /// predict - Given a set of features, predicts the target value using the constructed decision tree. -/// predict_proba - Given a set of features, predicts the probability of each X example being of a given class.. +/// predict_proba - Predicts class probabilities based on feature data. trait TreeClassifierTrait { + /// # TreeClassifierTrait::predict + /// + /// ```rust + /// fn predict(ref self: TreeClassifier, features: Span) -> T; + /// ``` + /// + /// Predicts the target value for a set of features using the provided decision tree. + /// + /// ## Args + /// + /// * `self`: A reference to the decision tree used for making the prediction. + /// * `features`: A span representing the features for which the prediction is to be made. + /// + /// ## Returns + /// + /// The predicted target value. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; + /// use orion::numbers::{FP16x16, FixedTrait}; + /// + /// fn tree_classifier_example(tree: TreeClassifier) { + /// + /// tree.predict( + /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + /// ); + /// + /// } + /// ``` + /// fn predict(ref self: TreeClassifier, features: Span) -> T; + /// # TreeClassifierTrait::predict_proba + /// + /// ```rust + /// fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; + /// ``` + /// + /// Given a set of features, this method traverses the decision tree + /// represented by `self` and returns the class distribution (probabilities) + /// found in the leaf node that matches the provided features. The traversal + /// stops once a leaf node is reached in the decision tree. + /// + /// ## Args + /// + /// * `self`: A reference to the decision tree used for making the prediction. + /// * `features`: A span representing the features for which the prediction is to be made. + /// + /// ## Returns + /// + /// Returns a `Span` representing the class distribution at the leaf node. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed points. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; + /// use orion::numbers::{FP16x16, FixedTrait}; + /// + /// fn tree_classifier_example(tree: TreeClassifier) { + /// + /// tree.predict_proba( + /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + /// ); + /// + /// } + /// ``` + /// fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; } diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index ea4ac0271..cb88de5b4 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -89,7 +89,7 @@ trait TreeRegressorTrait { /// /// ## Type Constraints /// - /// Constrain input and output types to fixed point tensors. + /// Constrain input and output types to fixed point. /// /// ## Examples /// From bcc412dda8acf30b6af26a1d939a606be19cf3ea Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 16:15:33 +0300 Subject: [PATCH 05/78] add xgboost implementations + docstrings --- src/operators/ml.cairo | 6 +++ src/operators/ml/xgboost_regressor.cairo | 3 +- src/operators/ml/xgboost_regressor/core.cairo | 53 +++++++++++++++++-- .../xgboost_regressor/implementations.cairo | 4 ++ .../xgboost_regressor_fp16x16.cairo | 12 +++++ .../xgboost_regressor_fp32x32.cairo | 12 +++++ .../xgboost_regressor_fp64x64.cairo | 12 +++++ .../xgboost_regressor_fp8x23.cairo | 12 +++++ 8 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 src/operators/ml/xgboost_regressor/implementations.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 143cba273..47c6cfa32 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -6,3 +6,9 @@ use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x1 use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp64x64::FP64x64TreeRegressor; + +use orion::operators::ml::xgboost_regressor::core::{XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp16x16::FP16x16XGBoostRegressor; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp8x23::FP8x23XGBoostRegressor; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp32x32::FP32x32XGBoostRegressor; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp64x64::FP64x64XGBoostRegressor; diff --git a/src/operators/ml/xgboost_regressor.cairo b/src/operators/ml/xgboost_regressor.cairo index ef33ab296..2ab1a62ac 100644 --- a/src/operators/ml/xgboost_regressor.cairo +++ b/src/operators/ml/xgboost_regressor.cairo @@ -1 +1,2 @@ -mod core; \ No newline at end of file +mod core; +mod implementations; \ No newline at end of file diff --git a/src/operators/ml/xgboost_regressor/core.cairo b/src/operators/ml/xgboost_regressor/core.cairo index d1bc87e04..e0534d10b 100644 --- a/src/operators/ml/xgboost_regressor/core.cairo +++ b/src/operators/ml/xgboost_regressor/core.cairo @@ -1,9 +1,56 @@ use orion::operators::ml::{TreeNode, TreeRegressorTrait}; use orion::numbers::FixedTrait; - -trait XGBoostPredictorTrait { - fn predict(trees: Span>, features: Span, weights: Span) -> T; +/// Trait +/// +/// predict - Predicts the target value for a set of features using the provided ensemble of decision trees. +trait XGBoostRegressorTrait { + /// # XGBoostRegressorTrait::predict + /// + /// ```rust + /// fn predict(ref self: Span>, ref features: Span, ref weights: Span) -> T; + /// ``` + /// + /// Predicts the target value for a set of features using the provided ensemble of decision trees + /// and combining their results using given weights. + /// + /// ## Args + /// + /// * `self`: A reference to a span representing a ensemble of decision trees. + /// * `features`: A reference to a span representing the features for which the prediction is to be made. + /// * `weights`: A reference to a span representing the weights applied to the predictions from each tree. + /// + /// ## Returns + /// + /// The predicted target value. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::ml::{FP16x16XGBoostRegressor, TreeRegressorTrait, TreeRegressor}; + /// use orion::numbers::{FP16x16, FixedTrait}; + /// + /// fn xgboost_regressor_example(trees: Span>) { + /// + /// let mut features = array![ + /// FixedTrait::new_unscaled(1, false), + /// FixedTrait::new_unscaled(2, false), + /// ].span(); + /// + /// let mut weights = array![ + /// FixedTrait::new_unscaled(0.5, false), + /// FixedTrait::new_unscaled(0.5, false) + /// ].span(); + /// + /// FP16x16XGBoostRegressor::predict(ref trees, ref features, ref weights); + /// } + /// ``` + /// + fn predict(ref self: Span>, ref features: Span, ref weights: Span) -> T; } fn predict< diff --git a/src/operators/ml/xgboost_regressor/implementations.cairo b/src/operators/ml/xgboost_regressor/implementations.cairo new file mode 100644 index 000000000..cd493cf91 --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations.cairo @@ -0,0 +1,4 @@ +mod xgboost_regressor_fp8x23; +mod xgboost_regressor_fp16x16; +mod xgboost_regressor_fp32x32; +mod xgboost_regressor_fp64x64; diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo new file mode 100644 index 000000000..41661711b --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP16x16TreeRegressor; +use orion::numbers::FP16x16; + +impl FP16x16XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP16x16 { + core::predict(ref self, ref features, ref weights) + } +} diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo new file mode 100644 index 000000000..83eca88ca --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP32x32TreeRegressor; +use orion::numbers::{FP32x32, FP32x32Impl}; + +impl FP32x32XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP32x32 { + core::predict(ref self, ref features, ref weights) + } +} diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo new file mode 100644 index 000000000..21c967976 --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP64x64TreeRegressor; +use orion::numbers::{FP64x64, FP64x64Impl}; + +impl FP64x64XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP64x64 { + core::predict(ref self, ref features, ref weights) + } +} diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo new file mode 100644 index 000000000..a011233bd --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP8x23TreeRegressor; +use orion::numbers::FP8x23; + +impl FP8x23XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP8x23 { + core::predict(ref self, ref features, ref weights) + } +} From 55402b2eac3b4230623e83a26aab165d22e79627 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 16:37:46 +0300 Subject: [PATCH 06/78] generate doc --- docgen/src/main.rs | 8 ++++ docs/SUMMARY.md | 2 + .../xgboost-regressor/README.md | 22 ++++++++++ .../xgboost-regressor/xgboost.predict.md | 44 +++++++++++++++++++ 4 files changed, 76 insertions(+) create mode 100644 docs/framework/operators/machine-learning/xgboost-regressor/README.md create mode 100644 docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md diff --git a/docgen/src/main.rs b/docgen/src/main.rs index e628cf980..15b962082 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -42,6 +42,14 @@ fn main() { let trait_name: &str = "TreeRegressorTrait"; doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + + // XGBOOST REGRESSOR DOC + let trait_path = "src/operators/ml/xgboost_regressor/core.cairo"; + let doc_path = "docs/framework/operators/machine-learning/xgboost-regressor"; + let label = "xgboost"; + let trait_name: &str = "XGBoostRegressorTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); } fn doc_trait(trait_path: &str, doc_path: &str, label: &str) { diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 8cdf6ddcb..473871fe3 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -99,6 +99,8 @@ * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) + * [XGBoost Regressor](framework/operators/machine-learning/xgboost-regressor/README.md) + * [xgboost.predict](framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md) ## 🏛 Hub diff --git a/docs/framework/operators/machine-learning/xgboost-regressor/README.md b/docs/framework/operators/machine-learning/xgboost-regressor/README.md new file mode 100644 index 000000000..1187916e8 --- /dev/null +++ b/docs/framework/operators/machine-learning/xgboost-regressor/README.md @@ -0,0 +1,22 @@ +# Tree Regressor + +`XGBoostRegressorTrait` provides a trait definition for xgboost regression. This trait offers functionalities to predict target values based on input features. + +```rust +use orion::operators::ml::XGBoostRegressorTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `XGBoostRegressorTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `TreeRegressorTrait` | + +*** + +| function | description | +| --- | --- | +| [`xgboost.predict`](xgboost.predict.md) | Predicts the target value for a set of features using the provided ensemble of decision trees. | + diff --git a/docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md b/docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md new file mode 100644 index 000000000..ed7d7a31d --- /dev/null +++ b/docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md @@ -0,0 +1,44 @@ +# XGBoostRegressorTrait::predict + +```rust + fn predict(ref self: Span>, ref features: Span, ref weights: Span) -> T; +``` + +Predicts the target value for a set of features using the provided ensemble of decision trees +and combining their results using given weights. + +## Args + +* `self`: A reference to a span representing a ensemble of decision trees. +* `features`: A reference to a span representing the features for which the prediction is to be made. +* `weights`: A reference to a span representing the weights applied to the predictions from each tree. + +## Returns + +The predicted target value. + +## Type Constraints + +Constrain input and output types to fixed point. + +## Examples + +```rust +use orion::operators::ml::{FP16x16XGBoostRegressor, TreeRegressorTrait, TreeRegressor}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn xgboost_regressor_example(trees: Span>) { + + let mut features = array![ + FixedTrait::new_unscaled(1, false), + FixedTrait::new_unscaled(2, false), + ].span(); + + let mut weights = array![ + FixedTrait::new_unscaled(0.5, false), + FixedTrait::new_unscaled(0.5, false) + ].span(); + + FP16x16XGBoostRegressor::predict(ref trees, ref features, ref weights); +} +``` From 5dfd5c573448d513f9e8585d2ed15ac4234a06ee Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 17:02:08 +0300 Subject: [PATCH 07/78] rename trees --- .../implementations/xgboost_regressor_fp16x16.cairo | 4 ++-- .../implementations/xgboost_regressor_fp32x32.cairo | 4 ++-- .../implementations/xgboost_regressor_fp64x64.cairo | 4 ++-- .../implementations/xgboost_regressor_fp8x23.cairo | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo index 41661711b..e8202a8d1 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP16x16TreeRegressor; use orion::numbers::FP16x16; impl FP16x16XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP16x16 { core::predict(ref self, ref features, ref weights) } diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo index 83eca88ca..6d266fce4 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP32x32TreeRegressor; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP32x32 { core::predict(ref self, ref features, ref weights) } diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo index 21c967976..ff21c9860 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP64x64TreeRegressor; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP64x64 { core::predict(ref self, ref features, ref weights) } diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo index a011233bd..ac2f1d3b5 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP8x23TreeRegressor; use orion::numbers::FP8x23; impl FP8x23XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP8x23 { core::predict(ref self, ref features, ref weights) } From cb32695371cd5f6f7880907502f3c8da29adff1f Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 17:15:15 +0300 Subject: [PATCH 08/78] remove fit from TreeRegressor --- .../machine-learning/tree-regressor/README.md | 1 - .../tree-regressor/tree.fit.md | 50 --- .../tree-regressor/tree.predict.md | 26 +- src/operators/ml/tree_regressor/core.cairo | 347 +----------------- .../tree_regressor_fp16x16.cairo | 6 - .../tree_regressor_fp32x32.cairo | 6 - .../tree_regressor_fp64x64.cairo | 6 - .../tree_regressor_fp8x23.cairo | 6 - tests/src/ml/tree_regressor.cairo | 72 +--- 9 files changed, 9 insertions(+), 511 deletions(-) delete mode 100644 docs/framework/operators/machine-learning/tree-regressor/tree.fit.md diff --git a/docs/framework/operators/machine-learning/tree-regressor/README.md b/docs/framework/operators/machine-learning/tree-regressor/README.md index 7df2112c4..286587884 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/README.md +++ b/docs/framework/operators/machine-learning/tree-regressor/README.md @@ -18,6 +18,5 @@ Orion supports currently only fixed point data types for `TreeRegressorTrait`. | function | description | | --- | --- | -| [`tree.fit`](tree.fit.md) | Constructs a decision tree regressor based on the provided data and target values. | | [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. | diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md b/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md deleted file mode 100644 index 0ba61814d..000000000 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md +++ /dev/null @@ -1,50 +0,0 @@ -# TreeRegressorTrait::fit - -```rust - fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; -``` - -Builds a decision tree based on the provided data and target values up to a specified maximum depth. - -## Args - -* `data`: A span of spans representing rows of features in the dataset. -* `target`: A span representing the target values corresponding to each row in the dataset. -* `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached. -* `random_state`: It ensures that the tie-breaking is consistent across multiple runs, leading to reproducible results. - -## Returns - -A `TreeRegressor` representing the root of the constructed decision tree. - -## Type Constraints - -Constrain input and output types to fixed point tensors. - -## Examples - -```rust -use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; -use orion::numbers::{FP16x16, FixedTrait}; - -fn tree_regressor_example() { - - let data = array![ - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - ] - .span(); - - let target = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - TreeRegressorTrait::fit(data, target, 3, 42); -} -``` diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md index c76714d58..6281af625 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md +++ b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md @@ -22,32 +22,14 @@ Constrain input and output types to fixed point. ## Examples ```rust -use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; +use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait, TreeRegressor}; use orion::numbers::{FP16x16, FixedTrait}; -fn tree_regressor_example() { +fn tree_regressor_example(tree: TreeRegressor) { - let data = array![ - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - ] - .span(); - - let target = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - let mut tree = TreeRegressorTrait::fit(data, target, 3); - - let prediction_1 = tree - .predict( + tree.predict( array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() ); + } ``` diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index cb88de5b4..1206d9094 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -13,63 +13,8 @@ struct TreeRegressor { /// Trait /// -/// fit - Constructs a decision tree regressor based on the provided data and target values. /// predict - Given a set of features, predicts the target value using the constructed decision tree. trait TreeRegressorTrait { - /// # TreeRegressorTrait::fit - /// - /// ```rust - /// fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; - /// ``` - /// - /// Builds a decision tree based on the provided data and target values up to a specified maximum depth. - /// - /// ## Args - /// - /// * `data`: A span of spans representing rows of features in the dataset. - /// * `target`: A span representing the target values corresponding to each row in the dataset. - /// * `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached. - /// * `random_state`: It ensures that the tie-breaking is consistent across multiple runs, leading to reproducible results. - /// - /// ## Returns - /// - /// A `TreeRegressor` representing the root of the constructed decision tree. - /// - /// ## Type Constraints - /// - /// Constrain input and output types to fixed point tensors. - /// - /// ## Examples - /// - /// ```rust - /// use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; - /// use orion::numbers::{FP16x16, FixedTrait}; - /// - /// fn tree_regressor_example() { - /// - /// let data = array![ - /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - /// array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - /// array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - /// array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - /// ] - /// .span(); - /// - /// let target = array![ - /// FixedTrait::new_unscaled(2, false), - /// FixedTrait::new_unscaled(4, false), - /// FixedTrait::new_unscaled(6, false), - /// FixedTrait::new_unscaled(8, false) - /// ] - /// .span(); - /// - /// TreeRegressorTrait::fit(data, target, 3, 42); - /// } - /// ``` - /// - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor; /// # TreeRegressorTrait::predict /// /// ```rust @@ -94,33 +39,15 @@ trait TreeRegressorTrait { /// ## Examples /// /// ```rust - /// use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; + /// use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait, TreeRegressor}; /// use orion::numbers::{FP16x16, FixedTrait}; /// - /// fn tree_regressor_example() { + /// fn tree_regressor_example(tree: TreeRegressor) { /// - /// let data = array![ - /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - /// array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - /// array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - /// array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - /// ] - /// .span(); - /// - /// let target = array![ - /// FixedTrait::new_unscaled(2, false), - /// FixedTrait::new_unscaled(4, false), - /// FixedTrait::new_unscaled(6, false), - /// FixedTrait::new_unscaled(8, false) - /// ] - /// .span(); - /// - /// let mut tree = TreeRegressorTrait::fit(data, target, 3); - /// - /// let prediction_1 = tree - /// .predict( + /// tree.predict( /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() /// ); + /// /// } /// ``` /// @@ -163,269 +90,3 @@ fn predict< current_node.prediction } - -fn mse< - T, - MAG, - impl FFixedTrait: FixedTrait, - impl TSub: Sub, - impl TAddEq: AddEq, - impl TDiv: Div, - impl U32IntoMAG: Into, - impl FeltTryIntoMAG: TryInto, - impl TCopy: Copy, - impl TDrop: Drop, ->( - y: Span, prediction: T -) -> T { - let mut sum_squared_error: T = FixedTrait::ZERO(); - - let mut y_copy = y; - loop { - match y_copy.pop_front() { - Option::Some(yi) => { - let error = *yi - prediction; - sum_squared_error += error - .pow(FixedTrait::new_unscaled(2.try_into().unwrap(), false)); - }, - Option::None(_) => { - break; - } - }; - }; - - sum_squared_error / FixedTrait::new_unscaled(y.len().into(), false) -} - -fn best_split< - T, - MAG, - impl FFixedTrait: FixedTrait, - impl TPartialOrd: PartialOrd, - impl TPartialEq: PartialEq, - impl TAddEq: AddEq, - impl TAdd: Add, - impl TSub: Sub, - impl TDiv: Div, - impl TMul: Mul, - impl U32IntoMAG: Into, - impl FeltTryIntoMAG: TryInto, - impl TCopy: Copy, - impl TDrop: Drop, ->( - data: Span>, target: Span, random_state: usize -) -> (usize, T, T) { - let mut best_mse = FixedTrait::MAX(); - let mut best_split_feature = 0; - let mut best_splits: Array<(usize, T, T)> = ArrayTrait::new(); - - let n_features: u32 = (*data[0]).len(); - - let mut feature = 0; - loop { - if feature == n_features { - break; - }; - - let mut unique_values = ArrayTrait::new(); - let mut data_copy = data; - loop { - match data_copy.pop_front() { - Option::Some(row) => { - unique_values.append(*row[feature]) - }, - Option::None(_) => { - break; - } - }; - }; - - let mut unique_values = unique_values.span(); - loop { - match unique_values.pop_front() { - Option::Some(value) => { - let mut left_target = ArrayTrait::new(); - let mut right_target = ArrayTrait::new(); - - let mut i = 0; - let mut target_copy = target; - loop { - match target_copy.pop_front() { - Option::Some(t) => { - if *(*data.at(i))[feature] < *value { - left_target.append(*t); - } else { - right_target.append(*t); - } - i += 1; - }, - Option::None(_) => { - break; - } - }; - }; - - if !left_target.is_empty() && !right_target.is_empty() { - let mut left_sum = FixedTrait::ZERO(); - let mut left_target_copy = left_target.span(); - loop { - match left_target_copy.pop_front() { - Option::Some(val) => { - left_sum += *val; - }, - Option::None(_) => { - break; - } - }; - }; - let left_target_as_fp: T = FixedTrait::new_unscaled( - left_target.len().into(), false - ); - let left_pred = left_sum / left_target_as_fp; - - let mut right_sum = FixedTrait::ZERO(); - let mut right_target_copy = right_target.span(); - loop { - match right_target_copy.pop_front() { - Option::Some(val) => { - right_sum += *val; - }, - Option::None(_) => { - break; - } - }; - }; - let right_target_as_fp: T = FixedTrait::new_unscaled( - right_target.len().into(), false - ); - let right_pred = right_sum / right_target_as_fp; - - let current_mse = (left_target_as_fp * mse(left_target.span(), left_pred)) - + (right_target_as_fp * mse(right_target.span(), right_pred)); - - if !(current_mse > best_mse) { - if current_mse < best_mse { - best_mse = current_mse; - best_splits = array![]; - } - - let mut total_sum = FixedTrait::ZERO(); - let mut target_copy = target; - loop { - match target_copy.pop_front() { - Option::Some(t) => { - total_sum += *t; - }, - Option::None(_) => { - break; - } - }; - }; - - let prediction = total_sum - / FixedTrait::new_unscaled(target.len().into(), false); - - best_splits.append((feature, *value, prediction)); - } - } - }, - Option::None(_) => { - break; - } - }; - }; - - feature += 1; - }; - - let random_idx: usize = u64_between(random_state.into(), 0, best_splits.len().into()) - .try_into() - .unwrap(); - let (best_split_feature, best_split_value, best_prediction) = *best_splits.at(random_idx); - - (best_split_feature, best_split_value, best_prediction) -} - -fn fit< - T, - MAG, - impl FFixedTrait: FixedTrait, - impl TPartialOrd: PartialOrd, - impl TPartialEq: PartialEq, - impl TAddEq: AddEq, - impl TAdd: Add, - impl TSub: Sub, - impl TDiv: Div, - impl TMul: Mul, - impl U32IntoMAG: Into, - impl FeltTryIntoMAG: TryInto, - impl TCopy: Copy, - impl TDrop: Drop, ->( - data: Span>, target: Span, depth: usize, max_depth: usize, random_state: usize -) -> TreeRegressor { - if depth == max_depth || data.len() < 2 { - let mut total = FixedTrait::ZERO(); - let mut target_copy = target; - loop { - match target_copy.pop_front() { - Option::Some(val) => { - total += *val; - }, - Option::None(_) => { - break; - } - }; - }; - return TreeRegressor { - left: Option::None(()), - right: Option::None(()), - split_feature: 0, - split_value: FixedTrait::ZERO(), - prediction: total / FixedTrait::new_unscaled(target.len().into(), false), - }; - } - - let (split_feature, split_value, prediction) = best_split(data, target, random_state); - let mut left_data = ArrayTrait::new(); - let mut left_target = ArrayTrait::new(); - - let mut right_data = ArrayTrait::new(); - let mut right_target = ArrayTrait::new(); - - let mut data_copy = data; - let mut i: usize = 0; - loop { - match data_copy.pop_front() { - Option::Some(row) => { - if *(*row).at(split_feature) < split_value { - left_data.append(row.clone()); - left_target.append(*target[i]) - } else { - right_data.append(row.clone()); - right_target.append(*target[i]) - } - i += 1 - }, - Option::None(_) => { - break; - } - }; - }; - - TreeRegressor { - left: Option::Some( - BoxTrait::new( - fit(left_data.span(), left_target.span(), depth + 1, max_depth, random_state) - ) - ), - right: Option::Some( - BoxTrait::new( - fit(right_data.span(), right_target.span(), depth + 1, max_depth, random_state) - ) - ), - split_feature, - split_value, - prediction, - } -} diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo index 7aeb6eb69..10fa1aa53 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::FP16x16; impl FP16x16TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP16x16 { core::predict(ref self, features) } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo index 288d7e15d..72c5033c2 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP32x32 { core::predict(ref self, features) } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo index 9102428fc..4450f630c 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP64x64 { core::predict(ref self, features) } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo index 54c195704..f6b1361be 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::FP8x23; impl FP8x23TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP8x23 { core::predict(ref self, features) } diff --git a/tests/src/ml/tree_regressor.cairo b/tests/src/ml/tree_regressor.cairo index 057fa58b0..98a9a3265 100644 --- a/tests/src/ml/tree_regressor.cairo +++ b/tests/src/ml/tree_regressor.cairo @@ -1,71 +1 @@ -use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; -use orion::operators::ml::tree_regressor::core::mse; -use orion::numbers::{FP16x16, FixedTrait}; - -#[test] -#[available_gas(2000000000000)] -fn test_mse() { - let mut y = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - let prediction = FixedTrait::::new_unscaled(5, false); - let expected_mse = FixedTrait::::new_unscaled( - 5, false - ); // MSE = [(2-5)^2 + (4-5)^2 + (6-5)^2 + (8-5)^2] / 4 = 5 - - let computed_mse = mse(y, prediction); - assert(computed_mse == expected_mse, 'Failed mse'); -} - - -#[test] -#[available_gas(2000000000000)] -fn test_tree() { - let data = array![ - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - ] - .span(); - - let target = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - let mut tree = TreeRegressorTrait::fit(data, target, 3, 42); - - let prediction_1 = tree - .predict( - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() - ); - - let prediction_2 = tree - .predict( - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span() - ); - - let prediction_3 = tree - .predict( - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span() - ); - - let prediction_4 = tree - .predict( - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span() - ); - - assert(prediction_1 == FixedTrait::::new_unscaled(2, false), 'should predict 2'); - assert(prediction_2 == FixedTrait::::new_unscaled(4, false), 'should predict 4'); - assert(prediction_3 == FixedTrait::::new_unscaled(6, false), 'should predict 6'); - assert(prediction_4 == FixedTrait::::new_unscaled(8, false), 'should predict 8'); -} +// TODO: make test once Tree transpilation implemented \ No newline at end of file From 7ab41bcf630ca0695c6d96f35cec2b8e8aeb3d8c Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 18:29:03 +0300 Subject: [PATCH 09/78] remove fit from doc --- docs/SUMMARY.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 66647a7a3..07de3f249 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -97,7 +97,6 @@ * [nn.linear](framework/operators/neural-network/nn.linear.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) - * [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) * [Tree Classifier](framework/operators/machine-learning/tree-classifier/README.md) * [tree.predict](framework/operators/machine-learning/tree-classifier/tree.predict.md) From 5c98b64f45fabef22a167f562f9004b6f9032752 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 10:38:09 +0300 Subject: [PATCH 10/78] implement fp16x16wide --- src/numbers/fixed_point/implementations.cairo | 1 + .../implementations/fp16x16wide.cairo | 3 + .../implementations/fp16x16wide/core.cairo | 390 ++++++ .../implementations/fp16x16wide/helpers.cairo | 41 + .../implementations/fp16x16wide/math.cairo | 5 + .../fp16x16wide/math/comp.cairo | 76 + .../fp16x16wide/math/core.cairo | 659 +++++++++ .../fp16x16wide/math/hyp.cairo | 159 +++ .../fp16x16wide/math/lut.cairo | 1235 +++++++++++++++++ .../fp16x16wide/math/trig.cairo | 450 ++++++ 10 files changed, 3019 insertions(+) create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo diff --git a/src/numbers/fixed_point/implementations.cairo b/src/numbers/fixed_point/implementations.cairo index 8b010e349..e6152e25a 100644 --- a/src/numbers/fixed_point/implementations.cairo +++ b/src/numbers/fixed_point/implementations.cairo @@ -2,3 +2,4 @@ mod fp8x23; mod fp16x16; mod fp64x64; mod fp32x32; +mod fp16x16wide; \ No newline at end of file diff --git a/src/numbers/fixed_point/implementations/fp16x16wide.cairo b/src/numbers/fixed_point/implementations/fp16x16wide.cairo new file mode 100644 index 000000000..e9acee340 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide.cairo @@ -0,0 +1,3 @@ +mod core; +mod math; +mod helpers; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo new file mode 100644 index 000000000..01a1d8b8d --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo @@ -0,0 +1,390 @@ +use debug::PrintTrait; + +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{TryInto, Into}; + +use orion::numbers::signed_integer::{i32::i32, i8::i8}; +use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::{core, trig, hyp}; +use orion::numbers::fixed_point::utils; + +/// A struct representing a fixed point number. +#[derive(Serde, Copy, Drop)] +struct FP16x16W { + mag: u64, + sign: bool +} + +// CONSTANTS + +const TWO: u64 = 131072; // 2 ** 17 +const ONE: u64 = 65536; // 2 ** 16 +const HALF: u64 = 32768; // 2 ** 15 +const MAX: u64 = 2147483648; // 2 ** 31 + + +impl FP16x16WImpl of FixedTrait { + fn ZERO() -> FP16x16W { + return FP16x16W { mag: 0, sign: false }; + } + + fn ONE() -> FP16x16W { + return FP16x16W { mag: ONE, sign: false }; + } + + fn MAX() -> FP16x16W { + return FP16x16W { mag: MAX, sign: false }; + } + + fn new(mag: u64, sign: bool) -> FP16x16W { + return FP16x16W { mag: mag, sign: sign }; + } + + fn new_unscaled(mag: u64, sign: bool) -> FP16x16W { + return FP16x16W { mag: mag * ONE, sign: sign }; + } + + fn from_felt(val: felt252) -> FP16x16W { + let mag = integer::u64_try_from_felt252(utils::felt_abs(val)).unwrap(); + return FixedTrait::new(mag, utils::felt_sign(val)); + } + + fn abs(self: FP16x16W) -> FP16x16W { + return core::abs(self); + } + + fn acos(self: FP16x16W) -> FP16x16W { + return trig::acos_fast(self); + } + + fn acos_fast(self: FP16x16W) -> FP16x16W { + return trig::acos_fast(self); + } + + fn acosh(self: FP16x16W) -> FP16x16W { + return hyp::acosh(self); + } + + fn asin(self: FP16x16W) -> FP16x16W { + return trig::asin_fast(self); + } + + fn asin_fast(self: FP16x16W) -> FP16x16W { + return trig::asin_fast(self); + } + + fn asinh(self: FP16x16W) -> FP16x16W { + return hyp::asinh(self); + } + + fn atan(self: FP16x16W) -> FP16x16W { + return trig::atan_fast(self); + } + + fn atan_fast(self: FP16x16W) -> FP16x16W { + return trig::atan_fast(self); + } + + fn atanh(self: FP16x16W) -> FP16x16W { + return hyp::atanh(self); + } + + fn ceil(self: FP16x16W) -> FP16x16W { + return core::ceil(self); + } + + fn cos(self: FP16x16W) -> FP16x16W { + return trig::cos_fast(self); + } + + fn cos_fast(self: FP16x16W) -> FP16x16W { + return trig::cos_fast(self); + } + + fn cosh(self: FP16x16W) -> FP16x16W { + return hyp::cosh(self); + } + + fn floor(self: FP16x16W) -> FP16x16W { + return core::floor(self); + } + + // Calculates the natural exponent of x: e^x + fn exp(self: FP16x16W) -> FP16x16W { + return core::exp(self); + } + + // Calculates the binary exponent of x: 2^x + fn exp2(self: FP16x16W) -> FP16x16W { + return core::exp2(self); + } + + // Calculates the natural logarithm of x: ln(x) + // self must be greater than zero + fn ln(self: FP16x16W) -> FP16x16W { + return core::ln(self); + } + + // Calculates the binary logarithm of x: log2(x) + // self must be greather than zero + fn log2(self: FP16x16W) -> FP16x16W { + return core::log2(self); + } + + // Calculates the base 10 log of x: log10(x) + // self must be greater than zero + fn log10(self: FP16x16W) -> FP16x16W { + return core::log10(self); + } + + // Calclates the value of x^y and checks for overflow before returning + // self is a fixed point value + // b is a fixed point value + fn pow(self: FP16x16W, b: FP16x16W) -> FP16x16W { + return core::pow(self, b); + } + + fn round(self: FP16x16W) -> FP16x16W { + return core::round(self); + } + + fn sin(self: FP16x16W) -> FP16x16W { + return trig::sin_fast(self); + } + + fn sin_fast(self: FP16x16W) -> FP16x16W { + return trig::sin_fast(self); + } + + fn sinh(self: FP16x16W) -> FP16x16W { + return hyp::sinh(self); + } + + // Calculates the square root of a fixed point value + // x must be positive + fn sqrt(self: FP16x16W) -> FP16x16W { + return core::sqrt(self); + } + + fn tan(self: FP16x16W) -> FP16x16W { + return trig::tan_fast(self); + } + + fn tan_fast(self: FP16x16W) -> FP16x16W { + return trig::tan_fast(self); + } + + fn tanh(self: FP16x16W) -> FP16x16W { + return hyp::tanh(self); + } + + fn sign(self: FP16x16W) -> FP16x16W { + return core::sign(self); + } +} + + +impl FP16x16WPrint of PrintTrait { + fn print(self: FP16x16W) { + self.sign.print(); + self.mag.print(); + } +} + +// Into a raw felt without unscaling +impl FP16x16WIntoFelt252 of Into { + fn into(self: FP16x16W) -> felt252 { + let mag_felt = self.mag.into(); + + if self.sign { + return mag_felt * -1; + } else { + return mag_felt * 1; + } + } +} + +impl FP16x16WIntoI32 of Into { + fn into(self: FP16x16W) -> i32 { + _i32_into_fp(self) + } +} + +impl FP16x16WTryIntoI8 of TryInto { + fn try_into(self: FP16x16W) -> Option { + _i8_try_from_fp(self) + } +} + + +impl FP16x16WTryIntoU128 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP16x16WTryIntoU64 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP16x16WTryIntoU32 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some(self.mag / ONE); + } + } +} + +impl FP16x16WTryIntoU16 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16WTryIntoU8 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16WPartialEq of PartialEq { + #[inline(always)] + fn eq(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { + return core::eq(lhs, rhs); + } + + #[inline(always)] + fn ne(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { + return core::ne(lhs, rhs); + } +} + +impl FP16x16WAdd of Add { + fn add(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::add(lhs, rhs); + } +} + +impl FP16x16WAddEq of AddEq { + #[inline(always)] + fn add_eq(ref self: FP16x16W, other: FP16x16W) { + self = Add::add(self, other); + } +} + +impl FP16x16WSub of Sub { + fn sub(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::sub(lhs, rhs); + } +} + +impl FP16x16WSubEq of SubEq { + #[inline(always)] + fn sub_eq(ref self: FP16x16W, other: FP16x16W) { + self = Sub::sub(self, other); + } +} + +impl FP16x16WMul of Mul { + fn mul(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::mul(lhs, rhs); + } +} + +impl FP16x16WMulEq of MulEq { + #[inline(always)] + fn mul_eq(ref self: FP16x16W, other: FP16x16W) { + self = Mul::mul(self, other); + } +} + +impl FP16x16WDiv of Div { + fn div(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::div(lhs, rhs); + } +} + +impl FP16x16WDivEq of DivEq { + #[inline(always)] + fn div_eq(ref self: FP16x16W, other: FP16x16W) { + self = Div::div(self, other); + } +} + +impl FP16x16WPartialOrd of PartialOrd { + #[inline(always)] + fn ge(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::ge(lhs, rhs); + } + + #[inline(always)] + fn gt(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::gt(lhs, rhs); + } + + #[inline(always)] + fn le(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::le(lhs, rhs); + } + + #[inline(always)] + fn lt(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::lt(lhs, rhs); + } +} + +impl FP16x16WNeg of Neg { + #[inline(always)] + fn neg(a: FP16x16W) -> FP16x16W { + return core::neg(a); + } +} + +impl FP16x16WRem of Rem { + #[inline(always)] + fn rem(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::rem(lhs, rhs); + } +} + + +/// INTERNAL + +fn _i32_into_fp(x: FP16x16W) -> i32 { + i32 { mag: (x.mag / ONE).try_into().unwrap(), sign: x.sign } +} + +fn _i8_try_from_fp(x: FP16x16W) -> Option { + let unscaled_mag: Option = (x.mag / ONE).try_into(); + + match unscaled_mag { + Option::Some(val) => Option::Some(i8 { mag: unscaled_mag.unwrap(), sign: x.sign }), + Option::None(_) => Option::None(()) + } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo new file mode 100644 index 000000000..c2a65e156 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo @@ -0,0 +1,41 @@ +use debug::PrintTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16W, FP16x16WImpl, FP16x16WSub, FP16x16WDiv, FixedTrait, FP16x16WPrint +}; + +const DEFAULT_PRECISION: u64 = 7; // 1e-4 + +// To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. +// To use `custom_precision` of 430_u32: `Option::Some(430_u32)`. +fn assert_precise(result: FP16x16W, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = (result - FixedTrait::from_felt(expected)).mag; + + if (diff > precision) { + result.print(); + assert(diff <= precision, msg); + } +} + +fn assert_relative( + result: FP16x16W, expected: felt252, msg: felt252, custom_precision: Option +) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = result - FixedTrait::from_felt(expected); + let rel_diff = (diff / result).mag; + + if (rel_diff > precision) { + result.print(); + assert(rel_diff <= precision, msg); + } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo new file mode 100644 index 000000000..970c65f30 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo @@ -0,0 +1,5 @@ +mod core; +mod comp; +mod lut; +mod trig; +mod hyp; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo new file mode 100644 index 000000000..63a3e4855 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo @@ -0,0 +1,76 @@ +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16W, FixedTrait, FP16x16WImpl, FP16x16WPartialOrd, FP16x16WPartialEq +}; + +fn max(a: FP16x16W, b: FP16x16W) -> FP16x16W { + if (a >= b) { + return a; + } else { + return b; + } +} + +fn min(a: FP16x16W, b: FP16x16W) -> FP16x16W { + if (a <= b) { + return a; + } else { + return b; + } +} + +fn xor(a: FP16x16W, b: FP16x16W) -> bool { + if (a == FixedTrait::new(0, false) || b == FixedTrait::new(0, false)) && (a != b) { + return true; + } else { + return false; + } +} + +fn or(a: FP16x16W, b: FP16x16W) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero && b == zero { + return false; + } else { + return true; + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +#[test] +fn test_max() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(max(a, a) == a, 'max(a, a)'); + assert(max(a, b) == a, 'max(a, b)'); + assert(max(a, c) == a, 'max(a, c)'); + + assert(max(b, a) == a, 'max(b, a)'); + assert(max(b, b) == b, 'max(b, b)'); + assert(max(b, c) == b, 'max(b, c)'); + + assert(max(c, a) == a, 'max(c, a)'); + assert(max(c, b) == b, 'max(c, b)'); + assert(max(c, c) == c, 'max(c, c)'); +} + +#[test] +fn test_min() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(min(a, a) == a, 'min(a, a)'); + assert(min(a, b) == b, 'min(a, b)'); + assert(min(a, c) == c, 'min(a, c)'); + + assert(min(b, a) == b, 'min(b, a)'); + assert(min(b, b) == b, 'min(b, b)'); + assert(min(b, c) == c, 'min(b, c)'); + + assert(min(c, a) == c, 'min(c, a)'); + assert(min(c, b) == c, 'min(c, b)'); + assert(min(c, c) == c, 'min(c, c)'); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo new file mode 100644 index 000000000..33c1c6d85 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -0,0 +1,659 @@ +use core::debug::PrintTrait; +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{Into, TryInto}; +use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; + +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, MAX, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WAddEq, FP16x16WSub, FP16x16WMul, + FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, FP16x16WNeg, + FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; + +// PUBLIC + +fn abs(a: FP16x16W) -> FP16x16W { + return FixedTrait::new(a.mag, false); +} + +fn add(a: FP16x16W, b: FP16x16W) -> FP16x16W { + if a.sign == b.sign { + return FixedTrait::new(a.mag + b.mag, a.sign); + } + + if a.mag == b.mag { + return FixedTrait::ZERO(); + } + + if (a.mag > b.mag) { + return FixedTrait::new(a.mag - b.mag, a.sign); + } else { + return FixedTrait::new(b.mag - a.mag, b.sign); + } +} + +fn ceil(a: FP16x16W) -> FP16x16W { + let (div, rem) = u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div + 1, false); + } else if div == 0 { + return FixedTrait::new_unscaled(0, false); + } else { + return FixedTrait::new_unscaled(div, true); + } +} + +fn div(a: FP16x16W, b: FP16x16W) -> FP16x16W { + let a_u64 = integer::u64_wide_mul(a.mag, ONE); + let res_u64 = a_u64 / b.mag.into(); + + // Re-apply sign + return FixedTrait::new(res_u64.try_into().unwrap(), a.sign ^ b.sign); +} + +fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { + return (*a.mag == *b.mag) && (*a.sign == *b.sign); +} + +// Calculates the natural exponent of x: e^x +fn exp(a: FP16x16W) -> FP16x16W { + return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 +} + +// Calculates the binary exponent of x: 2^x +fn exp2(a: FP16x16W) -> FP16x16W { + if (a.mag == 0) { + return FixedTrait::ONE(); + } + + let (int_part, frac_part) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + let int_res = FixedTrait::new_unscaled(lut::exp2(int_part), false); + let mut res_u = int_res; + + if frac_part != 0 { + let frac = FixedTrait::new(frac_part, false); + let r7 = FixedTrait::new(1, false) * frac; + let r6 = (r7 + FixedTrait::new(10, false)) * frac; + let r5 = (r6 + FixedTrait::new(87, false)) * frac; + let r4 = (r5 + FixedTrait::new(630, false)) * frac; + let r3 = (r4 + FixedTrait::new(3638, false)) * frac; + let r2 = (r3 + FixedTrait::new(15743, false)) * frac; + let r1 = (r2 + FixedTrait::new(45426, false)) * frac; + res_u = res_u * (r1 + FixedTrait::ONE()); + } + + if (a.sign == true) { + return FixedTrait::ONE() / res_u; + } else { + return res_u; + } +} + +fn exp2_int(exp: u64) -> FP16x16W { + return FixedTrait::new_unscaled(lut::exp2(exp), false); +} + +fn floor(a: FP16x16W) -> FP16x16W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div, false); + } else { + return FixedTrait::new_unscaled(div + 1, true); + } +} + +fn ge(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag == b.mag) || ((a.mag > b.mag) ^ a.sign); + } +} + +fn gt(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag != b.mag) && ((a.mag > b.mag) ^ a.sign); + } +} + +fn le(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag == b.mag) || ((a.mag < b.mag) ^ a.sign); + } +} + +// Calculates the natural logarithm of x: ln(x) +// self must be greater than zero +fn ln(a: FP16x16W) -> FP16x16W { + return FixedTrait::new(45426, false) * log2(a); // ln(2) = 0.693... +} + +// Calculates the binary logarithm of x: log2(x) +// self must be greather than zero +fn log2(a: FP16x16W) -> FP16x16W { + assert(a.sign == false, 'must be positive'); + + if (a.mag == ONE) { + return FixedTrait::ZERO(); + } else if (a.mag < ONE) { + // Compute true inverse binary log if 0 < x < 1 + let div = FixedTrait::ONE() / a; + return -log2(div); + } + + let whole = a.mag / ONE; + let (msb, div) = lut::msb(whole); + + if a.mag == div * ONE { + return FixedTrait::new_unscaled(msb, false); + } else { + let norm = a / FixedTrait::new_unscaled(div, false); + let r8 = FixedTrait::new(596, true) * norm; + let r7 = (r8 + FixedTrait::new(8116, false)) * norm; + let r6 = (r7 + FixedTrait::new(49044, true)) * norm; + let r5 = (r6 + FixedTrait::new(172935, false)) * norm; + let r4 = (r5 + FixedTrait::new(394096, true)) * norm; + let r3 = (r4 + FixedTrait::new(608566, false)) * norm; + let r2 = (r3 + FixedTrait::new(655828, true)) * norm; + let r1 = (r2 + FixedTrait::new(534433, false)) * norm; + return r1 + FixedTrait::new(224487, true) + FixedTrait::new_unscaled(msb, false); + } +} + +// Calculates the base 10 log of x: log10(x) +// self must be greater than zero +fn log10(a: FP16x16W) -> FP16x16W { + return FixedTrait::new(19728, false) * log2(a); // log10(2) = 0.301... +} + +fn lt(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag != b.mag) && ((a.mag < b.mag) ^ a.sign); + } +} + +fn mul(a: FP16x16W, b: FP16x16W) -> FP16x16W { + let prod_u128 = integer::u64_wide_mul(a.mag, b.mag); + + // Re-apply sign + return FixedTrait::new((prod_u128 / ONE.into()).try_into().unwrap(), a.sign ^ b.sign); +} + +fn ne(a: @FP16x16W, b: @FP16x16W) -> bool { + return (*a.mag != *b.mag) || (*a.sign != *b.sign); +} + +fn neg(a: FP16x16W) -> FP16x16W { + if a.mag == 0 { + return a; + } else if !a.sign { + return FixedTrait::new(a.mag, !a.sign); + } else { + return FixedTrait::new(a.mag, false); + } +} + +// Calclates the value of x^y and checks for overflow before returning +// self is a FP16x16W point value +// b is a FP16x16W point value +fn pow(a: FP16x16W, b: FP16x16W) -> FP16x16W { + let (div, rem) = integer::u64_safe_divmod(b.mag, u64_as_non_zero(ONE)); + + // use the more performant integer pow when y is an int + if (rem == 0) { + return pow_int(a, b.mag / ONE, b.sign); + } + + // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 + return exp(b * ln(a)); +} + +// Calclates the value of a^b and checks for overflow before returning +fn pow_int(a: FP16x16W, b: u64, sign: bool) -> FP16x16W { + let mut x = a; + let mut n = b; + + if sign == true { + x = FixedTrait::ONE() / x; + } + + if n == 0 { + return FixedTrait::ONE(); + } + + let mut y = FixedTrait::ONE(); + let two = integer::u64_as_non_zero(2); + + loop { + if n <= 1 { + break; + } + + let (div, rem) = integer::u64_safe_divmod(n, two); + + if rem == 1 { + y = x * y; + } + + x = x * x; + n = div; + }; + + return x * y; +} + +fn rem(a: FP16x16W, b: FP16x16W) -> FP16x16W { + return a - floor(a / b) * b; +} + +fn round(a: FP16x16W) -> FP16x16W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if (HALF <= rem) { + return FixedTrait::new_unscaled(div + 1, a.sign); + } else { + return FixedTrait::new_unscaled(div, a.sign); + } +} + +// Calculates the square root of a FP16x16W point value +// x must be positive +fn sqrt(a: FP16x16W) -> FP16x16W { + assert(a.sign == false, 'must be positive'); + + let root = integer::u64_sqrt(a.mag.into() * ONE.into()); + return FixedTrait::new(root.into(), false); +} + +fn sub(a: FP16x16W, b: FP16x16W) -> FP16x16W { + return add(a, -b); +} + +fn sign(a: FP16x16W) -> FP16x16W { + if a.mag == 0 { + FixedTrait::new(0, false) + } else { + FixedTrait::new(ONE, a.sign) + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::trig::{PI, HALF_PI}; + +#[test] +fn test_into() { + let a = FixedTrait::::new_unscaled(5, false); + assert(a.mag == 5 * ONE, 'invalid result'); +} + +#[test] +fn test_try_into_u128() { + // Positive unscaled + let a = FixedTrait::::new_unscaled(5, false); + assert(a.try_into().unwrap() == 5_u128, 'invalid result'); + + // Positive scaled + let b = FixedTrait::::new(5 * ONE, false); + assert(b.try_into().unwrap() == 5_u128, 'invalid result'); + + // Zero + let d = FixedTrait::::new_unscaled(0, false); + assert(d.try_into().unwrap() == 0_u128, 'invalid result'); +} + +#[test] +#[should_panic] +fn test_negative_try_into_u128() { + let a = FixedTrait::::new_unscaled(1, true); + let a: u128 = a.try_into().unwrap(); +} + +#[test] +#[available_gas(1000000)] +fn test_acos() { + let a = FixedTrait::::ONE(); + assert(a.acos().into() == 0, 'invalid one'); +} + +#[test] +#[available_gas(1000000)] +fn test_asin() { + let a = FixedTrait::ONE(); + assert_precise(a.asin(), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 +} + +#[test] +#[available_gas(2000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(a.atan(), 72558, 'invalid two', Option::None(())); +} + +#[test] +fn test_ceil() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(ceil(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_floor() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(floor(a).mag == 2 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_round() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(round(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +#[should_panic] +fn test_sqrt_fail() { + let a = FixedTrait::new_unscaled(25, true); + sqrt(a); +} + +#[test] +fn test_sqrt() { + let mut a = FixedTrait::new_unscaled(0, false); + assert(sqrt(a).mag == 0, 'invalid zero root'); + a = FixedTrait::new_unscaled(25, false); + assert(sqrt(a).mag == 5 * ONE, 'invalid pos root'); +} + + +#[test] +#[available_gas(100000)] +fn test_msb() { + let a = FixedTrait::::new_unscaled(100, false); + let (msb, div) = lut::msb(a.mag / ONE); + assert(msb == 6, 'invalid msb'); + assert(div == 64, 'invalid msb ceil'); +} + +#[test] +#[available_gas(600000)] +fn test_pow() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new_unscaled(4, false); + assert(pow(a, b).mag == 81 * ONE, 'invalid pos base power'); +} + +#[test] +#[available_gas(900000)] +fn test_pow_frac() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new(32768, false); // 0.5 + assert_relative( + pow(a, b), 113512, 'invalid pos base power', Option::None(()) + ); // 1.7320508075688772 +} + +#[test] +#[available_gas(1000000)] +fn test_exp() { + let a = FixedTrait::new_unscaled(2, false); + assert_relative(exp(a), 484249, 'invalid exp of 2', Option::None(())); // 7.389056098793725 +} + +#[test] +#[available_gas(400000)] +fn test_exp2() { + let a = FixedTrait::new_unscaled(5, false); + assert(exp2(a).mag == 2097152, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(20000)] +fn test_exp2_int() { + assert(exp2_int(5).into() == 2097152, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(1000000)] +fn test_ln() { + let mut a = FixedTrait::new_unscaled(1, false); + assert(ln(a).mag == 0, 'invalid ln of 1'); + + a = FixedTrait::new(178145, false); + assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); +} + +#[test] +#[available_gas(1000000)] +fn test_log2() { + let mut a = FixedTrait::new_unscaled(32, false); + assert(log2(a) == FixedTrait::new_unscaled(5, false), 'invalid log2 32'); + + a = FixedTrait::new_unscaled(10, false); + assert_relative(log2(a), 217706, 'invalid log2 10', Option::None(())); // 3.321928094887362 +} + +#[test] +#[available_gas(1000000)] +fn test_log10() { + let a = FixedTrait::new_unscaled(100, false); + assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); +} + +#[test] +fn test_eq() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = eq(@a, @b); + assert(c == true, 'invalid result'); +} + +#[test] +fn test_ne() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = ne(@a, @b); + assert(c == false, 'invalid result'); +} + +#[test] +fn test_add() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + assert(add(a, b) == FixedTrait::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_add_eq() { + let mut a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + a += b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_sub() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + let c = a - b; + assert(c == FixedTrait::::new_unscaled(3, false), 'false result invalid'); +} + +#[test] +fn test_sub_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + a -= b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +#[available_gas(100000)] +fn test_mul_pos() { + let a = FP16x16W { mag: 190054, sign: false }; + let b = FP16x16W { mag: 190054, sign: false }; + let c = a * b; + assert(c.mag == 551155, 'invalid result'); +} + +#[test] +fn test_mul_neg() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + let c = a * b; + assert(c == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_mul_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + a *= b; + assert(a == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_div() { + let a = FixedTrait::new_unscaled(10, false); + let b = FixedTrait::::new(190054, false); // 2.9 + let c = a / b; + assert(c.mag == 225986, 'invalid pos decimal'); // 3.4482758620689653 +} + +#[test] +fn test_le() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a <= a, 'a <= a'); + assert(a <= b == false, 'a <= b'); + assert(a <= c == false, 'a <= c'); + + assert(b <= a, 'b <= a'); + assert(b <= b, 'b <= b'); + assert(b <= c == false, 'b <= c'); + + assert(c <= a, 'c <= a'); + assert(c <= b, 'c <= b'); + assert(c <= c, 'c <= c'); +} + +#[test] +fn test_lt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a < a == false, 'a < a'); + assert(a < b == false, 'a < b'); + assert(a < c == false, 'a < c'); + + assert(b < a, 'b < a'); + assert(b < b == false, 'b < b'); + assert(b < c == false, 'b < c'); + + assert(c < a, 'c < a'); + assert(c < b, 'c < b'); + assert(c < c == false, 'c < c'); +} + +#[test] +fn test_ge() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a >= a, 'a >= a'); + assert(a >= b, 'a >= b'); + assert(a >= c, 'a >= c'); + + assert(b >= a == false, 'b >= a'); + assert(b >= b, 'b >= b'); + assert(b >= c, 'b >= c'); + + assert(c >= a == false, 'c >= a'); + assert(c >= b == false, 'c >= b'); + assert(c >= c, 'c >= c'); +} + +#[test] +fn test_gt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a > a == false, 'a > a'); + assert(a > b, 'a > b'); + assert(a > c, 'a > c'); + + assert(b > a == false, 'b > a'); + assert(b > b == false, 'b > b'); + assert(b > c, 'b > c'); + + assert(c > a == false, 'c > a'); + assert(c > b == false, 'c > b'); + assert(c > c == false, 'c > c'); +} + +#[test] +#[available_gas(1000000)] +fn test_cos() { + let a = FixedTrait::::new(HALF_PI, false); + assert(a.cos().into() == 0, 'invalid half pi'); +} + +#[test] +#[available_gas(1000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(a.sin(), ONE.into(), 'invalid half pi', Option::None(())); +} + +#[test] +#[available_gas(2000000)] +fn test_tan() { + let a = FixedTrait::::new(HALF_PI / 2, false); + assert(a.tan().mag == 65536, 'invalid quarter pi'); +} + +#[test] +#[available_gas(2000000)] +fn test_sign() { + let a = FixedTrait::::new(0, false); + assert(a.sign().mag == 0 && !a.sign().sign, 'invalid sign (0, true)'); + + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (HALF, true)'); + + let a = FixedTrait::::new(HALF, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (HALF, false)'); + + let a = FixedTrait::::new(ONE, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (ONE, true)'); + + let a = FixedTrait::::new(ONE, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (ONE, false)'); +} + +#[test] +#[should_panic] +#[available_gas(2000000)] +fn test_sign_fail() { + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag != ONE && !a.sign().sign, 'invalid sign (HALF, true)'); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo new file mode 100644 index 000000000..3286b6345 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo @@ -0,0 +1,159 @@ +use core::debug::PrintTrait; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WAddEq, FP16x16WSub, FP16x16WMul, + FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, FP16x16WNeg, + FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait +}; + +// Calculates hyperbolic cosine of a (fixed point) +fn cosh(a: FP16x16W) -> FP16x16W { + let ea = a.exp(); + return (ea + (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic sine of a (fixed point) +fn sinh(a: FP16x16W) -> FP16x16W { + let ea = a.exp(); + return (ea - (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic tangent of a (fixed point) +fn tanh(a: FP16x16W) -> FP16x16W { + let ea = a.exp(); + let ea_i = FixedTrait::ONE() / ea; + return (ea - ea_i) / (ea + ea_i); +} + +// Calculates inverse hyperbolic cosine of a (fixed point) +fn acosh(a: FP16x16W) -> FP16x16W { + let root = (a * a - FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic sine of a (fixed point) +fn asinh(a: FP16x16W) -> FP16x16W { + let root = (a * a + FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic tangent of a (fixed point) +fn atanh(a: FP16x16W) -> FP16x16W { + let one = FixedTrait::ONE(); + let ln_arg = (one + a) / (one - a); + return ln_arg.ln() / FixedTrait::new(TWO, false); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use option::OptionTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::assert_precise; + +#[test] +#[available_gas(10000000)] +fn test_cosh() { + let a = FixedTrait::new(TWO, false); + assert_precise(cosh(a), 246550, 'invalid two', Option::None(())); // 3.5954653836066 + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 101127, 'invalid one', Option::None(())); // 1.42428174592510 + + let a = FixedTrait::ZERO(); + assert_precise(cosh(a), ONE.into(), 'invalid zero', Option::None(())); + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 101127, 'invalid neg one', Option::None(())); // 1.42428174592510 + + let a = FixedTrait::new(TWO, true); + assert_precise(cosh(a), 246568, 'invalid neg two', Option::None(())); // 3.5954653836066 +} + +#[test] +#[available_gas(10000000)] +fn test_sinh() { + let a = FixedTrait::new(TWO, false); + assert_precise(sinh(a), 237681, 'invalid two', Option::None(())); // 3.48973469357602 + + let a = FixedTrait::ONE(); + assert_precise(sinh(a), 77018, 'invalid one', Option::None(())); // 1.13687593250230 + + let a = FixedTrait::ZERO(); + assert(sinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(sinh(a), -77018, 'invalid neg one', Option::None(())); // -1.13687593250230 + + let a = FixedTrait::new(TWO, true); + assert_precise(sinh(a), -237699, 'invalid neg two', Option::None(())); // -3.48973469357602 +} + +#[test] +#[available_gas(10000000)] +fn test_tanh() { + let a = FixedTrait::new(TWO, false); + assert_precise(tanh(a), 63179, 'invalid two', Option::None(())); // 0.75314654693321 + + let a = FixedTrait::ONE(); + assert_precise(tanh(a), 49912, 'invalid one', Option::None(())); // 0.59499543433175 + + let a = FixedTrait::ZERO(); + assert(tanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(tanh(a), -49912, 'invalid neg one', Option::None(())); // -0.59499543433175 + + let a = FixedTrait::new(TWO, true); + assert_precise(tanh(a), -63179, 'invalid neg two', Option::None(())); // 0.75314654693321 +} + +#[test] +#[available_gas(10000000)] +fn test_acosh() { + let a = FixedTrait::new(246559, false); // 3.5954653836066 + assert_precise(acosh(a), 131072, 'invalid two', Option::None(())); + + let a = FixedTrait::new(101127, false); // 1.42428174592510 + assert_precise(acosh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ONE(); // 1 + assert(acosh(a).into() == 0, 'invalid zero'); +} + +#[test] +#[available_gas(10000000)] +fn test_asinh() { + let a = FixedTrait::new(237690, false); // 3.48973469357602 + assert_precise(asinh(a), 131072, 'invalid two', Option::None(())); + + let a = FixedTrait::new(77018, false); // 1.13687593250230 + assert_precise(asinh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(asinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(77018, true); // -1.13687593250230 + assert_precise(asinh(a), -ONE.into(), 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(237690, true); // -3.48973469357602 + assert_precise(asinh(a), -131017, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(10000000)] +fn test_atanh() { + let a = FixedTrait::new(58982, false); // 0.9 + assert_precise(atanh(a), 96483, 'invalid 0.9', Option::None(())); // 1.36892147623689 + + let a = FixedTrait::new(HALF, false); // 0.5 + assert_precise(atanh(a), 35999, 'invalid half', Option::None(())); // 0.42914542526098 + + let a = FixedTrait::ZERO(); + assert(atanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(HALF, true); // 0.5 + assert_precise(atanh(a), -35999, 'invalid neg half', Option::None(())); // 0.42914542526098 + + let a = FixedTrait::new(58982, true); // 0.9 + assert_precise(atanh(a), -96483, 'invalid -0.9', Option::None(())); // 1.36892147623689 +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo new file mode 100644 index 000000000..e96b0d389 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo @@ -0,0 +1,1235 @@ +// Calculates the most significant bit +fn msb(whole: u64) -> (u64, u64) { + if whole < 256 { + if whole < 2 { + return (0, 1); + } + if whole < 4 { + return (1, 2); + } + if whole < 8 { + return (2, 4); + } + if whole < 16 { + return (3, 8); + } + if whole < 32 { + return (4, 16); + } + if whole < 64 { + return (5, 32); + } + if whole < 128 { + return (6, 64); + } + if whole < 256 { + return (7, 128); + } + } else if whole < 65536 { + if whole < 512 { + return (8, 256); + } + if whole < 1024 { + return (9, 512); + } + if whole < 2048 { + return (10, 1024); + } + if whole < 4096 { + return (11, 2048); + } + if whole < 8192 { + return (12, 4096); + } + if whole < 16384 { + return (13, 8192); + } + if whole < 32768 { + return (14, 16384); + } + if whole < 65536 { + return (15, 32768); + } + } + + return (16, 65536); +} + +fn exp2(exp: u64) -> u64 { + if exp <= 16 { + if exp == 0 { + return 1; + } + if exp == 1 { + return 2; + } + if exp == 2 { + return 4; + } + if exp == 3 { + return 8; + } + if exp == 4 { + return 16; + } + if exp == 5 { + return 32; + } + if exp == 6 { + return 64; + } + if exp == 7 { + return 128; + } + if exp == 8 { + return 256; + } + if exp == 9 { + return 512; + } + if exp == 10 { + return 1024; + } + if exp == 11 { + return 2048; + } + if exp == 12 { + return 4096; + } + if exp == 13 { + return 8192; + } + if exp == 14 { + return 16384; + } + if exp == 15 { + return 32768; + } + if exp == 16 { + return 65536; + } + } + + return 65536; +} + +fn sin(a: u64) -> (u64, u64, u64) { + let slot = a / 402; + + if slot < 128 { + if slot < 64 { + if slot < 32 { + if slot < 16 { + if slot == 0 { + return (0, 0, 402); + } + if slot == 1 { + return (402, 402, 804); + } + if slot == 2 { + return (804, 804, 1206); + } + if slot == 3 { + return (1206, 1206, 1608); + } + if slot == 4 { + return (1608, 1608, 2010); + } + if slot == 5 { + return (2011, 2010, 2412); + } + if slot == 6 { + return (2413, 2412, 2814); + } + if slot == 7 { + return (2815, 2814, 3216); + } + if slot == 8 { + return (3217, 3216, 3617); + } + if slot == 9 { + return (3619, 3617, 4019); + } + if slot == 10 { + return (4023, 4019, 4420); + } + if slot == 11 { + return (4423, 4420, 4821); + } + if slot == 12 { + return (4825, 4821, 5222); + } + if slot == 13 { + return (5228, 5222, 5623); + } + if slot == 14 { + return (5630, 5623, 6023); + } + if slot == 15 { + return (6032, 6023, 6424); + } + } else { + if slot == 16 { + return (6434, 6424, 6824); + } + if slot == 17 { + return (6836, 6824, 7224); + } + if slot == 18 { + return (7238, 7224, 7623); + } + if slot == 19 { + return (7640, 7623, 8022); + } + if slot == 20 { + return (8042, 8022, 8421); + } + if slot == 21 { + return (8445, 8421, 8820); + } + if slot == 22 { + return (8847, 8820, 9218); + } + if slot == 23 { + return (9249, 9218, 9616); + } + if slot == 24 { + return (9651, 9616, 10014); + } + if slot == 25 { + return (10053, 10014, 10411); + } + if slot == 26 { + return (10455, 10411, 10808); + } + if slot == 27 { + return (10857, 10808, 11204); + } + if slot == 28 { + return (11259, 11204, 11600); + } + if slot == 29 { + return (11662, 11600, 11996); + } + if slot == 30 { + return (12064, 11996, 12391); + } + if slot == 31 { + return (12466, 12391, 12785); + } + } + } else { + if slot < 48 { + if slot == 32 { + return (12868, 12785, 13180); + } + if slot == 33 { + return (13270, 13180, 13573); + } + if slot == 34 { + return (13672, 13573, 13966); + } + if slot == 35 { + return (14074, 13966, 14359); + } + if slot == 36 { + return (14476, 14359, 14751); + } + if slot == 37 { + return (14879, 14751, 15143); + } + if slot == 38 { + return (15281, 15143, 15534); + } + if slot == 39 { + return (15683, 15534, 15924); + } + if slot == 40 { + return (16081, 15924, 16314); + } + if slot == 41 { + return (16487, 16314, 16703); + } + if slot == 42 { + return (16889, 16703, 17091); + } + if slot == 43 { + return (17291, 17091, 17479); + } + if slot == 44 { + return (17693, 17479, 17867); + } + if slot == 45 { + return (18096, 17867, 18253); + } + if slot == 46 { + return (18498, 18253, 18639); + } + if slot == 47 { + return (18900, 18639, 19024); + } + } else { + if slot == 48 { + return (19302, 19024, 19409); + } + if slot == 49 { + return (19704, 19409, 19792); + } + if slot == 50 { + return (20113, 19792, 20175); + } + if slot == 51 { + return (20508, 20175, 20557); + } + if slot == 52 { + return (20910, 20557, 20939); + } + if slot == 53 { + return (21313, 20939, 21320); + } + if slot == 54 { + return (21715, 21320, 21699); + } + if slot == 55 { + return (22117, 21699, 22078); + } + if slot == 56 { + return (22519, 22078, 22457); + } + if slot == 57 { + return (22921, 22457, 22834); + } + if slot == 58 { + return (23323, 22834, 23210); + } + if slot == 59 { + return (23725, 23210, 23586); + } + if slot == 60 { + return (24127, 23586, 23961); + } + if slot == 61 { + return (24530, 23961, 24335); + } + if slot == 62 { + return (24932, 24335, 24708); + } + if slot == 63 { + return (25334, 24708, 25080); + } + } + } + } else { + if slot < 96 { + if slot < 80 { + if slot == 64 { + return (25736, 25080, 25451); + } + if slot == 65 { + return (26138, 25451, 25821); + } + if slot == 66 { + return (26540, 25821, 26190); + } + if slot == 67 { + return (26942, 26190, 26558); + } + if slot == 68 { + return (27344, 26558, 26925); + } + if slot == 69 { + return (27747, 26925, 27291); + } + if slot == 70 { + return (28149, 27291, 27656); + } + if slot == 71 { + return (28551, 27656, 28020); + } + if slot == 72 { + return (28953, 28020, 28383); + } + if slot == 73 { + return (29355, 28383, 28745); + } + if slot == 74 { + return (29757, 28745, 29106); + } + if slot == 75 { + return (30159, 29106, 29466); + } + if slot == 76 { + return (30561, 29466, 29824); + } + if slot == 77 { + return (30964, 29824, 30182); + } + if slot == 78 { + return (31366, 30182, 30538); + } + if slot == 79 { + return (31768, 30538, 30893); + } + } else { + if slot == 80 { + return (32171, 30893, 31248); + } + if slot == 81 { + return (32572, 31248, 31600); + } + if slot == 82 { + return (32974, 31600, 31952); + } + if slot == 83 { + return (33376, 31952, 32303); + } + if slot == 84 { + return (33778, 32303, 32652); + } + if slot == 85 { + return (34181, 32652, 33000); + } + if slot == 86 { + return (34583, 33000, 33347); + } + if slot == 87 { + return (34985, 33347, 33692); + } + if slot == 88 { + return (35387, 33692, 34037); + } + if slot == 89 { + return (35789, 34037, 34380); + } + if slot == 90 { + return (36194, 34380, 34721); + } + if slot == 91 { + return (36593, 34721, 35062); + } + if slot == 92 { + return (36995, 35062, 35401); + } + if slot == 93 { + return (37398, 35401, 35738); + } + if slot == 94 { + return (37800, 35738, 36075); + } + if slot == 95 { + return (38202, 36075, 36410); + } + } + } else { + if slot < 112 { + if slot == 96 { + return (38604, 36410, 36744); + } + if slot == 97 { + return (39006, 36744, 37076); + } + if slot == 98 { + return (39408, 37076, 37407); + } + if slot == 99 { + return (39810, 37407, 37736); + } + if slot == 100 { + return (40227, 37736, 38064); + } + if slot == 101 { + return (40615, 38064, 38391); + } + if slot == 102 { + return (41017, 38391, 38716); + } + if slot == 103 { + return (41419, 38716, 39040); + } + if slot == 104 { + return (41821, 39040, 39362); + } + if slot == 105 { + return (42223, 39362, 39683); + } + if slot == 106 { + return (42625, 39683, 40002); + } + if slot == 107 { + return (43027, 40002, 40320); + } + if slot == 108 { + return (43429, 40320, 40636); + } + if slot == 109 { + return (43832, 40636, 40951); + } + if slot == 110 { + return (44234, 40951, 41264); + } + if slot == 111 { + return (44636, 41264, 41576); + } + } else { + if slot == 112 { + return (45038, 41576, 41886); + } + if slot == 113 { + return (45440, 41886, 42194); + } + if slot == 114 { + return (45842, 42194, 42501); + } + if slot == 115 { + return (46244, 42501, 42806); + } + if slot == 116 { + return (46646, 42806, 43110); + } + if slot == 117 { + return (47048, 43110, 43412); + } + if slot == 118 { + return (47451, 43412, 43713); + } + if slot == 119 { + return (47853, 43713, 44011); + } + if slot == 120 { + return (48252, 44011, 44308); + } + if slot == 121 { + return (48657, 44308, 44604); + } + if slot == 122 { + return (49059, 44604, 44898); + } + if slot == 123 { + return (49461, 44898, 45190); + } + if slot == 124 { + return (49863, 45190, 45480); + } + if slot == 125 { + return (50265, 45480, 45769); + } + if slot == 126 { + return (50668, 45769, 46056); + } + if slot == 127 { + return (51070, 46056, 46341); + } + } + } + } + } else { + if slot < 192 { + if slot < 160 { + if slot < 144 { + if slot == 128 { + return (51472, 46341, 46624); + } + if slot == 129 { + return (51874, 46624, 46906); + } + if slot == 130 { + return (52285, 46906, 47186); + } + if slot == 131 { + return (52678, 47186, 47464); + } + if slot == 132 { + return (53080, 47464, 47741); + } + if slot == 133 { + return (53482, 47741, 48015); + } + if slot == 134 { + return (53885, 48015, 48288); + } + if slot == 135 { + return (54287, 48288, 48559); + } + if slot == 136 { + return (54689, 48559, 48828); + } + if slot == 137 { + return (55091, 48828, 49095); + } + if slot == 138 { + return (55493, 49095, 49361); + } + if slot == 139 { + return (55895, 49361, 49624); + } + if slot == 140 { + return (56297, 49624, 49886); + } + if slot == 141 { + return (56699, 49886, 50146); + } + if slot == 142 { + return (57102, 50146, 50404); + } + if slot == 143 { + return (57504, 50404, 50660); + } + } else { + if slot == 144 { + return (57906, 50660, 50914); + } + if slot == 145 { + return (58308, 50914, 51166); + } + if slot == 146 { + return (58710, 51166, 51417); + } + if slot == 147 { + return (59112, 51417, 51665); + } + if slot == 148 { + return (59514, 51665, 51911); + } + if slot == 149 { + return (59916, 51911, 52156); + } + if slot == 150 { + return (60320, 52156, 52398); + } + if slot == 151 { + return (60721, 52398, 52639); + } + if slot == 152 { + return (61123, 52639, 52878); + } + if slot == 153 { + return (61525, 52878, 53114); + } + if slot == 154 { + return (61927, 53114, 53349); + } + if slot == 155 { + return (62329, 53349, 53581); + } + if slot == 156 { + return (62731, 53581, 53812); + } + if slot == 157 { + return (63133, 53812, 54040); + } + if slot == 158 { + return (63536, 54040, 54267); + } + if slot == 159 { + return (63938, 54267, 54491); + } + if slot == 160 { + return (64343, 54491, 54714); + } + } + } else { + if slot < 176 { + if slot == 161 { + return (64742, 54714, 54934); + } + if slot == 162 { + return (65144, 54934, 55152); + } + if slot == 163 { + return (65546, 55152, 55368); + } + if slot == 164 { + return (65948, 55368, 55582); + } + if slot == 165 { + return (66350, 55582, 55794); + } + if slot == 166 { + return (66753, 55794, 56004); + } + if slot == 167 { + return (67155, 56004, 56212); + } + if slot == 168 { + return (67557, 56212, 56418); + } + if slot == 169 { + return (67959, 56418, 56621); + } + if slot == 170 { + return (68361, 56621, 56823); + } + if slot == 171 { + return (68763, 56823, 57022); + } + if slot == 172 { + return (69165, 57022, 57219); + } + if slot == 173 { + return (69567, 57219, 57414); + } + if slot == 174 { + return (69970, 57414, 57607); + } + if slot == 175 { + return (70372, 57607, 57798); + } + } else { + if slot == 176 { + return (70774, 57798, 57986); + } + if slot == 177 { + return (71176, 57986, 58172); + } + if slot == 178 { + return (71578, 58172, 58356); + } + if slot == 179 { + return (71980, 58356, 58538); + } + if slot == 180 { + return (72382, 58538, 58718); + } + if slot == 181 { + return (72784, 58718, 58896); + } + if slot == 182 { + return (73187, 58896, 59071); + } + if slot == 183 { + return (73589, 59071, 59244); + } + if slot == 184 { + return (73991, 59244, 59415); + } + if slot == 185 { + return (74393, 59415, 59583); + } + if slot == 186 { + return (74795, 59583, 59750); + } + if slot == 187 { + return (75197, 59750, 59914); + } + if slot == 188 { + return (75599, 59914, 60075); + } + if slot == 189 { + return (76001, 60075, 60235); + } + if slot == 190 { + return (76401, 60235, 60392); + } + if slot == 191 { + return (76806, 60392, 60547); + } + } + } + } else { + if slot < 224 { + if slot < 208 { + if slot == 192 { + return (77208, 60547, 60700); + } + if slot == 193 { + return (77610, 60700, 60851); + } + if slot == 194 { + return (78012, 60851, 60999); + } + if slot == 195 { + return (78414, 60999, 61145); + } + if slot == 196 { + return (78816, 61145, 61288); + } + if slot == 197 { + return (79218, 61288, 61429); + } + if slot == 198 { + return (79621, 61429, 61568); + } + if slot == 199 { + return (80023, 61568, 61705); + } + if slot == 200 { + return (80423, 61705, 61839); + } + if slot == 201 { + return (80827, 61839, 61971); + } + if slot == 202 { + return (81229, 61971, 62101); + } + if slot == 203 { + return (81631, 62101, 62228); + } + if slot == 204 { + return (82033, 62228, 62353); + } + if slot == 205 { + return (82435, 62353, 62476); + } + if slot == 206 { + return (82838, 62476, 62596); + } + if slot == 207 { + return (83240, 62596, 62714); + } + } else { + if slot == 208 { + return (83642, 62714, 62830); + } + if slot == 209 { + return (84044, 62830, 62943); + } + if slot == 210 { + return (84446, 62943, 63054); + } + if slot == 211 { + return (84848, 63054, 63162); + } + if slot == 212 { + return (85250, 63162, 63268); + } + if slot == 213 { + return (85652, 63268, 63372); + } + if slot == 214 { + return (86055, 63372, 63473); + } + if slot == 215 { + return (86457, 63473, 63572); + } + if slot == 216 { + return (86859, 63572, 63668); + } + if slot == 217 { + return (87261, 63668, 63763); + } + if slot == 218 { + return (87663, 63763, 63854); + } + if slot == 219 { + return (88065, 63854, 63944); + } + if slot == 220 { + return (88467, 63944, 64031); + } + if slot == 221 { + return (88869, 64031, 64115); + } + if slot == 222 { + return (89271, 64115, 64197); + } + if slot == 223 { + return (89674, 64197, 64277); + } + } + } else { + if slot < 240 { + if slot == 224 { + return (90076, 64277, 64354); + } + if slot == 225 { + return (90478, 64354, 64429); + } + if slot == 226 { + return (90880, 64429, 64501); + } + if slot == 227 { + return (91282, 64501, 64571); + } + if slot == 228 { + return (91684, 64571, 64639); + } + if slot == 229 { + return (92086, 64639, 64704); + } + if slot == 230 { + return (92491, 64704, 64766); + } + if slot == 231 { + return (92891, 64766, 64827); + } + if slot == 232 { + return (93293, 64827, 64884); + } + if slot == 233 { + return (93695, 64884, 64940); + } + if slot == 234 { + return (94097, 64940, 64993); + } + if slot == 235 { + return (94499, 64993, 65043); + } + if slot == 236 { + return (94901, 65043, 65091); + } + if slot == 237 { + return (95303, 65091, 65137); + } + if slot == 238 { + return (95705, 65137, 65180); + } + if slot == 239 { + return (96108, 65180, 65220); + } + } else { + if slot == 240 { + return (96514, 65220, 65259); + } + if slot == 241 { + return (96912, 65259, 65294); + } + if slot == 242 { + return (97314, 65294, 65328); + } + if slot == 243 { + return (97716, 65328, 65358); + } + if slot == 244 { + return (98118, 65358, 65387); + } + if slot == 245 { + return (98520, 65387, 65413); + } + if slot == 246 { + return (98922, 65413, 65436); + } + if slot == 247 { + return (99325, 65436, 65457); + } + if slot == 248 { + return (99727, 65457, 65476); + } + if slot == 249 { + return (100129, 65476, 65492); + } + if slot == 250 { + return (100531, 65492, 65505); + } + if slot == 251 { + return (100933, 65505, 65516); + } + if slot == 252 { + return (101335, 65516, 65525); + } + if slot == 253 { + return (101737, 65525, 65531); + } + if slot == 254 { + return (102139, 65531, 65535); + } + } + } + } + } + + return (102542, 65535, 65536); +} + +fn atan(a: u64) -> (u64, u64, u64) { + let slot = a / 459; + + if slot == 0 { + return (0, 0, 459); + } + if slot == 1 { + return (459, 459, 917); + } + if slot == 2 { + return (918, 917, 1376); + } + if slot == 3 { + return (1376, 1376, 1835); + } + if slot == 4 { + return (1835, 1835, 2293); + } + if slot == 5 { + return (2294, 2293, 2751); + } + if slot == 6 { + return (2753, 2751, 3209); + } + if slot == 7 { + return (3211, 3209, 3666); + } + if slot == 8 { + return (3670, 3666, 4123); + } + if slot == 9 { + return (4129, 4123, 4580); + } + if slot == 10 { + return (4591, 4580, 5036); + } + if slot == 11 { + return (5046, 5036, 5492); + } + if slot == 12 { + return (5505, 5492, 5947); + } + if slot == 13 { + return (5964, 5947, 6402); + } + if slot == 14 { + return (6423, 6402, 6856); + } + if slot == 15 { + return (6881, 6856, 7310); + } + if slot == 16 { + return (7340, 7310, 7762); + } + if slot == 17 { + return (7799, 7762, 8214); + } + if slot == 18 { + return (8258, 8214, 8665); + } + if slot == 19 { + return (8716, 8665, 9116); + } + if slot == 20 { + return (9181, 9116, 9565); + } + if slot == 21 { + return (9634, 9565, 10014); + } + if slot == 22 { + return (10093, 10014, 10462); + } + if slot == 23 { + return (10551, 10462, 10908); + } + if slot == 24 { + return (11010, 10908, 11354); + } + if slot == 25 { + return (11469, 11354, 11798); + } + if slot == 26 { + return (11928, 11798, 12242); + } + if slot == 27 { + return (12386, 12242, 12684); + } + if slot == 28 { + return (12845, 12684, 13125); + } + if slot == 29 { + return (13304, 13125, 13565); + } + if slot == 30 { + return (13762, 13565, 14004); + } + if slot == 31 { + return (14221, 14004, 14442); + } + if slot == 32 { + return (14680, 14442, 14878); + } + if slot == 33 { + return (15139, 14878, 15313); + } + if slot == 34 { + return (15598, 15313, 15746); + } + if slot == 35 { + return (16056, 15746, 16178); + } + if slot == 36 { + return (16515, 16178, 16609); + } + if slot == 37 { + return (16974, 16609, 17038); + } + if slot == 38 { + return (17433, 17038, 17466); + } + if slot == 39 { + return (17891, 17466, 17892); + } + if slot == 40 { + return (18353, 17892, 18317); + } + if slot == 41 { + return (18809, 18317, 18740); + } + if slot == 42 { + return (19268, 18740, 19161); + } + if slot == 43 { + return (19726, 19161, 19581); + } + if slot == 44 { + return (20185, 19581, 19999); + } + if slot == 45 { + return (20644, 19999, 20416); + } + if slot == 46 { + return (21103, 20416, 20830); + } + if slot == 47 { + return (21561, 20830, 21243); + } + if slot == 48 { + return (22020, 21243, 21655); + } + if slot == 49 { + return (22479, 21655, 22064); + } + if slot == 50 { + return (22944, 22064, 22472); + } + if slot == 51 { + return (23396, 22472, 22878); + } + if slot == 52 { + return (23855, 22878, 23282); + } + if slot == 53 { + return (24314, 23282, 23685); + } + if slot == 54 { + return (24773, 23685, 24085); + } + if slot == 55 { + return (25231, 24085, 24484); + } + if slot == 56 { + return (25690, 24484, 24880); + } + if slot == 57 { + return (26149, 24880, 25275); + } + if slot == 58 { + return (26608, 25275, 25668); + } + if slot == 59 { + return (27066, 25668, 26059); + } + if slot == 60 { + return (27534, 26059, 26448); + } + if slot == 61 { + return (27984, 26448, 26835); + } + if slot == 62 { + return (28443, 26835, 27220); + } + if slot == 63 { + return (28901, 27220, 27603); + } + if slot == 64 { + return (29360, 27603, 27984); + } + if slot == 65 { + return (29819, 27984, 28363); + } + if slot == 66 { + return (30278, 28363, 28740); + } + if slot == 67 { + return (30736, 28740, 29115); + } + if slot == 68 { + return (31195, 29115, 29488); + } + if slot == 69 { + return (31654, 29488, 29859); + } + if slot == 70 { + return (32113, 29859, 30228); + } + if slot == 71 { + return (32571, 30228, 30595); + } + if slot == 72 { + return (33030, 30595, 30960); + } + if slot == 73 { + return (33489, 30960, 31323); + } + if slot == 74 { + return (33948, 31323, 31683); + } + if slot == 75 { + return (34406, 31683, 32042); + } + if slot == 76 { + return (34865, 32042, 32398); + } + if slot == 77 { + return (35324, 32398, 32753); + } + if slot == 78 { + return (35783, 32753, 33105); + } + if slot == 79 { + return (36241, 33105, 33455); + } + if slot == 80 { + return (36700, 33455, 33804); + } + if slot == 81 { + return (37159, 33804, 34150); + } + if slot == 82 { + return (37618, 34150, 34494); + } + if slot == 83 { + return (38076, 34494, 34836); + } + if slot == 84 { + return (38535, 34836, 35175); + } + if slot == 85 { + return (38994, 35175, 35513); + } + if slot == 86 { + return (39453, 35513, 35849); + } + if slot == 87 { + return (39911, 35849, 36183); + } + if slot == 88 { + return (40370, 36183, 36514); + } + if slot == 89 { + return (40829, 36514, 36843); + } + if slot == 90 { + return (41288, 36843, 37171); + } + if slot == 91 { + return (41746, 37171, 37496); + } + if slot == 92 { + return (42205, 37496, 37819); + } + if slot == 93 { + return (42664, 37819, 38141); + } + if slot == 94 { + return (43123, 38141, 38460); + } + if slot == 95 { + return (43581, 38460, 38777); + } + if slot == 96 { + return (44040, 38777, 39092); + } + if slot == 97 { + return (44499, 39092, 39405); + } + if slot == 98 { + return (44958, 39405, 39716); + } + + return (45416, 39716, 40025); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo new file mode 100644 index 000000000..4c47eca5e --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo @@ -0,0 +1,450 @@ +use debug::PrintTrait; +use integer::{u64_safe_divmod, u64_as_non_zero}; +use option::OptionTrait; + +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WSub, FP16x16WMul, FP16x16WDiv, + FP16x16WIntoFelt252, FixedTrait +}; + +// CONSTANTS + +const TWO_PI: u64 = 411775; +const PI: u64 = 205887; +const HALF_PI: u64 = 102944; + +// PUBLIC + +// Calculates arccos(a) for -1 <= a <= 1 (fixed point) +// arccos(a) = arcsin(sqrt(1 - a^2)) - arctan identity has discontinuity at zero +fn acos(a: FP16x16W) -> FP16x16W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +fn acos_fast(a: FP16x16W) -> FP16x16W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin_fast(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +// Calculates arcsin(a) for -1 <= a <= 1 (fixed point) +// arcsin(a) = arctan(a / sqrt(1 - a^2)) +fn asin(a: FP16x16W) -> FP16x16W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan(a / div); +} + +fn asin_fast(a: FP16x16W) -> FP16x16W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan_fast(a / div); +} + +// Calculates arctan(a) (fixed point) +// See https://stackoverflow.com/a/50894477 for range adjustments +fn atan(a: FP16x16W) -> FP16x16W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 45875) { + let sqrt3_3 = FixedTrait::new(37837, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let r10 = FixedTrait::new(120, true) * at; + let r9 = (r10 + FixedTrait::new(3066, true)) * at; + let r8 = (r9 + FixedTrait::new(12727, false)) * at; + let r7 = (r8 + FixedTrait::new(17170, true)) * at; + let r6 = (r7 + FixedTrait::new(2865, false)) * at; + let r5 = (r6 + FixedTrait::new(12456, false)) * at; + let r4 = (r5 + FixedTrait::new(90, false)) * at; + let r3 = (r4 + FixedTrait::new(21852, true)) * at; + let r2 = r3 * at; + let mut res = (r2 + FixedTrait::new(65536, false)) * at; + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(34315, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + + +fn atan_fast(a: FP16x16W) -> FP16x16W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 45875) { + let sqrt3_3 = FixedTrait::new(37837, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let (start, low, high) = lut::atan(at.mag); + let partial_step = FixedTrait::new(at.mag - start, false) / FixedTrait::new(459, false); + let mut res = partial_step * FixedTrait::new(high - low, false) + FixedTrait::new(low, false); + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(34315, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +// Calculates cos(a) with a in radians (fixed point) +fn cos(a: FP16x16W) -> FP16x16W { + return sin(FixedTrait::new(HALF_PI, false) - a); +} + +fn cos_fast(a: FP16x16W) -> FP16x16W { + return sin_fast(FixedTrait::new(HALF_PI, false) - a); +} + +fn sin(a: FP16x16W) -> FP16x16W { + let a1 = a.mag % TWO_PI; + let (whole_rem, partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let a2 = FixedTrait::new(partial_rem, false); + let partial_sign = whole_rem == 1; + + let loop_res = a2 * _sin_loop(a2, 7, FixedTrait::ONE()); + return FixedTrait::new(loop_res.mag, a.sign ^ partial_sign && loop_res.mag != 0); +} + +fn sin_fast(a: FP16x16W) -> FP16x16W { + let a1 = a.mag % TWO_PI; + let (whole_rem, mut partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let partial_sign = whole_rem == 1; + + if partial_rem >= HALF_PI { + partial_rem = PI - partial_rem; + } + + let (start, low, high) = lut::sin(partial_rem); + let partial_step = FixedTrait::new(partial_rem - start, false) / FixedTrait::new(402, false); + let res = partial_step * (FixedTrait::new(high, false) - FixedTrait::new(low, false)) + + FixedTrait::::new(low, false); + + return FixedTrait::new(res.mag, a.sign ^ partial_sign && res.mag != 0); +} + +// Calculates tan(a) with a in radians (fixed point) +fn tan(a: FP16x16W) -> FP16x16W { + let sinx = sin(a); + let cosx = cos(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +fn tan_fast(a: FP16x16W) -> FP16x16W { + let sinx = sin_fast(a); + let cosx = cos_fast(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +// Helper function to calculate Taylor series for sin +fn _sin_loop(a: FP16x16W, i: u64, acc: FP16x16W) -> FP16x16W { + let div = (2 * i + 2) * (2 * i + 3); + let term = a * a * acc / FixedTrait::new_unscaled(div, false); + let new_acc = FixedTrait::ONE() - term; + + if (i == 0) { + return new_acc; + } + + return _sin_loop(a, i - 1, new_acc); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16WPartialEq, FP16x16WPrint}; + +#[test] +#[available_gas(8000000)] +fn test_acos() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos(a), 68629, 'invalid half', error); // 1.3687308642680 + + let a = FixedTrait::ZERO(); + assert_relative(acos(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos(a), 137258, 'invalid neg half', error); // 2.737461741902 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[available_gas(8000000)] +fn test_acos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos_fast(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos_fast(a), 68629, 'invalid half', error); // 1.3687308642680 + + let a = FixedTrait::ZERO(); + assert_relative(acos_fast(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos_fast(a), 137258, 'invalid neg half', error); // 2.737461741902 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos_fast(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[should_panic] +#[available_gas(8000000)] +fn test_acos_fail() { + let a = FixedTrait::new(2 * ONE, true); + acos(a); +} + +#[test] +#[available_gas(8000000)] +fn test_atan_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan_fast(a), 72558, 'invalid two', error); + + let a = FixedTrait::ONE(); + assert_relative(atan_fast(a), 51472, 'invalid one', error); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan_fast(a), 30386, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert(atan_fast(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan_fast(a), -30386, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan_fast(a), -51472, 'invalid neg one', error); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan_fast(a), -72558, 'invalid neg two', error); +} + +#[test] +#[available_gas(8000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan(a), 72558, 'invalid two', Option::None(())); + + let a = FixedTrait::ONE(); + assert_relative(atan(a), 51472, 'invalid one', Option::None(())); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan(a), 30386, 'invalid half', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(atan(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan(a), -30386, 'invalid neg half', Option::None(())); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan(a), -51472, 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan(a), -72558, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(8000000)] +fn test_asin() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert_relative(asin(a), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(asin(a), 34315, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert_precise(asin(a), 0, 'invalid zero', Option::None(())); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(asin(a), -34315, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(asin(a), -HALF_PI.into(), 'invalid neg one', Option::None(())); // -PI / 2 +} + +#[test] +#[should_panic] +#[available_gas(8000000)] +fn test_asin_fail() { + let a = FixedTrait::new(2 * ONE, false); + asin(a); +} + +#[test] +#[available_gas(8000000)] +fn test_cos() { + let a = FixedTrait::new(HALF_PI, false); + assert(cos(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_relative(cos(a), 46341, 'invalid quarter pi', Option::None(())); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert_relative(cos(a), -1 * ONE.into(), 'invalid pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_relative(cos(a), -18033, 'invalid 17', Option::None(())); // -0.21497123284870 + + let a = FixedTrait::new_unscaled(17, true); + assert_relative(cos(a), -18033, 'invalid -17', Option::None(())); // -0.21497123284870 +} + +#[test] +#[available_gas(8000000)] +fn test_cos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert(cos_fast(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(cos_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert_precise(cos_fast(a), -1 * ONE.into(), 'invalid pi', error); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(cos_fast(a), -18033, 'invalid 17', error); // -0.21497123284870 +} + +#[test] +#[available_gas(8000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin(a), ONE.into(), 'invalid half pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin(a), 46341, 'invalid quarter pi', Option::None(())); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert(sin(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise( + sin(a), -ONE.into(), 'invalid neg half pi', Option::None(()) + ); // 0.78124999999529 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin(a), -63006, 'invalid 17', Option::None(())); // -0.75109179053073 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin(a), 63006, 'invalid -17', Option::None(())); // 0.75109179053073 +} + +#[test] +#[available_gas(8000000)] +fn test_sin_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin_fast(a), ONE.into(), 'invalid half pi', error); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert(sin_fast(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(sin_fast(a), -ONE.into(), 'invalid neg half pi', error); // 0.78124999999529 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin_fast(a), -63006, 'invalid 17', error); // -0.75109179053073 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin_fast(a), 63006, 'invalid -17', error); // 0.75109179053073 +} + +#[test] +#[available_gas(8000000)] +fn test_tan() { + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(tan(a), ONE.into(), 'invalid quarter pi', Option::None(())); + + let a = FixedTrait::new(PI, false); + assert_precise(tan(a), 0, 'invalid pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(tan(a), 228990, 'invalid 17', Option::None(())); // 3.3858731852805 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(tan(a), -228952, 'invalid -17', Option::None(())); // -3.3858731852805 +} From 27f3b6b5f3ab60225329a88411310900450d8535 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 10:51:37 +0300 Subject: [PATCH 11/78] add convertors --- .../implementations/fp16x16wide/core.cairo | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo index 01a1d8b8d..f12b96d9b 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo @@ -5,7 +5,7 @@ use result::{ResultTrait, ResultTraitImpl}; use traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; -use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::{fixed_point::core::FixedTrait, FP16x16}; use orion::numbers::fixed_point::implementations::fp16x16wide::math::{core, trig, hyp}; use orion::numbers::fixed_point::utils; @@ -211,6 +211,25 @@ impl FP16x16WIntoI32 of Into { } } +impl FP16x16IntoFP16x16W of Into { + fn into(self: FP16x16) -> FP16x16W { + FP16x16W { mag: self.mag.into(), sign: self.sign } + } +} + +impl FP16x16WTryIntoFP16x16 of TryInto { + fn try_into(self: FP16x16W) -> Option { + match self.mag.try_into() { + Option::Some(val) => { + Option::Some(FP16x16 { mag: val, sign: self.sign }) + }, + Option::None(_) => { + Option::None(()) + } + } + } +} + impl FP16x16WTryIntoI8 of TryInto { fn try_into(self: FP16x16W) -> Option { _i8_try_from_fp(self) From 53181cfad4a0250e8e97582c91a55972a96489fd Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:03:50 +0300 Subject: [PATCH 12/78] implement FP16x16WTensor --- src/numbers.cairo | 165 ++++++++ src/operators/tensor/implementations.cairo | 1 + .../implementations/tensor_fp16x16wide.cairo | 361 ++++++++++++++++++ 3 files changed, 527 insertions(+) create mode 100644 src/operators/tensor/implementations/tensor_fp16x16wide.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index 04ad6efa5..02dd5b344 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -378,6 +378,171 @@ impl FP16x16Number of NumberTrait { } } +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16WImpl, FP16x16W}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::core as core_fp16x16wide; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::comp as comp_fp16x16wide; + +impl FP16x16WNumber of NumberTrait { + fn new(mag: u64, sign: bool) -> FP16x16W { + FP16x16WImpl::new(mag, sign) + } + + fn new_unscaled(mag: u64, sign: bool) -> FP16x16W { + FP16x16WImpl::new_unscaled(mag, sign) + } + + fn from_felt(val: felt252) -> FP16x16W { + FP16x16WImpl::from_felt(val) + } + + fn ceil(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::ceil(self) + } + + fn exp(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::exp(self) + } + + fn exp2(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::exp2(self) + } + + fn floor(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::floor(self) + } + + fn ln(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::ln(self) + } + + fn log2(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::log2(self) + } + + fn log10(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::log10(self) + } + + fn pow(self: FP16x16W, b: FP16x16W) -> FP16x16W { + FP16x16WImpl::pow(self, b) + } + + fn round(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::round(self) + } + + fn sqrt(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::sqrt(self) + } + + fn acos(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::acos(self) + } + + fn asin(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::asin(self) + } + + fn atan(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::atan(self) + } + + fn cos(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::cos(self) + } + + fn sin(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::sin(self) + } + + fn tan(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::tan(self) + } + + fn acosh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::acosh(self) + } + + fn asinh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::asinh(self) + } + + fn atanh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::atanh(self) + } + + fn cosh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::cosh(self) + } + + fn sinh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::sinh(self) + } + + fn tanh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::tanh(self) + } + + fn zero() -> FP16x16W { + FP16x16WImpl::ZERO() + } + fn is_zero(self: FP16x16W) -> bool { + core_fp16x16wide::eq(@self, @FP16x16WImpl::ZERO()) + } + + fn one() -> FP16x16W { + FP16x16WImpl::ONE() + } + + fn neg_one() -> FP16x16W { + FP16x16W { mag: core_fp16x16wide::ONE, sign: true } + } + + fn is_one(self: FP16x16W) -> bool { + core_fp16x16wide::eq(@self, @FP16x16WImpl::ONE()) + } + + fn abs(self: FP16x16W) -> FP16x16W { + core_fp16x16wide::abs(self) + } + + fn min_value() -> FP16x16W { + FP16x16W { mag: core_fp16x16wide::MAX, sign: true } + } + + fn max_value() -> FP16x16W { + FP16x16W { mag: core_fp16x16wide::MAX, sign: false } + } + + fn min(self: FP16x16W, other: FP16x16W) -> FP16x16W { + comp_fp16x16wide::min(self, other) + } + + fn max(self: FP16x16W, other: FP16x16W) -> FP16x16W { + comp_fp16x16wide::max(self, other) + } + + fn mag(self: FP16x16W) -> u64 { + self.mag + } + + fn is_neg(self: FP16x16W) -> bool { + self.sign + } + + fn xor(lhs: FP16x16W, rhs: FP16x16W) -> bool { + comp_fp16x16wide::xor(lhs, rhs) + } + + fn or(lhs: FP16x16W, rhs: FP16x16W) -> bool { + comp_fp16x16wide::or(lhs, rhs) + } + + fn sign(self: FP16x16W) -> FP16x16W { + core_fp16x16wide::sign(self) + } +} + use orion::numbers::fixed_point::implementations::fp64x64::core::{FP64x64Impl, FP64x64}; use orion::numbers::fixed_point::implementations::fp64x64::core as core_fp64x64; use orion::numbers::fixed_point::implementations::fp64x64::comp as comp_fp64x64; diff --git a/src/operators/tensor/implementations.cairo b/src/operators/tensor/implementations.cairo index a585b88a7..0df3dcdec 100644 --- a/src/operators/tensor/implementations.cairo +++ b/src/operators/tensor/implementations.cairo @@ -5,3 +5,4 @@ mod tensor_fp8x23; mod tensor_fp16x16; mod tensor_fp64x64; mod tensor_fp32x32; +mod tensor_fp16x16wide; \ No newline at end of file diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo new file mode 100644 index 000000000..0a89fe72d --- /dev/null +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -0,0 +1,361 @@ +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; +use traits::{TryInto, Into}; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{ + new_tensor, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, +}; +use orion::operators::tensor::{math, linalg, quantization, core}; +use orion::numbers::{i8, i32, NumberTrait, FP16x16W}; +use orion::operators::tensor::implementations::{tensor_i8::I8Tensor, tensor_u32::U32Tensor}; + +impl FP16x16WTensor of TensorTrait { + fn new(shape: Span, data: Span) -> Tensor { + new_tensor(shape, data) + } + + fn at(self: @Tensor, indices: Span) -> FP16x16W { + *at_tensor(self, indices) + } + + fn min(self: @Tensor) -> FP16x16W { + math::min::min_in_tensor::(*self.data) + } + + fn max(self: @Tensor) -> FP16x16W { + math::max::max_in_tensor(*self.data) + } + + fn stride(self: @Tensor) -> Span { + stride(*self.shape) + } + + fn ravel_index(self: @Tensor, indices: Span) -> usize { + ravel_index(*self.shape, indices) + } + + fn unravel_index(self: @Tensor, index: usize) -> Span { + unravel_index(index, *self.shape) + } + + fn reshape(self: @Tensor, target_shape: Span) -> Tensor { + reshape(self, target_shape) + } + + fn reduce_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_sum::reduce_sum(self, axis, keepdims) + } + + fn argmax( + self: @Tensor, + axis: usize, + keepdims: Option, + select_last_index: Option + ) -> Tensor { + math::argmax::argmax(self, axis, keepdims, select_last_index) + } + + fn argmin( + self: @Tensor, + axis: usize, + keepdims: Option, + select_last_index: Option + ) -> Tensor { + math::argmin::argmin(self, axis, keepdims, select_last_index) + } + + fn transpose(self: @Tensor, axes: Span) -> Tensor { + linalg::transpose::transpose(self, axes) + } + + fn matmul(self: @Tensor, other: @Tensor) -> Tensor { + linalg::matmul::matmul(self, other) + } + + fn exp(self: @Tensor) -> Tensor { + math::exp::exp(*self) + } + + fn log(self: @Tensor) -> Tensor { + math::log::log(*self) + } + + fn equal(self: @Tensor, other: @Tensor) -> Tensor { + math::equal::equal(self, other) + } + + fn greater(self: @Tensor, other: @Tensor) -> Tensor { + math::greater::greater(self, other) + } + + fn greater_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::greater_equal::greater_equal(self, other) + } + + fn less(self: @Tensor, other: @Tensor) -> Tensor { + math::less::less(self, other) + } + + fn less_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::less_equal::less_equal(self, other) + } + + fn abs(self: @Tensor) -> Tensor { + math::abs::abs(*self) + } + + fn ceil(self: @Tensor) -> Tensor { + math::ceil::ceil(*self) + } + + fn sin(self: @Tensor) -> Tensor { + math::sin::sin(*self) + } + + fn cos(self: @Tensor) -> Tensor { + math::cos::cos(*self) + } + + fn asin(self: @Tensor) -> Tensor { + math::asin::asin(*self) + } + + fn cumsum( + self: @Tensor, axis: usize, exclusive: Option, reverse: Option + ) -> Tensor { + math::cumsum::cumsum(self, axis, exclusive, reverse) + } + + fn flatten(self: @Tensor, axis: usize) -> Tensor { + math::flatten::flatten(self, axis) + } + + fn sinh(self: @Tensor) -> Tensor { + math::sinh::sinh(*self) + } + + fn tanh(self: @Tensor) -> Tensor { + math::tanh::tanh(*self) + } + + fn cosh(self: @Tensor) -> Tensor { + math::cosh::cosh(*self) + } + + fn acosh(self: @Tensor) -> Tensor { + math::acosh::acosh(*self) + } + + fn asinh(self: @Tensor) -> Tensor { + math::asinh::asinh(*self) + } + + fn atan(self: @Tensor) -> Tensor { + math::atan::atan(*self) + } + + fn xor(self: @Tensor, other: @Tensor) -> Tensor { + math::xor::xor(self, other) + } + + fn or(self: @Tensor, other: @Tensor) -> Tensor { + math::or::or(self, other) + } + + fn acos(self: @Tensor) -> Tensor { + math::acos::acos(*self) + } + + fn onehot( + self: @Tensor, depth: usize, axis: Option, values: Span + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sqrt(self: @Tensor) -> Tensor { + math::sqrt::sqrt(*self) + } + + fn concat(tensors: Span>, axis: usize,) -> Tensor { + math::concat::concat(tensors, axis) + } + + fn quantize_linear( + self: @Tensor, y_scale: @Tensor, y_zero_point: @Tensor + ) -> Tensor:: { + quantization::quantize_linear::quantize_linear( + self, + y_scale, + y_zero_point, + NumberTrait::new_unscaled(128, true), + NumberTrait::new_unscaled(127, false) + ) + } + + fn dequantize_linear( + self: @Tensor, x_scale: @Tensor, x_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn slice( + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> + ) -> Tensor { + core::slice::(self, starts, ends, axes, steps) + } + + fn gather( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) + } + + fn nonzero(self: @Tensor) -> Tensor { + core::nonzero(self) + } + + fn squeeze(self: @Tensor, axes: Option>) -> Tensor { + core::squeeze(self, axes) + } + + fn unsqueeze(self: @Tensor, axes: Span) -> Tensor { + core::unsqueeze(self, axes) + } + + fn sign(self: @Tensor) -> Tensor { + math::sign::sign(*self) + } + + fn clip( + self: @Tensor, min: Option, max: Option + ) -> Tensor { + core::clip(self, min, max) + } +} + +/// Implements addition for `Tensor` using the `Add` trait. +impl FP16x16WTensorAdd of Add> { + /// Adds two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise addition. + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } +} + +/// Implements subtraction for `Tensor` using the `Sub` trait. +impl FP16x16WTensorSub of Sub> { + /// Subtracts two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise subtraction. + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } +} + +/// Implements multiplication for `Tensor` using the `Mul` trait. +impl FP16x16WTensorMul of Mul> { + /// Multiplies two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise multiplication. + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } +} + +/// Implements division for `Tensor` using the `Div` trait. +impl FP16x16WTensorDiv of Div> { + /// Divides two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise division. + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } +} + +/// Implements partial equal for two `Tensor` using the `PartialEq` trait. +impl FP16x16WTensorPartialEq of PartialEq> { + fn eq(lhs: @Tensor, rhs: @Tensor) -> bool { + tensor_eq(*lhs, *rhs) + } + + fn ne(lhs: @Tensor, rhs: @Tensor) -> bool { + !tensor_eq(*lhs, *rhs) + } +} + +impl U32TryIntoU32 of TryInto { + fn try_into(self: u32) -> Option { + Option::Some(self) + } +} + + +// Internals +const PRECISION: u64 = 589; // 0.009 + +fn relative_eq(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { + let diff = *lhs - *rhs; + + let rel_diff = if *lhs.mag != 0 { + (diff / *lhs).mag + } else { + diff.mag + }; + + rel_diff <= PRECISION +} + + +fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { + let mut is_eq = true; + + loop { + if lhs.shape.len() == 0 || !is_eq { + break; + } + + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; + + if !is_eq { + return false; + } + + loop { + if lhs.data.len() == 0 || !is_eq { + break; + } + + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; + + return is_eq; +} + From fc98e829f7f3c31e1e9e19b7b95d97a74eea68af Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:36:03 +0300 Subject: [PATCH 13/78] implement softmaxWide --- src/operators/nn/functional/softmax.cairo | 30 +++++++++++++- .../nn/implementations/nn_fp16x16.cairo | 8 +++- src/operators/tensor/math/arithmetic.cairo | 41 +++++++++++++++++++ src/operators/tensor/math/exp.cairo | 34 +++++++++++++++ 4 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 528856265..fdbb7054f 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -1,5 +1,6 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; - +use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast}; +use orion::numbers::fixed_point::core::FixedTrait; /// Cf: NNTrait::softmax docstring fn softmax< @@ -19,3 +20,30 @@ fn softmax< return softmax; } +/// Cf: NNTrait::softmax docstring +fn softmaxWide< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TFixed: FixedTrait, + impl WFixed: FixedTrait, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax: Tensor = div_downcast(@exp_tensor, @sum); + + return softmax; +} + diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index a0094de29..b940d8742 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -7,6 +7,12 @@ use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16; use orion::operators::tensor::implementations::tensor_fp16x16::{ FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd }; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16WImpl, FP16x16WTryIntoFP16x16, FP16x16W, FP16x16IntoFP16x16W +}; +use orion::operators::tensor::implementations::tensor_fp16x16wide::{ + FP16x16WTensor, FP16x16WTensorDiv, FP16x16WTensorAdd +}; impl FP16x16NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -18,7 +24,7 @@ impl FP16x16NN of NNTrait { } fn softmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::softmax::softmax(tensor, axis) + functional::softmax::softmaxWide::(tensor, axis) } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { diff --git a/src/operators/tensor/math/arithmetic.cairo b/src/operators/tensor/math/arithmetic.cairo index 075f565b8..06879a4af 100644 --- a/src/operators/tensor/math/arithmetic.cairo +++ b/src/operators/tensor/math/arithmetic.cairo @@ -304,3 +304,44 @@ fn saturated_div< return TensorTrait::::new(broadcasted_shape, result.span()); } + +fn div_downcast< + T, + D, + impl TTensor: TensorTrait, + impl DTensor: TensorTrait, + impl DDiv: Div, + impl TTryIntoD: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl DCopy: Copy, + impl DDrop: Drop +>( + self: @Tensor, other: @Tensor +) -> Tensor { + let broadcasted_shape = broadcast_shape(*self.shape, *other.shape); + let mut result = ArrayTrait::new(); + + let num_elements = len_from_shape(broadcasted_shape); + + let mut n: usize = 0; + loop { + let indices_broadcasted = unravel_index(n, broadcasted_shape); + + let indices_self = broadcast_index_mapping(*self.shape, indices_broadcasted); + let indices_other = broadcast_index_mapping(*other.shape, indices_broadcasted); + + result + .append( + (*(*self.data)[indices_self]).try_into().unwrap() + / (*(*other.data)[indices_other]).try_into().unwrap() + ); + + n += 1; + if n == num_elements { + break (); + }; + }; + + return TensorTrait::::new(broadcasted_shape, result.span()); +} diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 3ba1e97d7..5ba161030 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -34,3 +34,37 @@ fn exp< return TensorTrait::new(self.shape, result.span()); } + +/// Cf: TensorTrait::exp docstring +fn exp_upcast< + T, + MAG, + W, + WMAG, + impl TFixedTrait: FixedTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, + impl WFixedTrait: FixedTrait, + impl WTensor: TensorTrait, + impl WCopy: Copy, + impl WDrop: Drop, + impl TIntoW: Into, +>( + mut self: Tensor +) -> Tensor { + let mut result = ArrayTrait::new(); + + loop { + match self.data.pop_front() { + Option::Some(item) => { + result.append((TIntoW::into(*item)).exp()); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::new(self.shape, result.span()); +} From eb09f55a1ed5c06eac5a2df224f028f7be6c7b3b Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:46:59 +0300 Subject: [PATCH 14/78] implement softmaxWide2 --- src/operators/nn/functional/softmax.cairo | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index fdbb7054f..1d3c59090 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -47,3 +47,15 @@ fn softmaxWide< return softmax; } +use orion::numbers::{FP16x16, FP16x16W}; +use orion::operators::tensor::{ + implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor +}; + +/// Cf: NNTrait::softmax docstring +fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax = exp_tensor / sum; + return softmax; +} From 73e49ab080a350d6f041cc5c9aaa6ff5df9cf0e1 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:48:44 +0300 Subject: [PATCH 15/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 1d3c59090..81cfd6c4e 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,6 +56,6 @@ use orion::operators::tensor::{ fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - let softmax = exp_tensor / sum; - return softmax; + // let softmax = exp_tensor / sum; + return sum; } From a835f7375071d5a011b49918f19ff617d3c20734 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:51:50 +0300 Subject: [PATCH 16/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 81cfd6c4e..870db1611 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -55,7 +55,7 @@ use orion::operators::tensor::{ /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - let sum = exp_tensor.reduce_sum(axis, true); + // let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; - return sum; + return exp_tensor; } From 541f9c450e213e36296adac28e94071f052809b5 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:58:23 +0300 Subject: [PATCH 17/78] Update core.cairo --- src/numbers/fixed_point/implementations/fp16x16/math/core.cairo | 1 + 1 file changed, 1 insertion(+) diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo index e113b97c7..fc05cd941 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16, b: @FP16x16) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16) -> FP16x16 { + a.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } From 5f7e22e3f11e733e533d0ac506339f0b72f989d6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:59:59 +0300 Subject: [PATCH 18/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 870db1611..a26288fc9 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -53,9 +53,10 @@ use orion::operators::tensor::{ }; /// Cf: NNTrait::softmax docstring -fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { - let exp_tensor: Tensor = exp_upcast(*z); +fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { + // let exp_tensor: Tensor = exp_upcast(*z); // let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; - return exp_tensor; + // return exp_tensor; + *z } From 1b1064e406979a2ebfdfba7d19d57833df69ef14 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:06:20 +0300 Subject: [PATCH 19/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index a26288fc9..f8d4f234c 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -54,7 +54,7 @@ use orion::operators::tensor::{ /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { - // let exp_tensor: Tensor = exp_upcast(*z); + let exp_tensor: Tensor = exp_upcast(*z); // let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; // return exp_tensor; From 397da5f139db5f08e043965f8a529f53782a79af Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:07:20 +0300 Subject: [PATCH 20/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index f8d4f234c..7088c5f40 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -55,7 +55,7 @@ use orion::operators::tensor::{ /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - // let sum = exp_tensor.reduce_sum(axis, true); + let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; // return exp_tensor; *z From 09a39488abc489cbda17b382abe8372b3c0683f9 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:10:08 +0300 Subject: [PATCH 21/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 7088c5f40..bc5dd4025 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,7 +56,7 @@ use orion::operators::tensor::{ fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - // let softmax = exp_tensor / sum; + let softmax = exp_tensor / sum; // return exp_tensor; *z } From 3bb5a5d30bf2bf10fdcc8228939731ca1e345b51 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:13:14 +0300 Subject: [PATCH 22/78] add print --- .../fixed_point/implementations/fp16x16/math/core.cairo | 1 - src/operators/nn/functional/softmax.cairo | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo index fc05cd941..e113b97c7 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo @@ -61,7 +61,6 @@ fn eq(a: @FP16x16, b: @FP16x16) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16) -> FP16x16 { - a.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index bc5dd4025..247f8af5f 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -51,12 +51,14 @@ use orion::numbers::{FP16x16, FP16x16W}; use orion::operators::tensor::{ implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor }; +use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - let softmax = exp_tensor / sum; + (sum.data.len()).print(); + // let softmax = exp_tensor / sum; // return exp_tensor; *z } From e881ab2899cab807b9c727ad964a0eeebd8501f1 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:14:46 +0300 Subject: [PATCH 23/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 247f8af5f..e9d968c8a 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -57,7 +57,7 @@ use debug::PrintTrait; fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - (sum.data.len()).print(); + (*sum.data.at(0)).print(); // let softmax = exp_tensor / sum; // return exp_tensor; *z From 803fbcecdbfcc3477ea5ada53c7b5dea1ba00d3f Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:17:11 +0300 Subject: [PATCH 24/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index e9d968c8a..6fdfd747b 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,8 +56,10 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - let sum = exp_tensor.reduce_sum(axis, true); - (*sum.data.at(0)).print(); + (*exp_tensor.data.at(0)).print(); + + // let sum = exp_tensor.reduce_sum(axis, true); + // (*sum.data.at(0)).print(); // let softmax = exp_tensor / sum; // return exp_tensor; *z From 162f5288b7b8bbcbc16fa8039ab82d78126bc0e9 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:17:26 +0300 Subject: [PATCH 25/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 6fdfd747b..bbe9335ec 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,7 +56,7 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - (*exp_tensor.data.at(0)).print(); + (exp_tensor.data.len()).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); From febcae739c83581bff427efb7c6ffd5b30f1d6fb Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:20:34 +0300 Subject: [PATCH 26/78] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index bbe9335ec..dc7e0636e 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,7 +56,9 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - (exp_tensor.data.len()).print(); + (*exp_tensor.data.at(0)).print(); + (*exp_tensor.data.at(1)).print(); + (*exp_tensor.data.at(2)).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); From 05dbea58a527cebe2447ca8eaf4b010579637cb1 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:23:08 +0300 Subject: [PATCH 27/78] fix exp --- src/operators/nn/functional/softmax.cairo | 7 ++++--- src/operators/tensor/math/exp.cairo | 5 +++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index dc7e0636e..2f251f1b0 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -37,6 +37,7 @@ fn softmaxWide< impl WDrop: Drop, impl TFixed: FixedTrait, impl WFixed: FixedTrait, + impl TPrint: PrintTrait >( z: @Tensor, axis: usize ) -> Tensor { @@ -56,9 +57,9 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - (*exp_tensor.data.at(0)).print(); - (*exp_tensor.data.at(1)).print(); - (*exp_tensor.data.at(2)).print(); + // (*exp_tensor.data.at(0)).print(); + // (*exp_tensor.data.at(1)).print(); + // (*exp_tensor.data.at(2)).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 5ba161030..aeb620208 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -35,6 +35,8 @@ fn exp< return TensorTrait::new(self.shape, result.span()); } +use debug::PrintTrait; + /// Cf: TensorTrait::exp docstring fn exp_upcast< T, @@ -50,6 +52,7 @@ fn exp_upcast< impl WCopy: Copy, impl WDrop: Drop, impl TIntoW: Into, + impl TPrint: PrintTrait >( mut self: Tensor ) -> Tensor { @@ -58,6 +61,8 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { + (*item).print(); + result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 62edd56d83118f5edbd552ddd057d072ada335a3 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:26:05 +0300 Subject: [PATCH 28/78] fix exp --- src/operators/nn/functional/softmax.cairo | 3 ++- src/operators/tensor/math/exp.cairo | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 2f251f1b0..168fab483 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -37,7 +37,8 @@ fn softmaxWide< impl WDrop: Drop, impl TFixed: FixedTrait, impl WFixed: FixedTrait, - impl TPrint: PrintTrait + impl TPrint: PrintTrait, + impl WPrint: PrintTrait >( z: @Tensor, axis: usize ) -> Tensor { diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index aeb620208..0b3889511 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -52,7 +52,8 @@ fn exp_upcast< impl WCopy: Copy, impl WDrop: Drop, impl TIntoW: Into, - impl TPrint: PrintTrait + impl TPrint: PrintTrait, + impl WPrint: PrintTrait >( mut self: Tensor ) -> Tensor { @@ -61,7 +62,7 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (*item).print(); + (TIntoW::into(*item)).print(); result.append((TIntoW::into(*item)).exp()); }, From d6c86316c4d72deecc6a718c285a40a052308576 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:27:14 +0300 Subject: [PATCH 29/78] Update exp.cairo --- src/operators/tensor/math/exp.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 0b3889511..18ee0ccde 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,7 +62,7 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (TIntoW::into(*item)).print(); + (TIntoW::into(*item)).exp().print(); result.append((TIntoW::into(*item)).exp()); }, From dc2019f5f22210ae404b36004ab75f0cce58eccd Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:28:57 +0300 Subject: [PATCH 30/78] debugin' --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 1 + src/operators/tensor/math/exp.cairo | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 33c1c6d85..a3ed48b4c 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { + a.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 18ee0ccde..92b60e2ac 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,8 +62,6 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (TIntoW::into(*item)).exp().print(); - result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 12ac782e299fc52c4769918f52fc306d93fb45c4 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:30:36 +0300 Subject: [PATCH 31/78] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index a3ed48b4c..6b6a74e07 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - a.print(); + (FixedTrait::new(94548, false) * a).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } From ce0c01e4722b4fc37c2c85d243012763d88dfa82 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:31:54 +0300 Subject: [PATCH 32/78] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 6b6a74e07..87a7c8706 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - (FixedTrait::new(94548, false) * a).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -87,6 +86,8 @@ fn exp2(a: FP16x16W) -> FP16x16W { res_u = res_u * (r1 + FixedTrait::ONE()); } + res_u.print(); + if (a.sign == true) { return FixedTrait::ONE() / res_u; } else { From acd104362cd362a3a84624abaea00d94bf0656a7 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:34:46 +0300 Subject: [PATCH 33/78] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 87a7c8706..bcd7fc8af 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { + (exp2(FixedTrait::new(94548, false) * a)).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -86,8 +87,6 @@ fn exp2(a: FP16x16W) -> FP16x16W { res_u = res_u * (r1 + FixedTrait::ONE()); } - res_u.print(); - if (a.sign == true) { return FixedTrait::ONE() / res_u; } else { From 715a497d7825e6163c449b39b153a18705faa23a Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:36:14 +0300 Subject: [PATCH 34/78] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index bcd7fc8af..2395e4c4a 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -6,8 +6,8 @@ use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ HALF, ONE, MAX, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WAddEq, FP16x16WSub, FP16x16WMul, - FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, FP16x16WNeg, - FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait + FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, + FP16x16WNeg, FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait }; use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - (exp2(FixedTrait::new(94548, false) * a)).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -88,6 +87,7 @@ fn exp2(a: FP16x16W) -> FP16x16W { } if (a.sign == true) { + (FixedTrait::ONE() / res_u).print(); return FixedTrait::ONE() / res_u; } else { return res_u; From 214a95f4800fdce9fcc5d224616cff4f5e77dcf7 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:37:16 +0300 Subject: [PATCH 35/78] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 2395e4c4a..38db5b84d 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { + a.sign.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -87,7 +88,6 @@ fn exp2(a: FP16x16W) -> FP16x16W { } if (a.sign == true) { - (FixedTrait::ONE() / res_u).print(); return FixedTrait::ONE() / res_u; } else { return res_u; From 6a82ce9338f31a62d6ab59b96eec5867ba69d647 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:39:15 +0300 Subject: [PATCH 36/78] debbug --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 2 +- src/operators/tensor/math/exp.cairo | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 38db5b84d..9ab7fcc4c 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - a.sign.print(); + // a.sign.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 92b60e2ac..0b3889511 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,6 +62,8 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { + (TIntoW::into(*item)).print(); + result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 3b56c695818b5c23e042b8199dc0eee12f7d03be Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:39:51 +0300 Subject: [PATCH 37/78] Update exp.cairo --- src/operators/tensor/math/exp.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 0b3889511..c4a9903c0 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,7 +62,7 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (TIntoW::into(*item)).print(); + ((*item)).print(); result.append((TIntoW::into(*item)).exp()); }, From 5267e358e77911e1c0bedae7887a5e7eab0e1387 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:41:41 +0300 Subject: [PATCH 38/78] debug --- src/operators/nn/functional/softmax.cairo | 7 ++++--- src/operators/tensor/math/exp.cairo | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 168fab483..59c66c3d2 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -58,9 +58,10 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - // (*exp_tensor.data.at(0)).print(); - // (*exp_tensor.data.at(1)).print(); - // (*exp_tensor.data.at(2)).print(); + + (*(*z).data.at(0)).print(); + (*(*z).data.at(1)).print(); + (*(*z).data.at(2)).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index c4a9903c0..92b60e2ac 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,8 +62,6 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - ((*item)).print(); - result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 9ab6af62c325c55533190754a4d061c4183d75c6 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Sun, 22 Oct 2023 23:05:34 -0400 Subject: [PATCH 39/78] feat: neg operator --- src/numbers.cairo | 41 +++++++++++++++++++ src/operators/tensor/core.cairo | 41 +++++++++++++++++++ .../implementations/tensor_fp16x16.cairo | 4 ++ .../implementations/tensor_fp32x32.cairo | 4 ++ .../implementations/tensor_fp64x64.cairo | 4 ++ .../implementations/tensor_fp8x23.cairo | 4 ++ .../tensor/implementations/tensor_i32.cairo | 4 ++ .../tensor/implementations/tensor_i8.cairo | 4 ++ .../tensor/implementations/tensor_u32.cairo | 4 ++ src/operators/tensor/math.cairo | 1 + src/operators/tensor/math/neg.cairo | 32 +++++++++++++++ 11 files changed, 143 insertions(+) create mode 100644 src/operators/tensor/math/neg.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index 04ad6efa5..3a6ab7807 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -10,6 +10,7 @@ trait NumberTrait { fn new_unscaled(mag: MAG, sign: bool) -> T; fn from_felt(val: felt252) -> T; fn abs(self: T) -> T; + fn neg(self: T) -> T; fn ceil(self: T) -> T; fn exp(self: T) -> T; fn exp2(self: T) -> T; @@ -176,6 +177,10 @@ impl FP8x23Number of NumberTrait { core_fp8x23::abs(self) } + fn neg(self: FP8x23) -> FP8x23 { + core_fp8x23::neg(self) + } + fn min_value() -> FP8x23 { FP8x23 { mag: core_fp8x23::MAX, sign: true } } @@ -341,6 +346,10 @@ impl FP16x16Number of NumberTrait { core_fp16x16::abs(self) } + fn neg(self: FP16x16) -> FP16x16 { + core_fp16x16::neg(self) + } + fn min_value() -> FP16x16 { FP16x16 { mag: core_fp16x16::MAX, sign: true } } @@ -507,6 +516,10 @@ impl FP64x64Number of NumberTrait { fp64x64::core::abs(self) } + fn neg(self: FP64x64) -> FP64x64 { + fp64x64::core::neg(self) + } + fn min_value() -> FP64x64 { FP64x64 { mag: core_fp64x64::MAX, sign: true } } @@ -673,6 +686,10 @@ impl FP32x32Number of NumberTrait { fp32x32::core::abs(self) } + fn neg(self: FP32x32) -> FP32x32 { + fp32x32::core::neg(self) + } + fn min_value() -> FP32x32 { FP32x32 { mag: core_fp32x32::MAX, sign: true } } @@ -837,6 +854,10 @@ impl I8Number of NumberTrait { i8_core::i8_abs(self) } + fn neg(self: i8) -> i8 { + i8_core::i8_neg(self) + } + fn min_value() -> i8 { i8 { mag: 128, sign: true } } @@ -1009,6 +1030,10 @@ impl i16Number of NumberTrait { i16_core::i16_abs(self) } + fn neg(self: i16) -> i16 { + i16_core::i16_neg(self) + } + fn min_value() -> i16 { i16 { mag: 32768, sign: true } } @@ -1181,6 +1206,10 @@ impl i32Number of NumberTrait { i32_core::i32_abs(self) } + fn neg(self: i32) -> i32 { + i32_core::i32_neg(self) + } + fn min_value() -> i32 { i32 { mag: 2147483648, sign: true } } @@ -1353,6 +1382,10 @@ impl i64Number of NumberTrait { i64_core::i64_abs(self) } + fn neg(self: i64) -> i64 { + i64_core::i64_neg(self) + } + fn min_value() -> i64 { i64 { mag: 9223372036854775808, sign: true } } @@ -1526,6 +1559,10 @@ impl i128Number of NumberTrait { i128_core::i128_abs(self) } + fn neg(self: i128) -> i128 { + i128_core::i128_neg(self) + } + fn min_value() -> i128 { i128 { mag: 170141183460469231731687303715884105728, sign: true } } @@ -1696,6 +1733,10 @@ impl u32Number of NumberTrait { self } + fn neg(self: u32) -> u32 { + panic(array!['not supported']) + } + fn min_value() -> u32 { 0 } diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 535ed699c..4b9c9ed52 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -1234,6 +1234,47 @@ trait TensorTrait { /// ``` /// fn abs(self: @Tensor) -> Tensor; + /// #tensor.neg + /// + /// ```rust + /// fn neg(self: @Tensor) -> Tensor; + /// ``` + /// + /// Computes the negation of all elements in the input tensor. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// + /// + /// ## Returns + /// + /// A new `Tensor` of the same shape as the input tensor with + /// the negation of all elements in the input tensor. + /// + /// ## Example + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, I32Tensor}; + /// use orion::numbers::{i32, IntegerTrait}; + /// + /// fn neg_example() -> Tensor { + /// let tensor = TensorTrait::new( + /// shape: array![3].span(), + /// data: array![ + /// IntegerTrait::new(1, true), IntegerTrait::new(2, true), IntegerTrait::new(3, false) + /// ] + /// .span(), + /// ); + /// + /// return tensor.neg(); + /// } + /// >>> [1, 2, -3] + /// ``` + /// + fn neg(self: @Tensor) -> Tensor; /// #tensor.ceil /// /// ```rust diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ac803904e..1ad60086f 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -100,6 +100,10 @@ impl FP16x16Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { math::ceil::ceil(*self) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 19557eba1..f421b8113 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -101,6 +101,10 @@ impl FP32x32Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { math::ceil::ceil(*self) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 686d319b1..c00da33fe 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -101,6 +101,10 @@ impl FP64x64Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { math::ceil::ceil(*self) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index f46e96fb5..956e3b2ef 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -100,6 +100,10 @@ impl FP8x23Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { math::ceil::ceil(*self) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index dd6ae6f59..ae631b90a 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -101,6 +101,10 @@ impl I32Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index a4518cefb..583871bd6 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -100,6 +100,10 @@ impl I8Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 4d26a8634..315aff0c5 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -100,6 +100,10 @@ impl U32Tensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index 5ac834280..2874b7e6e 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -32,3 +32,4 @@ mod sqrt; mod concat; mod gather; mod sign; +mod neg; diff --git a/src/operators/tensor/math/neg.cairo b/src/operators/tensor/math/neg.cairo new file mode 100644 index 000000000..b34ea16f0 --- /dev/null +++ b/src/operators/tensor/math/neg.cairo @@ -0,0 +1,32 @@ +use array::ArrayTrait; +use option::OptionTrait; +use array::SpanTrait; + +use orion::operators::tensor::core::{Tensor, TensorTrait}; +use orion::numbers::NumberTrait; + +/// Cf: TensorTrait::neg docstring +fn neg< + T, + MAG, + impl TTensor: TensorTrait, + impl TNumberTrait: NumberTrait, + impl TCopy: Copy, + impl TDrop: Drop +>( + mut z: Tensor +) -> Tensor { + let mut data_result = ArrayTrait::::new(); + loop { + match z.data.pop_front() { + Option::Some(item) => { + data_result.append((*item).neg()); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::::new(z.shape, data_result.span()); +} From 247e817e46180e605caa6e62fbce1ec708c97083 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Sun, 22 Oct 2023 23:43:19 -0400 Subject: [PATCH 40/78] add unit test --- nodegen/node/neg.py | 55 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 nodegen/node/neg.py diff --git a/nodegen/node/neg.py b/nodegen/node/neg.py new file mode 100644 index 000000000..a78456428 --- /dev/null +++ b/nodegen/node/neg.py @@ -0,0 +1,55 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl + + +class Neg(RunAll): + @staticmethod + def neg_i32(): + x = np.random.randint(-127, 127, (2, 2)).astype(np.int32) + y = np.negative(x) + + x = Tensor(Dtype.I32, x.shape, x.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "neg_i32" + make_node([x], [y], name) + make_test([x], y, "input_0.neg()", name) + + @staticmethod + def neg_i8(): + x = np.random.randint(-127, 127, (2, 2)).astype(np.int8) + y = np.negative(x) + + x = Tensor(Dtype.I8, x.shape, x.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "neg_i8" + make_node([x], [y], name) + make_test([x], y, "input_0.neg()", name) + + @staticmethod + def neg_fp8x23(): + x = to_fp(np.random.randint(-127, 127, (2, 2) + ).astype(np.int64), FixedImpl.FP8x23) + y = np.negative(x) + + x = Tensor(Dtype.FP8x23, x.shape, x.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, y.flatten()) + + name = "neg_fp8x23" + make_node([x], [y], name) + make_test([x], y, "input_0.neg()", name) + + @staticmethod + def neg_fp16x16(): + x = to_fp(np.random.randint(-127, 127, (2, 2) + ).astype(np.int64), FixedImpl.FP16x16) + y = np.negative(x) + + x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, y.flatten()) + + name = "neg_fp16x16" + make_node([x], [y], name) + make_test([x], y, "input_0.neg()", name) From 8b2276bab3f82b257959a3fededab5e46714a4a6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 10:25:23 +0300 Subject: [PATCH 41/78] clean --- .../fp16x16wide/math/core.cairo | 1 - src/operators/nn/functional/softmax.cairo | 30 ++----------------- src/operators/tensor/math/exp.cairo | 2 -- 3 files changed, 2 insertions(+), 31 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 9ab7fcc4c..4654cd6ba 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - // a.sign.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 59c66c3d2..81696ef22 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -15,9 +15,7 @@ fn softmax< ) -> Tensor { let exp_tensor = z.exp(); let sum = exp_tensor.reduce_sum(axis, true); - let softmax = exp_tensor / sum; - - return softmax; + exp_tensor / sum } /// Cf: NNTrait::softmax docstring @@ -37,35 +35,11 @@ fn softmaxWide< impl WDrop: Drop, impl TFixed: FixedTrait, impl WFixed: FixedTrait, - impl TPrint: PrintTrait, - impl WPrint: PrintTrait >( z: @Tensor, axis: usize ) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - let softmax: Tensor = div_downcast(@exp_tensor, @sum); - - return softmax; + div_downcast(@exp_tensor, @sum) } -use orion::numbers::{FP16x16, FP16x16W}; -use orion::operators::tensor::{ - implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor -}; -use debug::PrintTrait; - -/// Cf: NNTrait::softmax docstring -fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { - let exp_tensor: Tensor = exp_upcast(*z); - - (*(*z).data.at(0)).print(); - (*(*z).data.at(1)).print(); - (*(*z).data.at(2)).print(); - - // let sum = exp_tensor.reduce_sum(axis, true); - // (*sum.data.at(0)).print(); - // let softmax = exp_tensor / sum; - // return exp_tensor; - *z -} diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 92b60e2ac..83f79eac7 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -52,8 +52,6 @@ fn exp_upcast< impl WCopy: Copy, impl WDrop: Drop, impl TIntoW: Into, - impl TPrint: PrintTrait, - impl WPrint: PrintTrait >( mut self: Tensor ) -> Tensor { From f79e0f5efff681eca6e23d5313f23c3a5d92ffdb Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 10:35:23 +0300 Subject: [PATCH 42/78] implement fp8x23wide --- src/numbers/fixed_point/implementations.cairo | 3 +- .../implementations/fp8x23wide.cairo | 4 + .../implementations/fp8x23wide/core.cairo | 378 +++++ .../implementations/fp8x23wide/helpers.cairo | 39 + .../implementations/fp8x23wide/math.cairo | 5 + .../fp8x23wide/math/comp.cairo | 76 + .../fp8x23wide/math/core.cairo | 660 +++++++++ .../implementations/fp8x23wide/math/hyp.cairo | 159 +++ .../implementations/fp8x23wide/math/lut.cairo | 1229 +++++++++++++++++ .../fp8x23wide/math/trig.cairo | 448 ++++++ 10 files changed, 3000 insertions(+), 1 deletion(-) create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo diff --git a/src/numbers/fixed_point/implementations.cairo b/src/numbers/fixed_point/implementations.cairo index e6152e25a..d7617f9c8 100644 --- a/src/numbers/fixed_point/implementations.cairo +++ b/src/numbers/fixed_point/implementations.cairo @@ -2,4 +2,5 @@ mod fp8x23; mod fp16x16; mod fp64x64; mod fp32x32; -mod fp16x16wide; \ No newline at end of file +mod fp16x16wide; +mod fp8x23wide; \ No newline at end of file diff --git a/src/numbers/fixed_point/implementations/fp8x23wide.cairo b/src/numbers/fixed_point/implementations/fp8x23wide.cairo new file mode 100644 index 000000000..2cc1d5085 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide.cairo @@ -0,0 +1,4 @@ +mod core; +mod math; +mod helpers; + diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo new file mode 100644 index 000000000..36b64ce5e --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo @@ -0,0 +1,378 @@ +use debug::PrintTrait; + +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{TryInto, Into}; + +use orion::numbers::signed_integer::{i32::i32, i8::i8}; +use orion::numbers::fixed_point::core::{FixedTrait}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::{core, trig, hyp}; +use orion::numbers::fixed_point::utils; + +/// A struct representing a fixed point number. +#[derive(Serde, Copy, Drop)] +struct FP8x23W { + mag: u64, + sign: bool +} + +// CONSTANTS + +const TWO: u64 = 16777216; // 2 ** 24 +const ONE: u64 = 8388608; // 2 ** 23 +const HALF: u64 = 4194304; // 2 ** 22 +const MAX: u64 = 2147483648; // 2 ** 31 + + +impl FP8x23WImpl of FixedTrait { + fn ZERO() -> FP8x23W { + return FP8x23W { mag: 0, sign: false }; + } + + fn ONE() -> FP8x23W { + return FP8x23W { mag: ONE, sign: false }; + } + + fn MAX() -> FP8x23W { + return FP8x23W { mag: MAX, sign: false }; + } + + fn new(mag: u64, sign: bool) -> FP8x23W { + return FP8x23W { mag: mag, sign: sign }; + } + + fn new_unscaled(mag: u64, sign: bool) -> FP8x23W { + return FP8x23W { mag: mag * ONE, sign: sign }; + } + + fn from_felt(val: felt252) -> FP8x23W { + let mag = integer::u64_try_from_felt252(utils::felt_abs(val)).unwrap(); + return FixedTrait::new(mag, utils::felt_sign(val)); + } + + fn abs(self: FP8x23W) -> FP8x23W { + return core::abs(self); + } + + fn acos(self: FP8x23W) -> FP8x23W { + return trig::acos_fast(self); + } + + fn acos_fast(self: FP8x23W) -> FP8x23W { + return trig::acos_fast(self); + } + + fn acosh(self: FP8x23W) -> FP8x23W { + return hyp::acosh(self); + } + + fn asin(self: FP8x23W) -> FP8x23W { + return trig::asin_fast(self); + } + + fn asin_fast(self: FP8x23W) -> FP8x23W { + return trig::asin_fast(self); + } + + fn asinh(self: FP8x23W) -> FP8x23W { + return hyp::asinh(self); + } + + fn atan(self: FP8x23W) -> FP8x23W { + return trig::atan_fast(self); + } + + fn atan_fast(self: FP8x23W) -> FP8x23W { + return trig::atan_fast(self); + } + + fn atanh(self: FP8x23W) -> FP8x23W { + return hyp::atanh(self); + } + + fn ceil(self: FP8x23W) -> FP8x23W { + return core::ceil(self); + } + + fn cos(self: FP8x23W) -> FP8x23W { + return trig::cos_fast(self); + } + + fn cos_fast(self: FP8x23W) -> FP8x23W { + return trig::cos_fast(self); + } + + fn cosh(self: FP8x23W) -> FP8x23W { + return hyp::cosh(self); + } + + fn floor(self: FP8x23W) -> FP8x23W { + return core::floor(self); + } + + // Calculates the natural exponent of x: e^x + fn exp(self: FP8x23W) -> FP8x23W { + return core::exp(self); + } + + // Calculates the binary exponent of x: 2^x + fn exp2(self: FP8x23W) -> FP8x23W { + return core::exp2(self); + } + + // Calculates the natural logarithm of x: ln(x) + // self must be greater than zero + fn ln(self: FP8x23W) -> FP8x23W { + return core::ln(self); + } + + // Calculates the binary logarithm of x: log2(x) + // self must be greather than zero + fn log2(self: FP8x23W) -> FP8x23W { + return core::log2(self); + } + + // Calculates the base 10 log of x: log10(x) + // self must be greater than zero + fn log10(self: FP8x23W) -> FP8x23W { + return core::log10(self); + } + + // Calclates the value of x^y and checks for overflow before returning + // self is a fixed point value + // b is a fixed point value + fn pow(self: FP8x23W, b: FP8x23W) -> FP8x23W { + return core::pow(self, b); + } + + fn round(self: FP8x23W) -> FP8x23W { + return core::round(self); + } + + fn sin(self: FP8x23W) -> FP8x23W { + return trig::sin_fast(self); + } + + fn sin_fast(self: FP8x23W) -> FP8x23W { + return trig::sin_fast(self); + } + + fn sinh(self: FP8x23W) -> FP8x23W { + return hyp::sinh(self); + } + + // Calculates the square root of a fixed point value + // x must be positive + fn sqrt(self: FP8x23W) -> FP8x23W { + return core::sqrt(self); + } + + fn tan(self: FP8x23W) -> FP8x23W { + return trig::tan_fast(self); + } + + fn tan_fast(self: FP8x23W) -> FP8x23W { + return trig::tan_fast(self); + } + + fn tanh(self: FP8x23W) -> FP8x23W { + return hyp::tanh(self); + } + + fn sign(self: FP8x23W) -> FP8x23W { + return core::sign(self); + } +} + + +impl FP8x23WPrint of PrintTrait { + fn print(self: FP8x23W) { + self.sign.print(); + self.mag.print(); + } +} + +// Into a raw felt without unscaling +impl FP8x23WIntoFelt252 of Into { + fn into(self: FP8x23W) -> felt252 { + let mag_felt = self.mag.into(); + + if self.sign { + return mag_felt * -1; + } else { + return mag_felt * 1; + } + } +} + +impl FP8x23WTryIntoU128 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP8x23WTryIntoU64 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + + +impl FP8x23WTryIntoU16 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP8x23WTryIntoU8 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP8x23WIntoI32 of Into { + fn into(self: FP8x23W) -> i32 { + _i32_into_fp(self) + } +} + +impl FP8x23WTryIntoI8 of TryInto { + fn try_into(self: FP8x23W) -> Option { + _i8_try_from_fp(self) + } +} + +impl FP8x23WPartialEq of PartialEq { + #[inline(always)] + fn eq(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { + return core::eq(lhs, rhs); + } + + #[inline(always)] + fn ne(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { + return core::ne(lhs, rhs); + } +} + +impl FP8x23WAdd of Add { + fn add(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::add(lhs, rhs); + } +} + +impl FP8x23WAddEq of AddEq { + #[inline(always)] + fn add_eq(ref self: FP8x23W, other: FP8x23W) { + self = Add::add(self, other); + } +} + +impl FP8x23WSub of Sub { + fn sub(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::sub(lhs, rhs); + } +} + +impl FP8x23WSubEq of SubEq { + #[inline(always)] + fn sub_eq(ref self: FP8x23W, other: FP8x23W) { + self = Sub::sub(self, other); + } +} + +impl FP8x23WMul of Mul { + fn mul(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::mul(lhs, rhs); + } +} + +impl FP8x23WMulEq of MulEq { + #[inline(always)] + fn mul_eq(ref self: FP8x23W, other: FP8x23W) { + self = Mul::mul(self, other); + } +} + +impl FP8x23WDiv of Div { + fn div(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::div(lhs, rhs); + } +} + +impl FP8x23WDivEq of DivEq { + #[inline(always)] + fn div_eq(ref self: FP8x23W, other: FP8x23W) { + self = Div::div(self, other); + } +} + +impl FP8x23WPartialOrd of PartialOrd { + #[inline(always)] + fn ge(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::ge(lhs, rhs); + } + + #[inline(always)] + fn gt(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::gt(lhs, rhs); + } + + #[inline(always)] + fn le(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::le(lhs, rhs); + } + + #[inline(always)] + fn lt(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::lt(lhs, rhs); + } +} + +impl FP8x23WNeg of Neg { + #[inline(always)] + fn neg(a: FP8x23W) -> FP8x23W { + return core::neg(a); + } +} + +impl FP8x23WRem of Rem { + #[inline(always)] + fn rem(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::rem(lhs, rhs); + } +} + +/// INTERNAL + +fn _i32_into_fp(x: FP8x23W) -> i32 { + i32 { mag: (x.mag / ONE).try_into().unwrap(), sign: x.sign } +} + +fn _i8_try_from_fp(x: FP8x23W) -> Option { + let unscaled_mag: Option = (x.mag / ONE).try_into(); + + match unscaled_mag { + Option::Some(val) => Option::Some(i8 { mag: unscaled_mag.unwrap(), sign: x.sign }), + Option::None(_) => Option::None(()) + } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo new file mode 100644 index 000000000..a627803be --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo @@ -0,0 +1,39 @@ +use debug::PrintTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, TWO, FP8x23W, FP8x23WSub, FP8x23WDiv, FixedTrait, FP8x23WPrint +}; + +const DEFAULT_PRECISION: u64 = 8; // 1e-6 + +// To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. +// To use `custom_precision` of 430_u64: `Option::Some(430_u64)`. +fn assert_precise(result: FP8x23W, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = (result - FixedTrait::from_felt(expected)).mag; + + if (diff > precision) { + result.print(); + assert(diff <= precision, msg); + } +} + +fn assert_relative(result: FP8x23W, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = result - FixedTrait::from_felt(expected); + let rel_diff = (diff / result).mag; + + if (rel_diff > precision) { + result.print(); + assert(rel_diff <= precision, msg); + } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo new file mode 100644 index 000000000..970c65f30 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo @@ -0,0 +1,5 @@ +mod core; +mod comp; +mod lut; +mod trig; +mod hyp; diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo new file mode 100644 index 000000000..95b329109 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo @@ -0,0 +1,76 @@ +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + FP8x23W, FixedTrait, FP8x23WPartialOrd, FP8x23WPartialEq +}; + +fn max(a: FP8x23W, b: FP8x23W) -> FP8x23W { + if (a >= b) { + return a; + } else { + return b; + } +} + +fn min(a: FP8x23W, b: FP8x23W) -> FP8x23W { + if (a <= b) { + return a; + } else { + return b; + } +} + +fn xor(a: FP8x23W, b: FP8x23W) -> bool { + if (a == FixedTrait::new(0, false) || b == FixedTrait::new(0, false)) && (a != b) { + return true; + } else { + return false; + } +} + +fn or(a: FP8x23W, b: FP8x23W) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero && b == zero { + return false; + } else { + return true; + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +#[test] +fn test_max() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(max(a, a) == a, 'max(a, a)'); + assert(max(a, b) == a, 'max(a, b)'); + assert(max(a, c) == a, 'max(a, c)'); + + assert(max(b, a) == a, 'max(b, a)'); + assert(max(b, b) == b, 'max(b, b)'); + assert(max(b, c) == b, 'max(b, c)'); + + assert(max(c, a) == a, 'max(c, a)'); + assert(max(c, b) == b, 'max(c, b)'); + assert(max(c, c) == c, 'max(c, c)'); +} + +#[test] +fn test_min() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(min(a, a) == a, 'min(a, a)'); + assert(min(a, b) == b, 'min(a, b)'); + assert(min(a, c) == c, 'min(a, c)'); + + assert(min(b, a) == b, 'min(b, a)'); + assert(min(b, b) == b, 'min(b, b)'); + assert(min(b, c) == c, 'min(b, c)'); + + assert(min(c, a) == c, 'min(c, a)'); + assert(min(c, b) == c, 'min(c, b)'); + assert(min(c, c) == c, 'min(c, c)'); +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo new file mode 100644 index 000000000..129ff02c8 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo @@ -0,0 +1,660 @@ +use core::debug::PrintTrait; +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{Into, TryInto}; +use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; + +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, MAX, FP8x23W, FP8x23WAdd, FP8x23WImpl, FP8x23WAddEq, FP8x23WSub, FP8x23WMul, + FP8x23WMulEq, FP8x23WTryIntoU128, FP8x23WPartialEq, FP8x23WPartialOrd, FP8x23WSubEq, FP8x23WNeg, + FP8x23WDiv, FP8x23WIntoFelt252, FixedTrait +}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::lut; + +// PUBLIC + +fn abs(a: FP8x23W) -> FP8x23W { + return FixedTrait::new(a.mag, false); +} + +fn add(a: FP8x23W, b: FP8x23W) -> FP8x23W { + if a.sign == b.sign { + return FixedTrait::new(a.mag + b.mag, a.sign); + } + + if a.mag == b.mag { + return FixedTrait::ZERO(); + } + + if (a.mag > b.mag) { + return FixedTrait::new(a.mag - b.mag, a.sign); + } else { + return FixedTrait::new(b.mag - a.mag, b.sign); + } +} + +fn ceil(a: FP8x23W) -> FP8x23W { + let (div, rem) = u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div + 1, false); + } else if div == 0 { + return FixedTrait::new_unscaled(0, false); + } else { + return FixedTrait::new_unscaled(div, true); + } +} + +fn div(a: FP8x23W, b: FP8x23W) -> FP8x23W { + let a_u64 = integer::u64_wide_mul(a.mag, ONE); + let res_u64 = a_u64 / b.mag.into(); + + // Re-apply sign + return FixedTrait::new(res_u64.try_into().unwrap(), a.sign ^ b.sign); +} + +fn eq(a: @FP8x23W, b: @FP8x23W) -> bool { + return (*a.mag == *b.mag) && (*a.sign == *b.sign); +} + +// Calculates the natural exponent of x: e^x +fn exp(a: FP8x23W) -> FP8x23W { + return exp2(FixedTrait::new(12102203, false) * a); // log2(e) * 2^23 ≈ 12102203 +} + +// Calculates the binary exponent of x: 2^x +fn exp2(a: FP8x23W) -> FP8x23W { + if (a.mag == 0) { + return FixedTrait::ONE(); + } + + let (int_part, frac_part) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + let int_res = FixedTrait::new_unscaled(lut::exp2(int_part), false); + let mut res_u = int_res; + + if frac_part != 0 { + let frac = FixedTrait::new(frac_part, false); + let r8 = FixedTrait::new(19, false) * frac; + let r7 = (r8 + FixedTrait::new(105, false)) * frac; + let r6 = (r7 + FixedTrait::new(1324, false)) * frac; + let r5 = (r6 + FixedTrait::new(11159, false)) * frac; + let r4 = (r5 + FixedTrait::new(80695, false)) * frac; + let r3 = (r4 + FixedTrait::new(465599, false)) * frac; + let r2 = (r3 + FixedTrait::new(2015166, false)) * frac; + let r1 = (r2 + FixedTrait::new(5814540, false)) * frac; + res_u = res_u * (r1 + FixedTrait::ONE()); + } + + if (a.sign == true) { + return FixedTrait::ONE() / res_u; + } else { + return res_u; + } +} + +fn exp2_int(exp: u64) -> FP8x23W { + return FixedTrait::new_unscaled(lut::exp2(exp), false); +} + +fn floor(a: FP8x23W) -> FP8x23W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div, false); + } else { + return FixedTrait::new_unscaled(div + 1, true); + } +} + +fn ge(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag == b.mag) || ((a.mag > b.mag) ^ a.sign); + } +} + +fn gt(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag != b.mag) && ((a.mag > b.mag) ^ a.sign); + } +} + +fn le(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag == b.mag) || ((a.mag < b.mag) ^ a.sign); + } +} + +// Calculates the natural logarithm of x: ln(x) +// self must be greater than zero +fn ln(a: FP8x23W) -> FP8x23W { + return FixedTrait::new(5814540, false) * log2(a); // ln(2) = 0.693... +} + +// Calculates the binary logarithm of x: log2(x) +// self must be greather than zero +fn log2(a: FP8x23W) -> FP8x23W { + assert(a.sign == false, 'must be positive'); + + if (a.mag == ONE) { + return FixedTrait::ZERO(); + } else if (a.mag < ONE) { + // Compute true inverse binary log if 0 < x < 1 + let div = FixedTrait::ONE() / a; + return -log2(div); + } + + let whole = a.mag / ONE; + let (msb, div) = lut::msb(whole); + + if a.mag == div * ONE { + return FixedTrait::new_unscaled(msb, false); + } else { + let norm = a / FixedTrait::new_unscaled(div, false); + let r8 = FixedTrait::new(76243, true) * norm; + let r7 = (r8 + FixedTrait::new(1038893, false)) * norm; + let r6 = (r7 + FixedTrait::new(6277679, true)) * norm; + let r5 = (r6 + FixedTrait::new(22135645, false)) * norm; + let r4 = (r5 + FixedTrait::new(50444339, true)) * norm; + let r3 = (r4 + FixedTrait::new(77896489, false)) * norm; + let r2 = (r3 + FixedTrait::new(83945943, true)) * norm; + let r1 = (r2 + FixedTrait::new(68407458, false)) * norm; + return r1 + FixedTrait::new(28734280, true) + FixedTrait::new_unscaled(msb, false); + } +} + +// Calculates the base 10 log of x: log10(x) +// self must be greater than zero +fn log10(a: FP8x23W) -> FP8x23W { + return FixedTrait::new(2525223, false) * log2(a); // log10(2) = 0.301... +} + +fn lt(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag != b.mag) && ((a.mag < b.mag) ^ a.sign); + } +} + +fn mul(a: FP8x23W, b: FP8x23W) -> FP8x23W { + let prod_u128 = integer::u64_wide_mul(a.mag, b.mag); + + // Re-apply sign + return FixedTrait::new((prod_u128 / ONE.into()).try_into().unwrap(), a.sign ^ b.sign); +} + +fn ne(a: @FP8x23W, b: @FP8x23W) -> bool { + return (*a.mag != *b.mag) || (*a.sign != *b.sign); +} + +fn neg(a: FP8x23W) -> FP8x23W { + if a.mag == 0 { + return a; + } else if !a.sign { + return FixedTrait::new(a.mag, !a.sign); + } else { + return FixedTrait::new(a.mag, false); + } +} + +// Calclates the value of x^y and checks for overflow before returning +// self is a FP8x23W point value +// b is a FP8x23W point value +fn pow(a: FP8x23W, b: FP8x23W) -> FP8x23W { + let (div, rem) = integer::u64_safe_divmod(b.mag, u64_as_non_zero(ONE)); + + // use the more performant integer pow when y is an int + if (rem == 0) { + return pow_int(a, b.mag / ONE, b.sign); + } + + // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 + return exp(b * ln(a)); +} + +// Calclates the value of a^b and checks for overflow before returning +fn pow_int(a: FP8x23W, b: u64, sign: bool) -> FP8x23W { + let mut x = a; + let mut n = b; + + if sign == true { + x = FixedTrait::ONE() / x; + } + + if n == 0 { + return FixedTrait::ONE(); + } + + let mut y = FixedTrait::ONE(); + let two = integer::u64_as_non_zero(2); + + loop { + if n <= 1 { + break; + } + + let (div, rem) = integer::u64_safe_divmod(n, two); + + if rem == 1 { + y = x * y; + } + + x = x * x; + n = div; + }; + + return x * y; +} + +fn rem(a: FP8x23W, b: FP8x23W) -> FP8x23W { + return a - floor(a / b) * b; +} + +fn round(a: FP8x23W) -> FP8x23W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if (HALF <= rem) { + return FixedTrait::new_unscaled(div + 1, a.sign); + } else { + return FixedTrait::new_unscaled(div, a.sign); + } +} + +// Calculates the square root of a FP8x23W point value +// x must be positive +fn sqrt(a: FP8x23W) -> FP8x23W { + assert(a.sign == false, 'must be positive'); + + let root = integer::u64_sqrt(a.mag.into() * ONE.into()); + return FixedTrait::new(root.into(), false); +} + +fn sub(a: FP8x23W, b: FP8x23W) -> FP8x23W { + return add(a, -b); +} + +fn sign(a: FP8x23W) -> FP8x23W { + if a.mag == 0 { + FixedTrait::new(0, false) + } else { + FixedTrait::new(ONE, a.sign) + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use orion::numbers::fixed_point::implementations::fp8x23wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::trig::{PI, HALF_PI}; + +#[test] +fn test_into() { + let a = FixedTrait::::new_unscaled(5, false); + assert(a.mag == 5 * ONE, 'invalid result'); +} + +#[test] +fn test_try_into_u128() { + // Positive unscaled + let a = FixedTrait::::new_unscaled(5, false); + assert(a.try_into().unwrap() == 5_u128, 'invalid result'); + + // Positive scaled + let b = FixedTrait::::new(5 * ONE, false); + assert(b.try_into().unwrap() == 5_u128, 'invalid result'); + + // Zero + let d = FixedTrait::::new_unscaled(0, false); + assert(d.try_into().unwrap() == 0_u128, 'invalid result'); +} + +#[test] +#[should_panic] +fn test_negative_try_into_u128() { + let a = FixedTrait::::new_unscaled(1, true); + let a: u128 = a.try_into().unwrap(); +} + +#[test] +#[available_gas(1000000)] +fn test_acos() { + let a = FixedTrait::::ONE(); + assert(a.acos().into() == 0, 'invalid one'); +} + +#[test] +#[available_gas(1000000)] +fn test_asin() { + let a = FixedTrait::ONE(); + assert_precise(a.asin(), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 +} + +#[test] +#[available_gas(2000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(a.atan(), 9287469, 'invalid two', Option::None(())); +} + +#[test] +fn test_ceil() { + let a = FixedTrait::new(24326963, false); // 2.9 + assert(ceil(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_floor() { + let a = FixedTrait::new(24326963, false); // 2.9 + assert(floor(a).mag == 2 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_round() { + let a = FixedTrait::new(24326963, false); // 2.9 + assert(round(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +#[should_panic] +fn test_sqrt_fail() { + let a = FixedTrait::new_unscaled(25, true); + sqrt(a); +} + +#[test] +fn test_sqrt() { + let mut a = FixedTrait::new_unscaled(0, false); + assert(sqrt(a).mag == 0, 'invalid zero root'); + a = FixedTrait::new_unscaled(25, false); + assert(sqrt(a).mag == 5 * ONE, 'invalid pos root'); +} + + +#[test] +#[available_gas(100000)] +fn test_msb() { + let a = FixedTrait::::new_unscaled(100, false); + let (msb, div) = lut::msb(a.mag / ONE); + assert(msb == 6, 'invalid msb'); + assert(div == 64, 'invalid msb ceil'); +} + +#[test] +#[available_gas(600000)] +fn test_pow() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new_unscaled(4, false); + assert(pow(a, b).mag == 81 * ONE, 'invalid pos base power'); +} + +#[test] +#[available_gas(900000)] +fn test_pow_frac() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new(4194304, false); // 0.5 + assert_relative( + pow(a, b), 14529495, 'invalid pos base power', Option::None(()) + ); // 1.7320508075688772 +} + +#[test] +#[available_gas(1000000)] +fn test_exp() { + let a = FixedTrait::new_unscaled(2, false); + assert_relative(exp(a), 61983895, 'invalid exp of 2', Option::None(())); // 7.389056098793725 +} + +#[test] +#[available_gas(400000)] +fn test_exp2() { + let a = FixedTrait::new_unscaled(5, false); + assert(exp2(a).mag == 268435456, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(20000)] +fn test_exp2_int() { + assert(exp2_int(5).into() == 268435456, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(1000000)] +fn test_ln() { + let mut a = FixedTrait::new_unscaled(1, false); + assert(ln(a).mag == 0, 'invalid ln of 1'); + + a = FixedTrait::new(22802601, false); + assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); +} + +#[test] +#[available_gas(1000000)] +fn test_log2() { + let mut a = FixedTrait::new_unscaled(32, false); + assert(log2(a) == FixedTrait::new_unscaled(5, false), 'invalid log2 32'); + + a = FixedTrait::new_unscaled(10, false); + assert_relative(log2(a), 27866353, 'invalid log2 10', Option::None(())); // 3.321928094887362 +} + +#[test] +#[available_gas(1000000)] +fn test_log10() { + let a = FixedTrait::new_unscaled(100, false); + assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); +} + +#[test] +fn test_eq() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = eq(@a, @b); + assert(c == true, 'invalid result'); +} + +#[test] +fn test_ne() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = ne(@a, @b); + assert(c == false, 'invalid result'); +} + +#[test] +fn test_add() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + assert(add(a, b) == FixedTrait::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_add_eq() { + let mut a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + a += b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_sub() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + let c = a - b; + assert(c == FixedTrait::::new_unscaled(3, false), 'false result invalid'); +} + +#[test] +fn test_sub_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + a -= b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +#[available_gas(100000)] +fn test_mul_pos() { + let a = FP8x23W { mag: 24326963, sign: false }; + let b = FP8x23W { mag: 24326963, sign: false }; + let c = a * b; + assert(c.mag == 70548192, 'invalid result'); +} + +#[test] +fn test_mul_neg() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + let c = a * b; + assert(c == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_mul_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + a *= b; + assert(a == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_div() { + let a = FixedTrait::new_unscaled(10, false); + let b = FixedTrait::::new(24326963, false); // 2.9 + let c = a / b; + assert(c.mag == 28926234, 'invalid pos decimal'); // 3.4482758620689653 +} + +#[test] +fn test_le() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a <= a, 'a <= a'); + assert(a <= b == false, 'a <= b'); + assert(a <= c == false, 'a <= c'); + + assert(b <= a, 'b <= a'); + assert(b <= b, 'b <= b'); + assert(b <= c == false, 'b <= c'); + + assert(c <= a, 'c <= a'); + assert(c <= b, 'c <= b'); + assert(c <= c, 'c <= c'); +} + +#[test] +fn test_lt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a < a == false, 'a < a'); + assert(a < b == false, 'a < b'); + assert(a < c == false, 'a < c'); + + assert(b < a, 'b < a'); + assert(b < b == false, 'b < b'); + assert(b < c == false, 'b < c'); + + assert(c < a, 'c < a'); + assert(c < b, 'c < b'); + assert(c < c == false, 'c < c'); +} + +#[test] +fn test_ge() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a >= a, 'a >= a'); + assert(a >= b, 'a >= b'); + assert(a >= c, 'a >= c'); + + assert(b >= a == false, 'b >= a'); + assert(b >= b, 'b >= b'); + assert(b >= c, 'b >= c'); + + assert(c >= a == false, 'c >= a'); + assert(c >= b == false, 'c >= b'); + assert(c >= c, 'c >= c'); +} + +#[test] +fn test_gt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a > a == false, 'a > a'); + assert(a > b, 'a > b'); + assert(a > c, 'a > c'); + + assert(b > a == false, 'b > a'); + assert(b > b == false, 'b > b'); + assert(b > c, 'b > c'); + + assert(c > a == false, 'c > a'); + assert(c > b == false, 'c > b'); + assert(c > c == false, 'c > c'); +} + +#[test] +#[available_gas(1000000)] +fn test_cos() { + let a = FixedTrait::::new(HALF_PI, false); + assert(a.cos().into() == 0, 'invalid half pi'); +} + +#[test] +#[available_gas(1000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(a.sin(), ONE.into(), 'invalid half pi', Option::None(())); +} + +#[test] +#[available_gas(2000000)] +fn test_tan() { + let a = FixedTrait::::new(HALF_PI / 2, false); + assert(a.tan().mag == 8388608, 'invalid quarter pi'); +} + +#[test] +#[available_gas(2000000)] +fn test_sign() { + let a = FixedTrait::::new(0, false); + assert(a.sign().mag == 0 && !a.sign().sign, 'invalid sign (0, true)'); + + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (HALF, true)'); + + let a = FixedTrait::::new(HALF, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (HALF, false)'); + + let a = FixedTrait::::new(ONE, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (ONE, true)'); + + let a = FixedTrait::::new(ONE, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (ONE, false)'); +} + +#[test] +#[should_panic] +#[available_gas(2000000)] +fn test_sign_fail() { + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag != ONE && !a.sign().sign, 'invalid sign (HALF, true)'); +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo new file mode 100644 index 000000000..ed9b66391 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo @@ -0,0 +1,159 @@ +use core::debug::PrintTrait; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, TWO, FP8x23W, FP8x23WImpl, FP8x23WAdd, FP8x23WAddEq, FP8x23WSub, FP8x23WMul, FP8x23WMulEq, + FP8x23WTryIntoU128, FP8x23WPartialEq, FP8x23WPartialOrd, FP8x23WSubEq, FP8x23WNeg, FP8x23WDiv, + FP8x23WIntoFelt252, FixedTrait +}; + +// Calculates hyperbolic cosine of a (fixed point) +fn cosh(a: FP8x23W) -> FP8x23W { + let ea = a.exp(); + return (ea + (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic sine of a (fixed point) +fn sinh(a: FP8x23W) -> FP8x23W { + let ea = a.exp(); + return (ea - (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic tangent of a (fixed point) +fn tanh(a: FP8x23W) -> FP8x23W { + let ea = a.exp(); + let ea_i = FixedTrait::ONE() / ea; + return (ea - ea_i) / (ea + ea_i); +} + +// Calculates inverse hyperbolic cosine of a (fixed point) +fn acosh(a: FP8x23W) -> FP8x23W { + let root = (a * a - FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic sine of a (fixed point) +fn asinh(a: FP8x23W) -> FP8x23W { + let root = (a * a + FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic tangent of a (fixed point) +fn atanh(a: FP8x23W) -> FP8x23W { + let one = FixedTrait::ONE(); + let ln_arg = (one + a) / (one - a); + return ln_arg.ln() / FixedTrait::new(TWO, false); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use option::OptionTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp8x23wide::helpers::assert_precise; + +#[test] +#[available_gas(10000000)] +fn test_cosh() { + let a = FixedTrait::new(TWO, false); + assert_precise(cosh(a), 31559585, 'invalid two', Option::None(())); // 3.762195691016423 + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 12944299, 'invalid one', Option::None(())); // 1.5430806347841253 + + let a = FixedTrait::ZERO(); + assert_precise(cosh(a), ONE.into(), 'invalid zero', Option::None(())); + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 12944299, 'invalid neg one', Option::None(())); // 1.5430806347841253 + + let a = FixedTrait::new(TWO, true); + assert_precise(cosh(a), 31559602, 'invalid neg two', Option::None(())); // 3.762195691016423 +} + +#[test] +#[available_gas(10000000)] +fn test_sinh() { + let a = FixedTrait::new(TWO, false); + assert_precise(sinh(a), 30424310, 'invalid two', Option::None(())); // 3.6268604077773023 + + let a = FixedTrait::ONE(); + assert_precise(sinh(a), 9858302, 'invalid one', Option::None(())); // 1.1752011936029418 + + let a = FixedTrait::ZERO(); + assert(sinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(sinh(a), -9858302, 'invalid neg one', Option::None(())); // -1.1752011936029418 + + let a = FixedTrait::new(TWO, true); + assert_precise(sinh(a), -30424328, 'invalid neg two', Option::None(())); // -3.6268604077773023 +} + +#[test] +#[available_gas(10000000)] +fn test_tanh() { + let a = FixedTrait::new(TWO, false); + assert_precise(tanh(a), 8086849, 'invalid two', Option::None(())); // 0.9640275800745076 + + let a = FixedTrait::ONE(); + assert_precise(tanh(a), 6388715, 'invalid one', Option::None(())); // 0.7615941559446443 + + let a = FixedTrait::ZERO(); + assert(tanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(tanh(a), -6388715, 'invalid neg one', Option::None(())); // -0.7615941559446443 + + let a = FixedTrait::new(TWO, true); + assert_precise(tanh(a), -8086849, 'invalid neg two', Option::None(())); // 0.9640275800745076 +} + +#[test] +#[available_gas(10000000)] +fn test_acosh() { + let a = FixedTrait::new(31559585, false); // 3.762195691016423 + assert_precise(acosh(a), 16777257, 'invalid two', Option::None(())); + + let a = FixedTrait::new(12944299, false); // 1.5430806347841253 + assert_precise(acosh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ONE(); // 1 + assert(acosh(a).into() == 0, 'invalid zero'); +} + +#[test] +#[available_gas(10000000)] +fn test_asinh() { + let a = FixedTrait::new(30424310, false); // 3.6268604077773023 + assert_precise(asinh(a), 16777257, 'invalid two', Option::None(())); + + let a = FixedTrait::new(9858302, false); // 1.1752011936029418 + assert_precise(asinh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(asinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(9858302, true); // -1.1752011936029418 + assert_precise(asinh(a), -ONE.into(), 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(30424310, true); // -3.6268604077773023 + assert_precise(asinh(a), -16777238, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(10000000)] +fn test_atanh() { + let a = FixedTrait::new(7549747, false); // 0.9 + assert_precise(atanh(a), 12349872, 'invalid 0.9', Option::None(())); // 1.4722194895832204 + + let a = FixedTrait::new(HALF, false); // 0.5 + assert_precise(atanh(a), 4607914, 'invalid half', Option::None(())); // 0.5493061443340548 + + let a = FixedTrait::ZERO(); + assert(atanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(HALF, true); // 0.5 + assert_precise(atanh(a), -4607914, 'invalid neg half', Option::None(())); // 0.5493061443340548 + + let a = FixedTrait::new(7549747, true); // 0.9 + assert_precise(atanh(a), -12349872, 'invalid -0.9', Option::None(())); // 1.4722194895832204 +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo new file mode 100644 index 000000000..157499b5b --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo @@ -0,0 +1,1229 @@ +// Calculates the most significant bit +fn msb(whole: u64) -> (u64, u64) { + if whole < 256 { + if whole < 2 { + return (0, 1); + } + if whole < 4 { + return (1, 2); + } + if whole < 8 { + return (2, 4); + } + if whole < 16 { + return (3, 8); + } + if whole < 32 { + return (4, 16); + } + if whole < 64 { + return (5, 32); + } + if whole < 128 { + return (6, 64); + } + if whole < 256 { + return (7, 128); + } + } + + return (8, 256); +} + +fn exp2(exp: u64) -> u64 { + if exp <= 16 { + if exp == 0 { + return 1; + } + if exp == 1 { + return 2; + } + if exp == 2 { + return 4; + } + if exp == 3 { + return 8; + } + if exp == 4 { + return 16; + } + if exp == 5 { + return 32; + } + if exp == 6 { + return 64; + } + if exp == 7 { + return 128; + } + if exp == 8 { + return 256; + } + if exp == 9 { + return 512; + } + if exp == 10 { + return 1024; + } + if exp == 11 { + return 2048; + } + if exp == 12 { + return 4096; + } + if exp == 13 { + return 8192; + } + if exp == 14 { + return 16384; + } + if exp == 15 { + return 32768; + } + if exp == 16 { + return 65536; + } + } else if exp <= 32 { + if exp == 17 { + return 131072; + } + if exp == 18 { + return 262144; + } + if exp == 19 { + return 524288; + } + if exp == 20 { + return 1048576; + } + if exp == 21 { + return 2097152; + } + if exp == 22 { + return 4194304; + } + } + + return 8388608; +} + +fn sin(a: u64) -> (u64, u64, u64) { + let slot = a / 51472; + + if slot < 128 { + if slot < 64 { + if slot < 32 { + if slot < 16 { + if slot == 0 { + return (0, 0, 51472); + } + if slot == 1 { + return (51472, 51472, 102941); + } + if slot == 2 { + return (102944, 102941, 154407); + } + if slot == 3 { + return (154416, 154407, 205867); + } + if slot == 4 { + return (205887, 205867, 257319); + } + if slot == 5 { + return (257359, 257319, 308761); + } + if slot == 6 { + return (308831, 308761, 360192); + } + if slot == 7 { + return (360303, 360192, 411609); + } + if slot == 8 { + return (411775, 411609, 463011); + } + if slot == 9 { + return (463247, 463011, 514396); + } + if slot == 10 { + return (514723, 514396, 565761); + } + if slot == 11 { + return (566190, 565761, 617104); + } + if slot == 12 { + return (617662, 617104, 668425); + } + if slot == 13 { + return (669134, 668425, 719720); + } + if slot == 14 { + return (720606, 719720, 770988); + } + if slot == 15 { + return (772078, 770988, 822227); + } + } else { + if slot == 16 { + return (823550, 822227, 873436); + } + if slot == 17 { + return (875022, 873436, 924611); + } + if slot == 18 { + return (926493, 924611, 975751); + } + if slot == 19 { + return (977965, 975751, 1026855); + } + if slot == 20 { + return (1029437, 1026855, 1077920); + } + if slot == 21 { + return (1080909, 1077920, 1128945); + } + if slot == 22 { + return (1132381, 1128945, 1179927); + } + if slot == 23 { + return (1183853, 1179927, 1230864); + } + if slot == 24 { + return (1235324, 1230864, 1281756); + } + if slot == 25 { + return (1286796, 1281756, 1332599); + } + if slot == 26 { + return (1338268, 1332599, 1383392); + } + if slot == 27 { + return (1389740, 1383392, 1434132); + } + if slot == 28 { + return (1441212, 1434132, 1484819); + } + if slot == 29 { + return (1492684, 1484819, 1535450); + } + if slot == 30 { + return (1544156, 1535450, 1586023); + } + if slot == 31 { + return (1595627, 1586023, 1636536); + } + } + } else { + if slot < 48 { + if slot == 32 { + return (1647099, 1636536, 1686988); + } + if slot == 33 { + return (1698571, 1686988, 1737376); + } + if slot == 34 { + return (1750043, 1737376, 1787699); + } + if slot == 35 { + return (1801515, 1787699, 1837954); + } + if slot == 36 { + return (1852987, 1837954, 1888141); + } + if slot == 37 { + return (1904459, 1888141, 1938256); + } + if slot == 38 { + return (1955930, 1938256, 1988298); + } + if slot == 39 { + return (2007402, 1988298, 2038265); + } + if slot == 40 { + return (2058871, 2038265, 2088156); + } + if slot == 41 { + return (2110346, 2088156, 2137968); + } + if slot == 42 { + return (2161818, 2137968, 2187700); + } + if slot == 43 { + return (2213290, 2187700, 2237349); + } + if slot == 44 { + return (2264762, 2237349, 2286914); + } + if slot == 45 { + return (2316233, 2286914, 2336392); + } + if slot == 46 { + return (2367705, 2336392, 2385783); + } + if slot == 47 { + return (2419177, 2385783, 2435084); + } + } else { + if slot == 48 { + return (2470649, 2435084, 2484294); + } + if slot == 49 { + return (2522121, 2484294, 2533410); + } + if slot == 50 { + return (2573593, 2533410, 2582430); + } + if slot == 51 { + return (2625065, 2582430, 2631353); + } + if slot == 52 { + return (2676536, 2631353, 2680177); + } + if slot == 53 { + return (2728008, 2680177, 2728901); + } + if slot == 54 { + return (2779480, 2728901, 2777521); + } + if slot == 55 { + return (2830952, 2777521, 2826037); + } + if slot == 56 { + return (2882424, 2826037, 2874446); + } + if slot == 57 { + return (2933896, 2874446, 2922748); + } + if slot == 58 { + return (2985368, 2922748, 2970939); + } + if slot == 59 { + return (3036839, 2970939, 3019018); + } + if slot == 60 { + return (3088311, 3019018, 3066984); + } + if slot == 61 { + return (3139783, 3066984, 3114834); + } + if slot == 62 { + return (3191255, 3114834, 3162567); + } + if slot == 63 { + return (3242727, 3162567, 3210181); + } + } + } + } else { + if slot < 96 { + if slot < 80 { + if slot == 64 { + return (3294199, 3210181, 3257674); + } + if slot == 65 { + return (3345671, 3257674, 3305045); + } + if slot == 66 { + return (3397142, 3305045, 3352291); + } + if slot == 67 { + return (3448614, 3352291, 3399411); + } + if slot == 68 { + return (3500086, 3399411, 3446402); + } + if slot == 69 { + return (3551558, 3446402, 3493264); + } + if slot == 70 { + return (3603030, 3493264, 3539995); + } + if slot == 71 { + return (3654502, 3539995, 3586592); + } + if slot == 72 { + return (3705973, 3586592, 3633054); + } + if slot == 73 { + return (3757445, 3633054, 3679380); + } + if slot == 74 { + return (3808917, 3679380, 3725567); + } + if slot == 75 { + return (3860389, 3725567, 3771613); + } + if slot == 76 { + return (3911861, 3771613, 3817518); + } + if slot == 77 { + return (3963333, 3817518, 3863279); + } + if slot == 78 { + return (4014805, 3863279, 3908894); + } + if slot == 79 { + return (4066276, 3908894, 3954362); + } + } else { + if slot == 80 { + return (4117751, 3954362, 3999682); + } + if slot == 81 { + return (4169220, 3999682, 4044851); + } + if slot == 82 { + return (4220692, 4044851, 4089867); + } + if slot == 83 { + return (4272164, 4089867, 4134730); + } + if slot == 84 { + return (4323636, 4134730, 4179437); + } + if slot == 85 { + return (4375108, 4179437, 4223986); + } + if slot == 86 { + return (4426579, 4223986, 4268377); + } + if slot == 87 { + return (4478051, 4268377, 4312606); + } + if slot == 88 { + return (4529523, 4312606, 4356674); + } + if slot == 89 { + return (4580995, 4356674, 4400577); + } + if slot == 90 { + return (4632474, 4400577, 4444315); + } + if slot == 91 { + return (4683939, 4444315, 4487885); + } + if slot == 92 { + return (4735411, 4487885, 4531287); + } + if slot == 93 { + return (4786882, 4531287, 4574518); + } + if slot == 94 { + return (4838354, 4574518, 4617576); + } + if slot == 95 { + return (4889826, 4617576, 4660461); + } + } + } else { + if slot < 112 { + if slot == 96 { + return (4941298, 4660461, 4703170); + } + if slot == 97 { + return (4992770, 4703170, 4745702); + } + if slot == 98 { + return (5044242, 4745702, 4788056); + } + if slot == 99 { + return (5095714, 4788056, 4830229); + } + if slot == 100 { + return (5147227, 4830229, 4872221); + } + if slot == 101 { + return (5198657, 4872221, 4914029); + } + if slot == 102 { + return (5250129, 4914029, 4955652); + } + if slot == 103 { + return (5301601, 4955652, 4997088); + } + if slot == 104 { + return (5353073, 4997088, 5038336); + } + if slot == 105 { + return (5404545, 5038336, 5079395); + } + if slot == 106 { + return (5456017, 5079395, 5120262); + } + if slot == 107 { + return (5507488, 5120262, 5160937); + } + if slot == 108 { + return (5558960, 5160937, 5201417); + } + if slot == 109 { + return (5610432, 5201417, 5241701); + } + if slot == 110 { + return (5661904, 5241701, 5281788); + } + if slot == 111 { + return (5713376, 5281788, 5321677); + } + } else { + if slot == 112 { + return (5764848, 5321677, 5361364); + } + if slot == 113 { + return (5816320, 5361364, 5400850); + } + if slot == 114 { + return (5867791, 5400850, 5440133); + } + if slot == 115 { + return (5919263, 5440133, 5479211); + } + if slot == 116 { + return (5970735, 5479211, 5518082); + } + if slot == 117 { + return (6022207, 5518082, 5556746); + } + if slot == 118 { + return (6073679, 5556746, 5595201); + } + if slot == 119 { + return (6125151, 5595201, 5633445); + } + if slot == 120 { + return (6176622, 5633445, 5671477); + } + if slot == 121 { + return (6228094, 5671477, 5709295); + } + if slot == 122 { + return (6279566, 5709295, 5746898); + } + if slot == 123 { + return (6331038, 5746898, 5784285); + } + if slot == 124 { + return (6382510, 5784285, 5821455); + } + if slot == 125 { + return (6433982, 5821455, 5858405); + } + if slot == 126 { + return (6485454, 5858405, 5895134); + } + if slot == 127 { + return (6536925, 5895134, 5931642); + } + } + } + } + } else { + if slot < 192 { + if slot < 160 { + if slot < 144 { + if slot == 128 { + return (6588397, 5931642, 5967926); + } + if slot == 129 { + return (6639869, 5967926, 6003985); + } + if slot == 130 { + return (6691345, 6003985, 6039819); + } + if slot == 131 { + return (6742813, 6039819, 6075425); + } + if slot == 132 { + return (6794285, 6075425, 6110802); + } + if slot == 133 { + return (6845757, 6110802, 6145949); + } + if slot == 134 { + return (6897228, 6145949, 6180865); + } + if slot == 135 { + return (6948700, 6180865, 6215549); + } + if slot == 136 { + return (7000172, 6215549, 6249998); + } + if slot == 137 { + return (7051644, 6249998, 6284212); + } + if slot == 138 { + return (7103116, 6284212, 6318189); + } + if slot == 139 { + return (7154588, 6318189, 6351928); + } + if slot == 140 { + return (7206060, 6351928, 6385428); + } + if slot == 141 { + return (7257531, 6385428, 6418688); + } + if slot == 142 { + return (7309003, 6418688, 6451706); + } + if slot == 143 { + return (7360475, 6451706, 6484482); + } + } else { + if slot == 144 { + return (7411947, 6484482, 6517013); + } + if slot == 145 { + return (7463419, 6517013, 6549299); + } + if slot == 146 { + return (7514891, 6549299, 6581338); + } + if slot == 147 { + return (7566363, 6581338, 6613129); + } + if slot == 148 { + return (7617834, 6613129, 6644672); + } + if slot == 149 { + return (7669306, 6644672, 6675964); + } + if slot == 150 { + return (7720780, 6675964, 6707005); + } + if slot == 151 { + return (7772250, 6707005, 6737793); + } + if slot == 152 { + return (7823722, 6737793, 6768328); + } + if slot == 153 { + return (7875194, 6768328, 6798608); + } + if slot == 154 { + return (7926666, 6798608, 6828632); + } + if slot == 155 { + return (7978137, 6828632, 6858399); + } + if slot == 156 { + return (8029609, 6858399, 6887907); + } + if slot == 157 { + return (8081081, 6887907, 6917156); + } + if slot == 158 { + return (8132553, 6917156, 6946145); + } + if slot == 159 { + return (8184025, 6946145, 6974873); + } + if slot == 160 { + return (8235503, 6974873, 7003337); + } + } + } else { + if slot < 176 { + if slot == 161 { + return (8286968, 7003337, 7031538); + } + if slot == 162 { + return (8338440, 7031538, 7059475); + } + if slot == 163 { + return (8389912, 7059475, 7087145); + } + if slot == 164 { + return (8441384, 7087145, 7114549); + } + if slot == 165 { + return (8492856, 7114549, 7141685); + } + if slot == 166 { + return (8544328, 7141685, 7168552); + } + if slot == 167 { + return (8595800, 7168552, 7195149); + } + if slot == 168 { + return (8647271, 7195149, 7221475); + } + if slot == 169 { + return (8698743, 7221475, 7247530); + } + if slot == 170 { + return (8750215, 7247530, 7273311); + } + if slot == 171 { + return (8801687, 7273311, 7298819); + } + if slot == 172 { + return (8853159, 7298819, 7324052); + } + if slot == 173 { + return (8904631, 7324052, 7349009); + } + if slot == 174 { + return (8956103, 7349009, 7373689); + } + if slot == 175 { + return (9007574, 7373689, 7398092); + } + } else { + if slot == 176 { + return (9059046, 7398092, 7422216); + } + if slot == 177 { + return (9110518, 7422216, 7446061); + } + if slot == 178 { + return (9161990, 7446061, 7469625); + } + if slot == 179 { + return (9213462, 7469625, 7492909); + } + if slot == 180 { + return (9264934, 7492909, 7515910); + } + if slot == 181 { + return (9316406, 7515910, 7538628); + } + if slot == 182 { + return (9367877, 7538628, 7561062); + } + if slot == 183 { + return (9419349, 7561062, 7583212); + } + if slot == 184 { + return (9470821, 7583212, 7605076); + } + if slot == 185 { + return (9522293, 7605076, 7626654); + } + if slot == 186 { + return (9573765, 7626654, 7647945); + } + if slot == 187 { + return (9625237, 7647945, 7668947); + } + if slot == 188 { + return (9676709, 7668947, 7689661); + } + if slot == 189 { + return (9728180, 7689661, 7710086); + } + if slot == 190 { + return (9779651, 7710086, 7730220); + } + if slot == 191 { + return (9831124, 7730220, 7750063); + } + } + } + } else { + if slot < 224 { + if slot < 208 { + if slot == 192 { + return (9882596, 7750063, 7769615); + } + if slot == 193 { + return (9934068, 7769615, 7788874); + } + if slot == 194 { + return (9985540, 7788874, 7807839); + } + if slot == 195 { + return (10037012, 7807839, 7826511); + } + if slot == 196 { + return (10088483, 7826511, 7844888); + } + if slot == 197 { + return (10139955, 7844888, 7862970); + } + if slot == 198 { + return (10191427, 7862970, 7880755); + } + if slot == 199 { + return (10242899, 7880755, 7898244); + } + if slot == 200 { + return (10294373, 7898244, 7915436); + } + if slot == 201 { + return (10345843, 7915436, 7932329); + } + if slot == 202 { + return (10397315, 7932329, 7948924); + } + if slot == 203 { + return (10448786, 7948924, 7965220); + } + if slot == 204 { + return (10500258, 7965220, 7981215); + } + if slot == 205 { + return (10551730, 7981215, 7996911); + } + if slot == 206 { + return (10603202, 7996911, 8012305); + } + if slot == 207 { + return (10654674, 8012305, 8027397); + } + } else { + if slot == 208 { + return (10706146, 8027397, 8042188); + } + if slot == 209 { + return (10757617, 8042188, 8056675); + } + if slot == 210 { + return (10809089, 8056675, 8070859); + } + if slot == 211 { + return (10860561, 8070859, 8084740); + } + if slot == 212 { + return (10912033, 8084740, 8098316); + } + if slot == 213 { + return (10963505, 8098316, 8111587); + } + if slot == 214 { + return (11014977, 8111587, 8124552); + } + if slot == 215 { + return (11066449, 8124552, 8137212); + } + if slot == 216 { + return (11117920, 8137212, 8149565); + } + if slot == 217 { + return (11169392, 8149565, 8161612); + } + if slot == 218 { + return (11220864, 8161612, 8173351); + } + if slot == 219 { + return (11272336, 8173351, 8184783); + } + if slot == 220 { + return (11323808, 8184783, 8195906); + } + if slot == 221 { + return (11375280, 8195906, 8206721); + } + if slot == 222 { + return (11426752, 8206721, 8217227); + } + if slot == 223 { + return (11478223, 8217227, 8227423); + } + } + } else { + if slot < 240 { + if slot == 224 { + return (11529695, 8227423, 8237310); + } + if slot == 225 { + return (11581167, 8237310, 8246887); + } + if slot == 226 { + return (11632639, 8246887, 8256153); + } + if slot == 227 { + return (11684111, 8256153, 8265108); + } + if slot == 228 { + return (11735583, 8265108, 8273752); + } + if slot == 229 { + return (11787055, 8273752, 8282085); + } + if slot == 230 { + return (11838531, 8282085, 8290105); + } + if slot == 231 { + return (11889998, 8290105, 8297814); + } + if slot == 232 { + return (11941470, 8297814, 8305210); + } + if slot == 233 { + return (11992942, 8305210, 8312294); + } + if slot == 234 { + return (12044414, 8312294, 8319064); + } + if slot == 235 { + return (12095886, 8319064, 8325522); + } + if slot == 236 { + return (12147358, 8325522, 8331666); + } + if slot == 237 { + return (12198829, 8331666, 8337496); + } + if slot == 238 { + return (12250301, 8337496, 8343012); + } + if slot == 239 { + return (12301773, 8343012, 8348215); + } + } else { + if slot == 240 { + return (12353244, 8348215, 8353102); + } + if slot == 241 { + return (12404717, 8353102, 8357676); + } + if slot == 242 { + return (12456189, 8357676, 8361935); + } + if slot == 243 { + return (12507661, 8361935, 8365879); + } + if slot == 244 { + return (12559132, 8365879, 8369508); + } + if slot == 245 { + return (12610604, 8369508, 8372822); + } + if slot == 246 { + return (12662076, 8372822, 8375820); + } + if slot == 247 { + return (12713548, 8375820, 8378504); + } + if slot == 248 { + return (12765020, 8378504, 8380871); + } + if slot == 249 { + return (12816492, 8380871, 8382924); + } + if slot == 250 { + return (12867964, 8382924, 8384660); + } + if slot == 251 { + return (12919435, 8384660, 8386082); + } + if slot == 252 { + return (12970907, 8386082, 8387187); + } + if slot == 253 { + return (13022379, 8387187, 8387976); + } + if slot == 254 { + return (13073851, 8387976, 8388450); + } + } + } + } + } + + return (13125323, 8388450, 8388608); +} + +fn atan(a: u64) -> (u64, u64, u64) { + let slot = a / 58720; + + if slot == 0 { + return (0, 0, 58719); + } + if slot == 1 { + return (58720, 58719, 117433); + } + if slot == 2 { + return (117441, 117433, 176135); + } + if slot == 3 { + return (176161, 176135, 234820); + } + if slot == 4 { + return (234881, 234820, 293481); + } + if slot == 5 { + return (293601, 293481, 352115); + } + if slot == 6 { + return (352322, 352115, 410713); + } + if slot == 7 { + return (411042, 410713, 469272); + } + if slot == 8 { + return (469762, 469272, 527785); + } + if slot == 9 { + return (528482, 527785, 586246); + } + if slot == 10 { + return (587201, 586246, 644651); + } + if slot == 11 { + return (645923, 644651, 702993); + } + if slot == 12 { + return (704643, 702993, 761267); + } + if slot == 13 { + return (763363, 761267, 819467); + } + if slot == 14 { + return (822084, 819467, 877588); + } + if slot == 15 { + return (880804, 877588, 935625); + } + if slot == 16 { + return (939524, 935625, 993572); + } + if slot == 17 { + return (998244, 993572, 1051424); + } + if slot == 18 { + return (1056965, 1051424, 1109175); + } + if slot == 19 { + return (1115685, 1109175, 1166821); + } + if slot == 20 { + return (1174411, 1166821, 1224357); + } + if slot == 21 { + return (1233125, 1224357, 1281776); + } + if slot == 22 { + return (1291846, 1281776, 1339075); + } + if slot == 23 { + return (1350566, 1339075, 1396248); + } + if slot == 24 { + return (1409286, 1396248, 1453290); + } + if slot == 25 { + return (1468006, 1453290, 1510197); + } + if slot == 26 { + return (1526727, 1510197, 1566964); + } + if slot == 27 { + return (1585447, 1566964, 1623585); + } + if slot == 28 { + return (1644167, 1623585, 1680058); + } + if slot == 29 { + return (1702887, 1680058, 1736376); + } + if slot == 30 { + return (1761612, 1736376, 1792537); + } + if slot == 31 { + return (1820328, 1792537, 1848534); + } + if slot == 32 { + return (1879048, 1848534, 1904364); + } + if slot == 33 { + return (1937768, 1904364, 1960024); + } + if slot == 34 { + return (1996489, 1960024, 2015508); + } + if slot == 35 { + return (2055209, 2015508, 2070813); + } + if slot == 36 { + return (2113929, 2070813, 2125935); + } + if slot == 37 { + return (2172649, 2125935, 2180869); + } + if slot == 38 { + return (2231370, 2180869, 2235613); + } + if slot == 39 { + return (2290090, 2235613, 2290163); + } + if slot == 40 { + return (2348813, 2290163, 2344515); + } + if slot == 41 { + return (2407530, 2344515, 2398665); + } + if slot == 42 { + return (2466251, 2398665, 2452611); + } + if slot == 43 { + return (2524971, 2452611, 2506348); + } + if slot == 44 { + return (2583691, 2506348, 2559875); + } + if slot == 45 { + return (2642412, 2559875, 2613187); + } + if slot == 46 { + return (2701132, 2613187, 2666281); + } + if slot == 47 { + return (2759852, 2666281, 2719156); + } + if slot == 48 { + return (2818572, 2719156, 2771807); + } + if slot == 49 { + return (2877293, 2771807, 2824233); + } + if slot == 50 { + return (2936014, 2824233, 2876431); + } + if slot == 51 { + return (2994733, 2876431, 2928397); + } + if slot == 52 { + return (3053453, 2928397, 2980130); + } + if slot == 53 { + return (3112174, 2980130, 3031628); + } + if slot == 54 { + return (3170894, 3031628, 3082888); + } + if slot == 55 { + return (3229614, 3082888, 3133907); + } + if slot == 56 { + return (3288334, 3133907, 3184685); + } + if slot == 57 { + return (3347055, 3184685, 3235218); + } + if slot == 58 { + return (3405775, 3235218, 3285506); + } + if slot == 59 { + return (3464495, 3285506, 3335545); + } + if slot == 60 { + return (3523224, 3335545, 3385336); + } + if slot == 61 { + return (3581936, 3385336, 3434875); + } + if slot == 62 { + return (3640656, 3434875, 3484161); + } + if slot == 63 { + return (3699376, 3484161, 3533193); + } + if slot == 64 { + return (3758096, 3533193, 3581970); + } + if slot == 65 { + return (3816817, 3581970, 3630491); + } + if slot == 66 { + return (3875537, 3630491, 3678753); + } + if slot == 67 { + return (3934257, 3678753, 3726756); + } + if slot == 68 { + return (3992977, 3726756, 3774499); + } + if slot == 69 { + return (4051698, 3774499, 3821981); + } + if slot == 70 { + return (4110418, 3821981, 3869201); + } + if slot == 71 { + return (4169138, 3869201, 3916159); + } + if slot == 72 { + return (4227858, 3916159, 3962853); + } + if slot == 73 { + return (4286579, 3962853, 4009282); + } + if slot == 74 { + return (4345299, 4009282, 4055447); + } + if slot == 75 { + return (4404019, 4055447, 4101347); + } + if slot == 76 { + return (4462739, 4101347, 4146981); + } + if slot == 77 { + return (4521460, 4146981, 4192350); + } + if slot == 78 { + return (4580180, 4192350, 4237451); + } + if slot == 79 { + return (4638900, 4237451, 4282286); + } + if slot == 80 { + return (4697620, 4282286, 4326855); + } + if slot == 81 { + return (4756341, 4326855, 4371156); + } + if slot == 82 { + return (4815061, 4371156, 4415191); + } + if slot == 83 { + return (4873781, 4415191, 4458958); + } + if slot == 84 { + return (4932502, 4458958, 4502459); + } + if slot == 85 { + return (4991222, 4502459, 4545693); + } + if slot == 86 { + return (5049942, 4545693, 4588660); + } + if slot == 87 { + return (5108662, 4588660, 4631361); + } + if slot == 88 { + return (5167383, 4631361, 4673795); + } + if slot == 89 { + return (5226103, 4673795, 4715964); + } + if slot == 90 { + return (5284823, 4715964, 4757868); + } + if slot == 91 { + return (5343543, 4757868, 4799506); + } + if slot == 92 { + return (5402264, 4799506, 4840880); + } + if slot == 93 { + return (5460984, 4840880, 4881990); + } + if slot == 94 { + return (5519704, 4881990, 4922837); + } + if slot == 95 { + return (5578424, 4922837, 4963420); + } + if slot == 96 { + return (5637145, 4963420, 5003742); + } + if slot == 97 { + return (5695865, 5003742, 5043802); + } + if slot == 98 { + return (5754585, 5043802, 5083601); + } + + return (5813305, 5083601, 5123141); +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo new file mode 100644 index 000000000..025b79bb2 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo @@ -0,0 +1,448 @@ +use debug::PrintTrait; +use integer::{u64_safe_divmod, u64_as_non_zero}; +use option::OptionTrait; + +use orion::numbers::fixed_point::implementations::fp8x23wide::math::lut; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, TWO, FP8x23W, FP8x23WImpl, FP8x23WAdd, FP8x23WSub, FP8x23WMul, FP8x23WDiv, + FP8x23WIntoFelt252, FixedTrait +}; + +// CONSTANTS + +const TWO_PI: u64 = 52707178; +const PI: u64 = 26353589; +const HALF_PI: u64 = 13176795; + +// PUBLIC + +// Calculates arccos(a) for -1 <= a <= 1 (fixed point) +// arccos(a) = arcsin(sqrt(1 - a^2)) - arctan identity has discontinuity at zero +fn acos(a: FP8x23W) -> FP8x23W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +fn acos_fast(a: FP8x23W) -> FP8x23W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin_fast(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +// Calculates arcsin(a) for -1 <= a <= 1 (fixed point) +// arcsin(a) = arctan(a / sqrt(1 - a^2)) +fn asin(a: FP8x23W) -> FP8x23W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan(a / div); +} + +fn asin_fast(a: FP8x23W) -> FP8x23W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan_fast(a / div); +} + +// Calculates arctan(a) (fixed point) +// See https://stackoverflow.com/a/50894477 for range adjustments +fn atan(a: FP8x23W) -> FP8x23W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 5872026) { + let sqrt3_3 = FixedTrait::new(4843165, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let r10 = FixedTrait::new(15363, true) * at; + let r9 = (r10 + FixedTrait::new(392482, true)) * at; + let r8 = (r9 + FixedTrait::new(1629064, false)) * at; + let r7 = (r8 + FixedTrait::new(2197820, true)) * at; + let r6 = (r7 + FixedTrait::new(366693, false)) * at; + let r5 = (r6 + FixedTrait::new(1594324, false)) * at; + let r4 = (r5 + FixedTrait::new(11519, false)) * at; + let r3 = (r4 + FixedTrait::new(2797104, true)) * at; + let r2 = (r3 + FixedTrait::new(34, false)) * at; + let mut res = (r2 + FixedTrait::new(8388608, false)) * at; + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(4392265, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +fn atan_fast(a: FP8x23W) -> FP8x23W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 5872026) { + let sqrt3_3 = FixedTrait::new(4843165, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let (start, low, high) = lut::atan(at.mag); + let partial_step = FixedTrait::new(at.mag - start, false) / FixedTrait::new(58720, false); + let mut res = partial_step * FixedTrait::new(high - low, false) + FixedTrait::new(low, false); + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(4392265, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +// Calculates cos(a) with a in radians (fixed point) +fn cos(a: FP8x23W) -> FP8x23W { + return sin(FixedTrait::new(HALF_PI, false) - a); +} + +fn cos_fast(a: FP8x23W) -> FP8x23W { + return sin_fast(FixedTrait::new(HALF_PI, false) - a); +} + +fn sin(a: FP8x23W) -> FP8x23W { + let a1 = a.mag % TWO_PI; + let (whole_rem, partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let a2 = FixedTrait::new(partial_rem, false); + let partial_sign = whole_rem == 1; + + let loop_res = a2 * _sin_loop(a2, 7, FixedTrait::ONE()); + return FixedTrait::new(loop_res.mag, a.sign ^ partial_sign && loop_res.mag != 0); +} + +fn sin_fast(a: FP8x23W) -> FP8x23W { + let a1 = a.mag % TWO_PI; + let (whole_rem, mut partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let partial_sign = whole_rem == 1; + + if partial_rem >= HALF_PI { + partial_rem = PI - partial_rem; + } + + let (start, low, high) = lut::sin(partial_rem); + let partial_step = FixedTrait::new(partial_rem - start, false) / FixedTrait::new(51472, false); + let res = partial_step * (FixedTrait::new(high, false) - FixedTrait::new(low, false)) + + FixedTrait::::new(low, false); + + return FixedTrait::new(res.mag, a.sign ^ partial_sign && res.mag != 0); +} + +// Calculates tan(a) with a in radians (fixed point) +fn tan(a: FP8x23W) -> FP8x23W { + let sinx = sin(a); + let cosx = cos(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +fn tan_fast(a: FP8x23W) -> FP8x23W { + let sinx = sin_fast(a); + let cosx = cos_fast(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +// Helper function to calculate Taylor series for sin +fn _sin_loop(a: FP8x23W, i: u64, acc: FP8x23W) -> FP8x23W { + let div = (2 * i + 2) * (2 * i + 3); + let term = a * a * acc / FixedTrait::new_unscaled(div, false); + let new_acc = FixedTrait::ONE() - term; + + if (i == 0) { + return new_acc; + } + + return _sin_loop(a, i - 1, new_acc); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp8x23wide::helpers::{ + assert_precise, assert_relative +}; + +#[test] +#[available_gas(3000000)] +fn test_acos() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos(a), 8784530, 'invalid half', error); // 1.0471975506263043 + + let a = FixedTrait::ZERO(); + assert_relative(acos(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos(a), 17569060, 'invalid neg half', error); // 2.094395102963489 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[available_gas(3000000)] +fn test_acos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos_fast(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos_fast(a), 8784530, 'invalid half', error); // 1.0471975506263043 + + let a = FixedTrait::ZERO(); + assert_relative(acos_fast(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos_fast(a), 17569060, 'invalid neg half', error); // 2.094395102963489 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos_fast(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[should_panic] +#[available_gas(1000000)] +fn test_acos_fail() { + let a = FixedTrait::new(2 * ONE, true); + acos(a); +} + +#[test] +#[available_gas(1400000)] +fn test_atan_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan_fast(a), 9287437, 'invalid two', error); + + let a = FixedTrait::ONE(); + assert_relative(atan_fast(a), 6588397, 'invalid one', error); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan_fast(a), 3889358, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert(atan_fast(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan_fast(a), -3889358, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan_fast(a), -6588397, 'invalid neg one', error); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan_fast(a), -9287437, 'invalid neg two', error); +} + +#[test] +#[available_gas(2600000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan(a), 9287437, 'invalid two', Option::None(())); + + let a = FixedTrait::ONE(); + assert_relative(atan(a), 6588397, 'invalid one', Option::None(())); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan(a), 3889358, 'invalid half', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(atan(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan(a), -3889358, 'invalid neg half', Option::None(())); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan(a), -6588397, 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan(a), -9287437, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(3000000)] +fn test_asin() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert_relative(asin(a), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(asin(a), 4392265, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert_precise(asin(a), 0, 'invalid zero', Option::None(())); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(asin(a), -4392265, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(asin(a), -HALF_PI.into(), 'invalid neg one', Option::None(())); // -PI / 2 +} + +#[test] +#[should_panic] +#[available_gas(1000000)] +fn test_asin_fail() { + let a = FixedTrait::new(2 * ONE, false); + asin(a); +} + +#[test] +#[available_gas(6000000)] +fn test_cos() { + let a = FixedTrait::new(HALF_PI, false); + assert(cos(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_relative(cos(a), 5931642, 'invalid quarter pi', Option::None(())); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert_relative(cos(a), -1 * ONE.into(), 'invalid pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_relative(cos(a), -2308239, 'invalid 17', Option::None(())); // -0.2751631780463348 + + let a = FixedTrait::new_unscaled(17, true); + assert_relative(cos(a), -2308236, 'invalid -17', Option::None(())); // -0.2751631780463348 +} + +#[test] +#[available_gas(6000000)] +fn test_cos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert(cos_fast(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(cos_fast(a), 5931642, 'invalid quarter pi', error); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert_precise(cos_fast(a), -1 * ONE.into(), 'invalid pi', error); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(cos_fast(a), -2308239, 'invalid 17', error); // -0.2751631780463348 +} + +#[test] +#[available_gas(6000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin(a), ONE.into(), 'invalid half pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin(a), 5931642, 'invalid quarter pi', Option::None(())); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert(sin(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise( + sin(a), -ONE.into(), 'invalid neg half pi', Option::None(()) + ); // 0.9999999999939766 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin(a), -8064787, 'invalid 17', Option::None(())); // -0.9613974918793389 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin(a), 8064787, 'invalid -17', Option::None(())); // 0.9613974918793389 +} + +#[test] +#[available_gas(1000000)] +fn test_sin_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin_fast(a), ONE.into(), 'invalid half pi', error); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin_fast(a), 5931642, 'invalid quarter pi', error); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert(sin_fast(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(sin_fast(a), -ONE.into(), 'invalid neg half pi', error); // 0.9999999999939766 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin_fast(a), -8064787, 'invalid 17', error); // -0.9613974918793389 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin_fast(a), 8064787, 'invalid -17', error); // 0.9613974918793389 +} + +#[test] +#[available_gas(8000000)] +fn test_tan() { + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(tan(a), ONE.into(), 'invalid quarter pi', Option::None(())); + + let a = FixedTrait::new(PI, false); + assert_precise(tan(a), 0, 'invalid pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(tan(a), 29309069, 'invalid 17', Option::None(())); // 3.493917677159002 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(tan(a), -29309106, 'invalid -17', Option::None(())); // -3.493917677159002 +} From 8a55e5c730d6b2400165c517c251d6e3b603192b Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 11:00:19 +0300 Subject: [PATCH 43/78] implement tensor_fp8x23wide --- src/numbers.cairo | 165 ++++++++ .../implementations/fp8x23wide/core.cairo | 21 +- .../nn/implementations/nn_fp8x23.cairo | 6 +- src/operators/tensor/implementations.cairo | 3 +- .../implementations/tensor_fp8x23wide.cairo | 376 ++++++++++++++++++ 5 files changed, 568 insertions(+), 3 deletions(-) create mode 100644 src/operators/tensor/implementations/tensor_fp8x23wide.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index 02dd5b344..0a533e937 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -213,6 +213,171 @@ impl FP8x23Number of NumberTrait { } } +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{FP8x23WImpl, FP8x23W}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::core as core_fp8x23wide; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::comp as comp_fp8x23wide; + +impl FP8x23WNumber of NumberTrait { + fn new(mag: u64, sign: bool) -> FP8x23W { + FP8x23WImpl::new(mag, sign) + } + + fn new_unscaled(mag: u64, sign: bool) -> FP8x23W { + FP8x23WImpl::new_unscaled(mag, sign) + } + + fn from_felt(val: felt252) -> FP8x23W { + FP8x23WImpl::from_felt(val) + } + + fn ceil(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::ceil(self) + } + + fn exp(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::exp(self) + } + + fn exp2(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::exp2(self) + } + + fn floor(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::floor(self) + } + + fn ln(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::ln(self) + } + + fn log2(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::log2(self) + } + + fn log10(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::log10(self) + } + + fn pow(self: FP8x23W, b: FP8x23W) -> FP8x23W { + FP8x23WImpl::pow(self, b) + } + + fn round(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::round(self) + } + + fn sqrt(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::sqrt(self) + } + + fn acos(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::acos(self) + } + + fn asin(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::asin(self) + } + + fn atan(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::atan(self) + } + + fn cos(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::cos(self) + } + + fn sin(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::sin(self) + } + + fn tan(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::tan(self) + } + + fn acosh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::acosh(self) + } + + fn asinh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::asinh(self) + } + + fn atanh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::atanh(self) + } + + fn cosh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::cosh(self) + } + + fn sinh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::sinh(self) + } + + fn tanh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::tanh(self) + } + + fn zero() -> FP8x23W { + FP8x23WImpl::ZERO() + } + fn is_zero(self: FP8x23W) -> bool { + core_fp8x23wide::eq(@self, @FP8x23WImpl::ZERO()) + } + + fn one() -> FP8x23W { + FP8x23WImpl::ONE() + } + + fn neg_one() -> FP8x23W { + FP8x23W { mag: core_fp8x23wide::ONE, sign: true } + } + + fn is_one(self: FP8x23W) -> bool { + core_fp8x23wide::eq(@self, @FP8x23WImpl::ONE()) + } + + fn abs(self: FP8x23W) -> FP8x23W { + core_fp8x23wide::abs(self) + } + + fn min_value() -> FP8x23W { + FP8x23W { mag: core_fp8x23wide::MAX, sign: true } + } + + fn max_value() -> FP8x23W { + FP8x23W { mag: core_fp8x23wide::MAX, sign: false } + } + + fn min(self: FP8x23W, other: FP8x23W) -> FP8x23W { + comp_fp8x23wide::min(self, other) + } + + fn max(self: FP8x23W, other: FP8x23W) -> FP8x23W { + comp_fp8x23wide::max(self, other) + } + + fn mag(self: FP8x23W) -> u64 { + self.mag + } + + fn is_neg(self: FP8x23W) -> bool { + self.sign + } + + fn xor(lhs: FP8x23W, rhs: FP8x23W) -> bool { + comp_fp8x23wide::xor(lhs, rhs) + } + + fn or(lhs: FP8x23W, rhs: FP8x23W) -> bool { + comp_fp8x23wide::or(lhs, rhs) + } + + fn sign(self: FP8x23W) -> FP8x23W { + core_fp8x23wide::sign(self) + } +} + use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16}; use orion::numbers::fixed_point::implementations::fp16x16::math::core as core_fp16x16; use orion::numbers::fixed_point::implementations::fp16x16::math::comp as comp_fp16x16; diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo index 36b64ce5e..3fe3cd3cb 100644 --- a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo @@ -5,7 +5,7 @@ use result::{ResultTrait, ResultTraitImpl}; use traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; -use orion::numbers::fixed_point::core::{FixedTrait}; +use orion::numbers::{fixed_point::core::{FixedTrait}, FP8x23}; use orion::numbers::fixed_point::implementations::fp8x23wide::math::{core, trig, hyp}; use orion::numbers::fixed_point::utils; @@ -205,6 +205,25 @@ impl FP8x23WIntoFelt252 of Into { } } +impl FP8x23IntoFP8x23W of Into { + fn into(self: FP8x23) -> FP8x23W { + FP8x23W { mag: self.mag.into(), sign: self.sign } + } +} + +impl FP8x23WTryIntoFP8x23 of TryInto { + fn try_into(self: FP8x23W) -> Option { + match self.mag.try_into() { + Option::Some(val) => { + Option::Some(FP8x23 { mag: val, sign: self.sign }) + }, + Option::None(_) => { + Option::None(()) + } + } + } +} + impl FP8x23WTryIntoU128 of TryInto { fn try_into(self: FP8x23W) -> Option { if self.sign { diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 305eeaba2..510f8cebd 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -7,6 +7,10 @@ use orion::numbers::fixed_point::implementations::fp8x23::core::FP8x23; use orion::operators::tensor::implementations::tensor_fp8x23::{ FP8x23Tensor, FP8x23TensorDiv, FP8x23TensorAdd }; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + FP8x23WImpl, FP8x23WTryIntoFP8x23, FP8x23W, FP8x23IntoFP8x23W +}; +use orion::operators::tensor::implementations::tensor_fp8x23wide::{FP8x23WTensor}; impl FP8x23NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -18,7 +22,7 @@ impl FP8x23NN of NNTrait { } fn softmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::softmax::softmax(tensor, axis) + functional::softmax::softmaxWide::(tensor, axis) } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { diff --git a/src/operators/tensor/implementations.cairo b/src/operators/tensor/implementations.cairo index 0df3dcdec..a96030744 100644 --- a/src/operators/tensor/implementations.cairo +++ b/src/operators/tensor/implementations.cairo @@ -5,4 +5,5 @@ mod tensor_fp8x23; mod tensor_fp16x16; mod tensor_fp64x64; mod tensor_fp32x32; -mod tensor_fp16x16wide; \ No newline at end of file +mod tensor_fp16x16wide; +mod tensor_fp8x23wide; \ No newline at end of file diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo new file mode 100644 index 000000000..91ca9338d --- /dev/null +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -0,0 +1,376 @@ +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; +use traits::{TryInto, Into}; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{ + new_tensor, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, +}; +use orion::operators::tensor::{math, linalg, quantization, core}; +use orion::numbers::{i8, i32, NumberTrait, FP8x23W}; +use orion::operators::tensor::implementations::{tensor_i8::I8Tensor, tensor_u32::U32Tensor}; + +impl FP8x23WTensor of TensorTrait { + fn new(shape: Span, data: Span) -> Tensor { + new_tensor(shape, data) + } + + fn at(self: @Tensor, indices: Span) -> FP8x23W { + *at_tensor(self, indices) + } + + fn min(self: @Tensor) -> FP8x23W { + math::min::min_in_tensor::(*self.data) + } + + fn max(self: @Tensor) -> FP8x23W { + math::max::max_in_tensor(*self.data) + } + + fn stride(self: @Tensor) -> Span { + stride(*self.shape) + } + + fn ravel_index(self: @Tensor, indices: Span) -> usize { + ravel_index(*self.shape, indices) + } + + fn unravel_index(self: @Tensor, index: usize) -> Span { + unravel_index(index, *self.shape) + } + + fn reshape(self: @Tensor, target_shape: Span) -> Tensor { + reshape(self, target_shape) + } + + fn reduce_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_sum::reduce_sum(self, axis, keepdims) + } + + fn argmax( + self: @Tensor, axis: usize, keepdims: Option, select_last_index: Option + ) -> Tensor { + math::argmax::argmax(self, axis, keepdims, select_last_index) + } + + fn argmin( + self: @Tensor, axis: usize, keepdims: Option, select_last_index: Option + ) -> Tensor { + math::argmin::argmin(self, axis, keepdims, select_last_index) + } + + fn transpose(self: @Tensor, axes: Span) -> Tensor { + linalg::transpose::transpose(self, axes) + } + + fn matmul(self: @Tensor, other: @Tensor) -> Tensor { + linalg::matmul::matmul(self, other) + } + + fn exp(self: @Tensor) -> Tensor { + math::exp::exp(*self) + } + + fn log(self: @Tensor) -> Tensor { + math::log::log(*self) + } + + fn equal(self: @Tensor, other: @Tensor) -> Tensor { + math::equal::equal(self, other) + } + + fn greater(self: @Tensor, other: @Tensor) -> Tensor { + math::greater::greater(self, other) + } + + fn greater_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::greater_equal::greater_equal(self, other) + } + + fn less(self: @Tensor, other: @Tensor) -> Tensor { + math::less::less(self, other) + } + + fn less_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::less_equal::less_equal(self, other) + } + + fn abs(self: @Tensor) -> Tensor { + math::abs::abs(*self) + } + + fn ceil(self: @Tensor) -> Tensor { + math::ceil::ceil(*self) + } + + fn sin(self: @Tensor) -> Tensor { + math::sin::sin(*self) + } + + fn cos(self: @Tensor) -> Tensor { + math::cos::cos(*self) + } + + fn asin(self: @Tensor) -> Tensor { + math::asin::asin(*self) + } + + fn cumsum( + self: @Tensor, axis: usize, exclusive: Option, reverse: Option + ) -> Tensor { + math::cumsum::cumsum(self, axis, exclusive, reverse) + } + + fn flatten(self: @Tensor, axis: usize) -> Tensor { + math::flatten::flatten(self, axis) + } + + fn sinh(self: @Tensor) -> Tensor { + math::sinh::sinh(*self) + } + + fn tanh(self: @Tensor) -> Tensor { + math::tanh::tanh(*self) + } + + fn cosh(self: @Tensor) -> Tensor { + math::cosh::cosh(*self) + } + + fn acosh(self: @Tensor) -> Tensor { + math::acosh::acosh(*self) + } + + fn asinh(self: @Tensor) -> Tensor { + math::asinh::asinh(*self) + } + + fn atan(self: @Tensor) -> Tensor { + math::atan::atan(*self) + } + + fn xor(self: @Tensor, other: @Tensor) -> Tensor { + math::xor::xor(self, other) + } + + fn or(self: @Tensor, other: @Tensor) -> Tensor { + math::or::or(self, other) + } + + fn acos(self: @Tensor) -> Tensor { + math::acos::acos(*self) + } + + fn onehot( + self: @Tensor, depth: usize, axis: Option, values: Span + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sqrt(self: @Tensor) -> Tensor { + math::sqrt::sqrt(*self) + } + + fn concat(tensors: Span>, axis: usize,) -> Tensor { + math::concat::concat(tensors, axis) + } + + fn quantize_linear( + self: @Tensor, y_scale: @Tensor, y_zero_point: @Tensor + ) -> Tensor:: { + quantization::quantize_linear::quantize_linear( + self, + y_scale, + y_zero_point, + NumberTrait::new_unscaled(128, true), + NumberTrait::new_unscaled(127, false) + ) + } + + fn dequantize_linear( + self: @Tensor, x_scale: @Tensor, x_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn slice( + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> + ) -> Tensor { + core::slice::(self, starts, ends, axes, steps) + } + + fn gather( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) + } + + fn nonzero(self: @Tensor) -> Tensor { + core::nonzero(self) + } + + fn squeeze(self: @Tensor, axes: Option>) -> Tensor { + core::squeeze(self, axes) + } + + fn unsqueeze(self: @Tensor, axes: Span) -> Tensor { + core::unsqueeze(self, axes) + } + + fn sign(self: @Tensor) -> Tensor { + math::sign::sign(*self) + } + + fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { + core::clip(self, min, max) + } +} + +/// Implements addition for `Tensor` using the `Add` trait. +impl FP8x23WTensorAdd< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TAdd: Add, + impl TCopy: Copy, + impl TDrop: Drop +> of Add> { + /// Adds two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise addition. + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } +} + +/// Implements subtraction for `Tensor` using the `Sub` trait. +impl FP8x23WTensorSub< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TSub: Sub, + impl TCopy: Copy, + impl TDrop: Drop +> of Sub> { + /// Subtracts two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise subtraction. + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } +} + +/// Implements multiplication for `Tensor` using the `Mul` trait. +impl FP8x23WTensorMul< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TMul: Mul, + impl TCopy: Copy, + impl TDrop: Drop +> of Mul> { + /// Multiplies two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise multiplication. + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } +} + +/// Implements division for `Tensor` using the `Div` trait. +impl FP8x23WTensorDiv< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TDiv: Div, + impl TCopy: Copy, + impl TDrop: Drop +> of Div> { + /// Divides two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise division. + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } +} + +/// Implements partial equal for two `Tensor` using the `PartialEq` trait. +impl FP8x23WTensorPartialEq of PartialEq> { + fn eq(lhs: @Tensor, rhs: @Tensor) -> bool { + tensor_eq(*lhs, *rhs) + } + + fn ne(lhs: @Tensor, rhs: @Tensor) -> bool { + !tensor_eq(*lhs, *rhs) + } +} + +impl U32TryIntoU32 of TryInto { + fn try_into(self: u64) -> Option { + Option::Some(self) + } +} + +// Internals + +const PRECISION: u64 = 75497; // 0.009 + +fn relative_eq(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { + let diff = *lhs - *rhs; + + let rel_diff = if *lhs.mag != 0 { + (diff / *lhs).mag + } else { + diff.mag + }; + + rel_diff <= PRECISION +} + +fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { + let mut is_eq = true; + + loop { + if lhs.shape.len() == 0 || !is_eq { + break; + } + + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; + + if !is_eq { + return false; + } + + loop { + if lhs.data.len() == 0 || !is_eq { + break; + } + + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; + + return is_eq; +} + From 19d5e99d5b7e487776d6c42dc5b7b4823514c0eb Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 11:10:27 +0300 Subject: [PATCH 44/78] add logsoftmaxwide --- src/operators/nn/functional/logsoftmax.cairo | 28 +++++++++++++++++++ .../nn/implementations/nn_fp16x16.cairo | 2 +- .../nn/implementations/nn_fp8x23.cairo | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/logsoftmax.cairo b/src/operators/nn/functional/logsoftmax.cairo index 6d19cbb62..bd38d138c 100644 --- a/src/operators/nn/functional/logsoftmax.cairo +++ b/src/operators/nn/functional/logsoftmax.cairo @@ -2,6 +2,8 @@ use array::SpanTrait; use orion::numbers::NumberTrait; use orion::operators::tensor::core::{Tensor, TensorTrait}; +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast}; /// Cf: NNTrait::logsoftmax docstring fn logsoftmax< @@ -16,3 +18,29 @@ fn logsoftmax< return logsoftmax; } + +/// Cf: NNTrait::logsoftmax docstring +fn logsoftmaxWide< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TFixed: FixedTrait, + impl WFixed: FixedTrait, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax = div_downcast(@exp_tensor, @sum); + softmax.log() +} \ No newline at end of file diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index b940d8742..de81cde6d 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -28,7 +28,7 @@ impl FP16x16NN of NNTrait { } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::logsoftmax::logsoftmax(tensor, axis) + functional::logsoftmax::logsoftmaxWide::(tensor, axis) } fn softsign(tensor: @Tensor) -> Tensor { diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 510f8cebd..d837b8fef 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -26,7 +26,7 @@ impl FP8x23NN of NNTrait { } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::logsoftmax::logsoftmax(tensor, axis) + functional::logsoftmax::logsoftmaxWide::(tensor, axis) } fn softsign(tensor: @Tensor) -> Tensor { From d025293f3115d309fd2bc2baf92ccce427e2ac9b Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 04:51:50 -0400 Subject: [PATCH 45/78] feat: hard_sigmoid --- src/operators/nn/core.cairo | 53 +++++++++++++++++++ src/operators/nn/functional.cairo | 1 + .../nn/functional/hard_sigmoid.cairo | 43 +++++++++++++++ .../nn/implementations/nn_fp16x16.cairo | 4 ++ .../nn/implementations/nn_fp32x32.cairo | 4 ++ .../nn/implementations/nn_fp64x64.cairo | 4 ++ .../nn/implementations/nn_fp8x23.cairo | 4 ++ src/operators/nn/implementations/nn_i32.cairo | 4 ++ src/operators/nn/implementations/nn_i8.cairo | 4 ++ src/operators/nn/implementations/nn_u32.cairo | 4 ++ 10 files changed, 125 insertions(+) create mode 100644 src/operators/nn/functional/hard_sigmoid.cairo diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index 476616e30..bd40930e8 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -451,4 +451,57 @@ trait NNTrait { /// ``` /// fn leaky_relu(inputs: @Tensor, alpha: @T) -> Tensor; + /// # NNTrait::hard_sigmoid + /// + /// ```rust + /// fn hard_sigmoid(tensor: @Tensor, alpha: @T, beta: @T) -> Tensor; + /// ``` + /// + /// Applies the HardSigmoid function to an n-dimensional input tensor. + /// + /// $$ + /// \text{HardSigmoid}(x_i) = \text{max}(0, \text{min}(alpha * x + beta, 1)) + /// $$ + /// + /// ## Args + /// + /// * `tensor`(`@Tensor`) - The input tensor. + /// + /// ## Returns + /// + /// A Tensor of fixed point numbers with the same shape than the input Tensor. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point tensors. + /// + /// ## Examples + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, FP8x23}; + /// use orion::operators::nn::{NNTrait, FP8x23NN}; + /// use orion::numbers::{FP8x23, FixedTrait}; + /// + /// fn hard_sigmoid_example() -> Tensor { + /// let tensor = TensorTrait::::new( + /// shape: array![2, 2].span(), + /// data: array![ + /// FixedTrait::new(0, false), + /// FixedTrait::new(1, false), + /// FixedTrait::new(2, false), + /// FixedTrait::new(3, false), + /// ] + /// .span(), + /// ); + /// + /// return NNTrait::hard_sigmoid(@tensor); + /// } + /// >>> [[4194304,6132564],[7388661,7990771]] + /// // The fixed point representation of + /// // [[0.5, 0.7310586],[0.88079703, 0.95257413]] + /// ``` + /// + fn hard_sigmoid(tensor: @Tensor, alpha: @T, beta: @T) -> Tensor; } diff --git a/src/operators/nn/functional.cairo b/src/operators/nn/functional.cairo index 272475c30..39a743304 100644 --- a/src/operators/nn/functional.cairo +++ b/src/operators/nn/functional.cairo @@ -6,3 +6,4 @@ mod softsign; mod softplus; mod linear; mod logsoftmax; +mod hard_sigmoid; diff --git a/src/operators/nn/functional/hard_sigmoid.cairo b/src/operators/nn/functional/hard_sigmoid.cairo new file mode 100644 index 000000000..8f7f3dd10 --- /dev/null +++ b/src/operators/nn/functional/hard_sigmoid.cairo @@ -0,0 +1,43 @@ +use core::traits::Into; +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; + + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{Tensor, TensorTrait}; +use orion::numbers::NumberTrait; + +/// Cf: NNTrait::hard_sigmoid docstring +fn hard_sigmoid< + T, + MAG, + impl TNumber: NumberTrait, + impl TTensor: TensorTrait, + impl TPartialOrd: PartialOrd, + impl TAdd: Add, + impl TMul: Mul, + impl TDiv: Div, + impl TCopy: Copy, + impl TDrop: Drop, +>( + mut x: Tensor, alpha: @T, beta: @T +) -> Tensor { + let mut data_result = ArrayTrait::::new(); + + loop { + match x.data.pop_front() { + Option::Some(item) => { + let temp = (*item) * (*alpha) + (*beta); + let result = temp.min(NumberTrait::one()).max(NumberTrait::zero()); + data_result.append(result); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::new(x.shape, data_result.span()); +} + diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index a0094de29..178cca7a7 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -42,4 +42,8 @@ impl FP16x16NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP16x16) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @FP16x16, beta: @FP16x16) -> Tensor { + functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) + } } diff --git a/src/operators/nn/implementations/nn_fp32x32.cairo b/src/operators/nn/implementations/nn_fp32x32.cairo index a550caa31..a4f9cc6cd 100644 --- a/src/operators/nn/implementations/nn_fp32x32.cairo +++ b/src/operators/nn/implementations/nn_fp32x32.cairo @@ -42,4 +42,8 @@ impl FP32x32NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP32x32) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @FP32x32, beta: @FP32x32) -> Tensor { + functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) + } } diff --git a/src/operators/nn/implementations/nn_fp64x64.cairo b/src/operators/nn/implementations/nn_fp64x64.cairo index bcd4be5db..d5d4ffe24 100644 --- a/src/operators/nn/implementations/nn_fp64x64.cairo +++ b/src/operators/nn/implementations/nn_fp64x64.cairo @@ -42,4 +42,8 @@ impl FP64x64NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP64x64) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @FP64x64, beta: @FP64x64) -> Tensor { + functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) + } } diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 305eeaba2..87a89b0cc 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -42,4 +42,8 @@ impl FP8x23NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP8x23) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @FP8x23, beta: @FP8x23) -> Tensor { + functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) + } } diff --git a/src/operators/nn/implementations/nn_i32.cairo b/src/operators/nn/implementations/nn_i32.cairo index c74dfc268..d4d0b59e2 100644 --- a/src/operators/nn/implementations/nn_i32.cairo +++ b/src/operators/nn/implementations/nn_i32.cairo @@ -38,4 +38,8 @@ impl I32NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @i32) -> Tensor { panic(array!['not supported!']) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @i32, beta: @i32) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_i8.cairo b/src/operators/nn/implementations/nn_i8.cairo index 12e9812a1..c220152bf 100644 --- a/src/operators/nn/implementations/nn_i8.cairo +++ b/src/operators/nn/implementations/nn_i8.cairo @@ -38,4 +38,8 @@ impl I8NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @i8) -> Tensor { panic(array!['not supported!']) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @i8, beta: @i8) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_u32.cairo b/src/operators/nn/implementations/nn_u32.cairo index 0da74d544..f6944f4f8 100644 --- a/src/operators/nn/implementations/nn_u32.cairo +++ b/src/operators/nn/implementations/nn_u32.cairo @@ -37,4 +37,8 @@ impl U32NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @u32) -> Tensor { panic(array!['not supported!']) } + + fn hard_sigmoid(tensor: @Tensor, alpha: @u32, beta: @u32) -> Tensor { + panic(array!['not supported!']) + } } From a632260adadc1349d925a944b98794636f198e16 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 05:12:11 -0400 Subject: [PATCH 46/78] docgen + nodegen --- docs/framework/operators/tensor/tensor.neg.md | 39 +++++++++++++++++++ tests/src/nodes.cairo | 4 ++ tests/src/nodes/neg_fp16x16.cairo | 20 ++++++++++ tests/src/nodes/neg_fp16x16/input_0.cairo | 18 +++++++++ tests/src/nodes/neg_fp16x16/output_0.cairo | 18 +++++++++ tests/src/nodes/neg_fp8x23.cairo | 20 ++++++++++ tests/src/nodes/neg_fp8x23/input_0.cairo | 18 +++++++++ tests/src/nodes/neg_fp8x23/output_0.cairo | 18 +++++++++ tests/src/nodes/neg_i32.cairo | 20 ++++++++++ tests/src/nodes/neg_i32/input_0.cairo | 17 ++++++++ tests/src/nodes/neg_i32/output_0.cairo | 17 ++++++++ tests/src/nodes/neg_i8.cairo | 20 ++++++++++ tests/src/nodes/neg_i8/input_0.cairo | 17 ++++++++ tests/src/nodes/neg_i8/output_0.cairo | 17 ++++++++ 14 files changed, 263 insertions(+) create mode 100644 docs/framework/operators/tensor/tensor.neg.md create mode 100644 tests/src/nodes/neg_fp16x16.cairo create mode 100644 tests/src/nodes/neg_fp16x16/input_0.cairo create mode 100644 tests/src/nodes/neg_fp16x16/output_0.cairo create mode 100644 tests/src/nodes/neg_fp8x23.cairo create mode 100644 tests/src/nodes/neg_fp8x23/input_0.cairo create mode 100644 tests/src/nodes/neg_fp8x23/output_0.cairo create mode 100644 tests/src/nodes/neg_i32.cairo create mode 100644 tests/src/nodes/neg_i32/input_0.cairo create mode 100644 tests/src/nodes/neg_i32/output_0.cairo create mode 100644 tests/src/nodes/neg_i8.cairo create mode 100644 tests/src/nodes/neg_i8/input_0.cairo create mode 100644 tests/src/nodes/neg_i8/output_0.cairo diff --git a/docs/framework/operators/tensor/tensor.neg.md b/docs/framework/operators/tensor/tensor.neg.md new file mode 100644 index 000000000..f5c9b3816 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.neg.md @@ -0,0 +1,39 @@ +#tensor.neg + +```rust + fn neg(self: @Tensor) -> Tensor; +``` + +Computes the negation of all elements in the input tensor. + +## Args + +* `self`(`@Tensor`) - The input tensor. + + +## Returns + +A new `Tensor` of the same shape as the input tensor with +the negation of all elements in the input tensor. + +## Example + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, I32Tensor}; +use orion::numbers::{i32, IntegerTrait}; + +fn neg_example() -> Tensor { + let tensor = TensorTrait::new( + shape: array![3].span(), + data: array![ + IntegerTrait::new(1, true), IntegerTrait::new(2, true), IntegerTrait::new(3, false) + ] + .span(), + ); + + return tensor.neg(); +} +>>> [1, 2, -3] +``` diff --git a/tests/src/nodes.cairo b/tests/src/nodes.cairo index 398edc4bf..2d4cc383a 100644 --- a/tests/src/nodes.cairo +++ b/tests/src/nodes.cairo @@ -433,3 +433,7 @@ mod clip_i8_2d; mod clip_i8_3d; mod clip_u32_2d; mod clip_u32_3d; +mod neg_fp16x16; +mod neg_fp8x23; +mod neg_i32; +mod neg_i8; diff --git a/tests/src/nodes/neg_fp16x16.cairo b/tests/src/nodes/neg_fp16x16.cairo new file mode 100644 index 000000000..5c0731454 --- /dev/null +++ b/tests/src/nodes/neg_fp16x16.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_neg_fp16x16() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.neg(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/neg_fp16x16/input_0.cairo b/tests/src/nodes/neg_fp16x16/input_0.cairo new file mode 100644 index 000000000..dc17867bd --- /dev/null +++ b/tests/src/nodes/neg_fp16x16/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1507328, sign: true }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: true }); + data.append(FP16x16 { mag: 7274496, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_fp16x16/output_0.cairo b/tests/src/nodes/neg_fp16x16/output_0.cairo new file mode 100644 index 000000000..cac798f63 --- /dev/null +++ b/tests/src/nodes/neg_fp16x16/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 983040, sign: true }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 7274496, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_fp8x23.cairo b/tests/src/nodes/neg_fp8x23.cairo new file mode 100644 index 000000000..349c6fc08 --- /dev/null +++ b/tests/src/nodes/neg_fp8x23.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_neg_fp8x23() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.neg(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/neg_fp8x23/input_0.cairo b/tests/src/nodes/neg_fp8x23/input_0.cairo new file mode 100644 index 000000000..5a59a5874 --- /dev/null +++ b/tests/src/nodes/neg_fp8x23/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1015021568, sign: true }); + data.append(FP8x23 { mag: 109051904, sign: true }); + data.append(FP8x23 { mag: 637534208, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_fp8x23/output_0.cairo b/tests/src/nodes/neg_fp8x23/output_0.cairo new file mode 100644 index 000000000..10fd45912 --- /dev/null +++ b/tests/src/nodes/neg_fp8x23/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1015021568, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 637534208, sign: true }); + data.append(FP8x23 { mag: 92274688, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_i32.cairo b/tests/src/nodes/neg_i32.cairo new file mode 100644 index 000000000..fa738e698 --- /dev/null +++ b/tests/src/nodes/neg_i32.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::I32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_neg_i32() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.neg(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/neg_i32/input_0.cairo b/tests/src/nodes/neg_i32/input_0.cairo new file mode 100644 index 000000000..46b571d16 --- /dev/null +++ b/tests/src/nodes/neg_i32/input_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 105, sign: true }); + data.append(i32 { mag: 124, sign: true }); + data.append(i32 { mag: 53, sign: true }); + data.append(i32 { mag: 77, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_i32/output_0.cairo b/tests/src/nodes/neg_i32/output_0.cairo new file mode 100644 index 000000000..de0636cda --- /dev/null +++ b/tests/src/nodes/neg_i32/output_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 105, sign: false }); + data.append(i32 { mag: 124, sign: false }); + data.append(i32 { mag: 53, sign: false }); + data.append(i32 { mag: 77, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_i8.cairo b/tests/src/nodes/neg_i8.cairo new file mode 100644 index 000000000..decb2e9aa --- /dev/null +++ b/tests/src/nodes/neg_i8.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_neg_i8() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.neg(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/neg_i8/input_0.cairo b/tests/src/nodes/neg_i8/input_0.cairo new file mode 100644 index 000000000..6f0c3c00a --- /dev/null +++ b/tests/src/nodes/neg_i8/input_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 89, sign: false }); + data.append(i8 { mag: 18, sign: true }); + data.append(i8 { mag: 113, sign: false }); + data.append(i8 { mag: 63, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/neg_i8/output_0.cairo b/tests/src/nodes/neg_i8/output_0.cairo new file mode 100644 index 000000000..e339959f3 --- /dev/null +++ b/tests/src/nodes/neg_i8/output_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 89, sign: true }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 113, sign: true }); + data.append(i8 { mag: 63, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file From a6cb818903454feaeb7656a1449c6e30cde42e46 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 07:12:41 -0400 Subject: [PATCH 47/78] docgen + nodegen --- .../neural-network/nn.hard_sigmoid.md | 53 +++++++++++++++++++ nodegen/node/hard_sigmoid.py | 41 ++++++++++++++ src/operators/nn/core.cairo | 22 ++++---- tests/src/nodes.cairo | 2 + tests/src/nodes/hard_sigmoid_fp16x16.cairo | 20 +++++++ .../nodes/hard_sigmoid_fp16x16/input_0.cairo | 18 +++++++ .../nodes/hard_sigmoid_fp16x16/output_0.cairo | 18 +++++++ tests/src/nodes/hard_sigmoid_fp8x23.cairo | 37 +++++++++++++ .../nodes/hard_sigmoid_fp8x23/input_0.cairo | 18 +++++++ .../nodes/hard_sigmoid_fp8x23/output_0.cairo | 18 +++++++ 10 files changed, 237 insertions(+), 10 deletions(-) create mode 100644 docs/framework/operators/neural-network/nn.hard_sigmoid.md create mode 100644 nodegen/node/hard_sigmoid.py create mode 100644 tests/src/nodes/hard_sigmoid_fp16x16.cairo create mode 100644 tests/src/nodes/hard_sigmoid_fp16x16/input_0.cairo create mode 100644 tests/src/nodes/hard_sigmoid_fp16x16/output_0.cairo create mode 100644 tests/src/nodes/hard_sigmoid_fp8x23.cairo create mode 100644 tests/src/nodes/hard_sigmoid_fp8x23/input_0.cairo create mode 100644 tests/src/nodes/hard_sigmoid_fp8x23/output_0.cairo diff --git a/docs/framework/operators/neural-network/nn.hard_sigmoid.md b/docs/framework/operators/neural-network/nn.hard_sigmoid.md new file mode 100644 index 000000000..be4338109 --- /dev/null +++ b/docs/framework/operators/neural-network/nn.hard_sigmoid.md @@ -0,0 +1,53 @@ +# NNTrait::hard_sigmoid + +```rust + fn hard_sigmoid(tensor: @Tensor, alpha: @T, beta: @T) -> Tensor; +``` + +Applies the HardSigmoid function to an n-dimensional input tensor. + +$$ +\text{HardSigmoid}(x_i) = \text{max}(0, \text{min}(alpha * x + beta, 1)) +$$ + +## Args + +* `tensor`(`@Tensor`) - The input tensor. +* `alpha`(`@T`) - value of alpha. +* `beta`(`@T`) - value of beta. + +## Returns + +A Tensor of fixed point numbers with the same shape than the input Tensor. + +## Type Constraints + +Constrain input and output types to fixed point tensors. + +## Examples + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, FP8x23}; +use orion::operators::nn::{NNTrait, FP8x23NN}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn hard_sigmoid_example() -> Tensor { + let tensor = TensorTrait::::new( + shape: array![2, 2].span(), + data: array![ + FixedTrait::new(0, false), + FixedTrait::new(13107, false), + FixedTrait::new(32768, false), + FixedTrait::new(65536, false), + ] + .span(), + ); + let alpha = FixedTrait::new(13107, false); + let beta = FixedTrait::new(32768, false); + + return NNTrait::hard_sigmoid(@tensor, @alpha, @beta); +} +>>> [[32768, 35389],[39321, 45875]] +``` diff --git a/nodegen/node/hard_sigmoid.py b/nodegen/node/hard_sigmoid.py new file mode 100644 index 000000000..11de116d2 --- /dev/null +++ b/nodegen/node/hard_sigmoid.py @@ -0,0 +1,41 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl, Trait + +class Hard_sigmoid(RunAll): + + @staticmethod + def fp8x23(): + alpha = 0.2 + beta = 0.5 + x = np.random.uniform(-3, 3, (2, 2)).astype(np.float32) + y = np.maximum(0, np.minimum(1, alpha * x + beta)) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "hard_sigmoid_fp8x23" + make_node([x], [y], name) + make_test([x], y, "NNTrait::hard_sigmoid(@input_0, @FixedTrait::new(1677721, false), @FixedTrait::new(4194304, false))", + name, Trait.NN) + + @staticmethod + def fp16x16(): + alpha = 0.2 + beta = 0.5 + x = np.random.uniform(-3, 3, (2, 2)).astype(np.float32) + y = np.maximum(0, np.minimum(1, alpha * x + beta)) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "hard_sigmoid_fp16x16" + make_node([x], [y], name) + make_test([x], y, "NNTrait::hard_sigmoid(@input_0, @FixedTrait::new(13107, false), @FixedTrait::new(32768, false))", + name, Trait.NN) + + diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index bd40930e8..ce49540a0 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -466,6 +466,8 @@ trait NNTrait { /// ## Args /// /// * `tensor`(`@Tensor`) - The input tensor. + /// * `alpha`(`@T`) - value of alpha. + /// * `beta`(`@T`) - value of beta. /// /// ## Returns /// @@ -482,25 +484,25 @@ trait NNTrait { /// /// use orion::operators::tensor::{TensorTrait, Tensor, FP8x23}; /// use orion::operators::nn::{NNTrait, FP8x23NN}; - /// use orion::numbers::{FP8x23, FixedTrait}; + /// use orion::numbers::{FP16x16, FixedTrait}; /// - /// fn hard_sigmoid_example() -> Tensor { - /// let tensor = TensorTrait::::new( + /// fn hard_sigmoid_example() -> Tensor { + /// let tensor = TensorTrait::::new( /// shape: array![2, 2].span(), /// data: array![ /// FixedTrait::new(0, false), - /// FixedTrait::new(1, false), - /// FixedTrait::new(2, false), - /// FixedTrait::new(3, false), + /// FixedTrait::new(13107, false), + /// FixedTrait::new(32768, false), + /// FixedTrait::new(65536, false), /// ] /// .span(), /// ); + /// let alpha = FixedTrait::new(13107, false); + /// let beta = FixedTrait::new(32768, false); /// - /// return NNTrait::hard_sigmoid(@tensor); + /// return NNTrait::hard_sigmoid(@tensor, @alpha, @beta); /// } - /// >>> [[4194304,6132564],[7388661,7990771]] - /// // The fixed point representation of - /// // [[0.5, 0.7310586],[0.88079703, 0.95257413]] + /// >>> [[32768, 35389],[39321, 45875]] /// ``` /// fn hard_sigmoid(tensor: @Tensor, alpha: @T, beta: @T) -> Tensor; diff --git a/tests/src/nodes.cairo b/tests/src/nodes.cairo index 398edc4bf..4c3b04cfb 100644 --- a/tests/src/nodes.cairo +++ b/tests/src/nodes.cairo @@ -433,3 +433,5 @@ mod clip_i8_2d; mod clip_i8_3d; mod clip_u32_2d; mod clip_u32_3d; +mod hard_sigmoid_fp8x23; +mod hard_sigmoid_fp16x16; diff --git a/tests/src/nodes/hard_sigmoid_fp16x16.cairo b/tests/src/nodes/hard_sigmoid_fp16x16.cairo new file mode 100644 index 000000000..e2e7faf89 --- /dev/null +++ b/tests/src/nodes/hard_sigmoid_fp16x16.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_hard_sigmoid_fp16x16() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = NNTrait::hard_sigmoid(@input_0, @FixedTrait::new(13107, false), @FixedTrait::new(32768, false)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/hard_sigmoid_fp16x16/input_0.cairo b/tests/src/nodes/hard_sigmoid_fp16x16/input_0.cairo new file mode 100644 index 000000000..dd5b12108 --- /dev/null +++ b/tests/src/nodes/hard_sigmoid_fp16x16/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 35866, sign: true }); + data.append(FP16x16 { mag: 152077, sign: false }); + data.append(FP16x16 { mag: 17807, sign: true }); + data.append(FP16x16 { mag: 93701, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/hard_sigmoid_fp16x16/output_0.cairo b/tests/src/nodes/hard_sigmoid_fp16x16/output_0.cairo new file mode 100644 index 000000000..31f36b587 --- /dev/null +++ b/tests/src/nodes/hard_sigmoid_fp16x16/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 25594, sign: false }); + data.append(FP16x16 { mag: 63183, sign: false }); + data.append(FP16x16 { mag: 29206, sign: false }); + data.append(FP16x16 { mag: 51508, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/hard_sigmoid_fp8x23.cairo b/tests/src/nodes/hard_sigmoid_fp8x23.cairo new file mode 100644 index 000000000..5fc213fda --- /dev/null +++ b/tests/src/nodes/hard_sigmoid_fp8x23.cairo @@ -0,0 +1,37 @@ +mod input_0; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP8x23NN; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_hard_sigmoid_fp8x23() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = NNTrait::hard_sigmoid(@input_0, @FixedTrait::new(1677721, false), @FixedTrait::new(4194304, false)); + + assert_eq(y, z); +} + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP8x23NN; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_hard_sigmoid_fp8x23() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = NNTrait::hard_sigmoid(@input_0, @FixedTrait::new(1677721, false), @FixedTrait::new(4194304, false)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/hard_sigmoid_fp8x23/input_0.cairo b/tests/src/nodes/hard_sigmoid_fp8x23/input_0.cairo new file mode 100644 index 000000000..e8756adf0 --- /dev/null +++ b/tests/src/nodes/hard_sigmoid_fp8x23/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 12689105, sign: true }); + data.append(FP8x23 { mag: 6909640, sign: false }); + data.append(FP8x23 { mag: 13798595, sign: true }); + data.append(FP8x23 { mag: 9114792, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/hard_sigmoid_fp8x23/output_0.cairo b/tests/src/nodes/hard_sigmoid_fp8x23/output_0.cairo new file mode 100644 index 000000000..a55c6e091 --- /dev/null +++ b/tests/src/nodes/hard_sigmoid_fp8x23/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1656483, sign: false }); + data.append(FP8x23 { mag: 5576232, sign: false }); + data.append(FP8x23 { mag: 1434585, sign: false }); + data.append(FP8x23 { mag: 2371345, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file From 612ed6733c31fee12dfeb5c53bf926d855a150fa Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 07:37:01 -0400 Subject: [PATCH 48/78] fix test --- tests/src/nodes/hard_sigmoid_fp8x23.cairo | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/src/nodes/hard_sigmoid_fp8x23.cairo b/tests/src/nodes/hard_sigmoid_fp8x23.cairo index 5fc213fda..7db64baac 100644 --- a/tests/src/nodes/hard_sigmoid_fp8x23.cairo +++ b/tests/src/nodes/hard_sigmoid_fp8x23.cairo @@ -18,20 +18,3 @@ fn test_hard_sigmoid_fp8x23() { assert_eq(y, z); } - -use orion::operators::nn::NNTrait; -use orion::numbers::FixedTrait; -use orion::operators::nn::FP8x23NN; -use orion::operators::tensor::FP8x23TensorPartialEq; -use orion::utils::assert_eq; - -#[test] -#[available_gas(2000000000)] -fn test_hard_sigmoid_fp8x23() { - let input_0 = input_0::input_0(); - let z = output_0::output_0(); - - let y = NNTrait::hard_sigmoid(@input_0, @FixedTrait::new(1677721, false), @FixedTrait::new(4194304, false)); - - assert_eq(y, z); -} \ No newline at end of file From 6462c61a4c11b39d36065e2717129a2c97e79fc2 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:43:03 -0400 Subject: [PATCH 49/78] feat: thresholded relu --- .../neural-network/nn.thresholded_relu.md | 47 ++++++++++++++++++ nodegen/node/thresholded_relu.py | 44 +++++++++++++++++ src/operators/nn/core.cairo | 49 +++++++++++++++++++ src/operators/nn/functional.cairo | 1 + .../nn/functional/thresholded_relu.cairo | 38 ++++++++++++++ .../nn/implementations/nn_fp16x16.cairo | 4 ++ .../nn/implementations/nn_fp32x32.cairo | 4 ++ .../nn/implementations/nn_fp64x64.cairo | 4 ++ .../nn/implementations/nn_fp8x23.cairo | 4 ++ src/operators/nn/implementations/nn_i32.cairo | 4 ++ src/operators/nn/implementations/nn_i8.cairo | 4 ++ src/operators/nn/implementations/nn_u32.cairo | 4 ++ tests/src/nodes.cairo | 2 + .../src/nodes/thresholded_relu_fp16x16.cairo | 20 ++++++++ .../thresholded_relu_fp16x16/input_0.cairo | 18 +++++++ .../thresholded_relu_fp16x16/output_0.cairo | 18 +++++++ tests/src/nodes/thresholded_relu_fp8x23.cairo | 20 ++++++++ .../thresholded_relu_fp8x23/input_0.cairo | 18 +++++++ .../thresholded_relu_fp8x23/output_0.cairo | 18 +++++++ 19 files changed, 321 insertions(+) create mode 100644 docs/framework/operators/neural-network/nn.thresholded_relu.md create mode 100644 nodegen/node/thresholded_relu.py create mode 100644 src/operators/nn/functional/thresholded_relu.cairo create mode 100644 tests/src/nodes/thresholded_relu_fp16x16.cairo create mode 100644 tests/src/nodes/thresholded_relu_fp16x16/input_0.cairo create mode 100644 tests/src/nodes/thresholded_relu_fp16x16/output_0.cairo create mode 100644 tests/src/nodes/thresholded_relu_fp8x23.cairo create mode 100644 tests/src/nodes/thresholded_relu_fp8x23/input_0.cairo create mode 100644 tests/src/nodes/thresholded_relu_fp8x23/output_0.cairo diff --git a/docs/framework/operators/neural-network/nn.thresholded_relu.md b/docs/framework/operators/neural-network/nn.thresholded_relu.md new file mode 100644 index 000000000..85521efc4 --- /dev/null +++ b/docs/framework/operators/neural-network/nn.thresholded_relu.md @@ -0,0 +1,47 @@ +# NNTrait::thresholded_relu + +```rust + fn thresholded_relu(inputs: @Tensor, alpha: @T) -> Tensor +``` + +Applies the thresholded rectified linear unit (Thresholded ReLU) activation function element-wise to a given tensor. + +The Thresholded ReLU function is defined as f(x) = x if x > alpha, f(x) = 0 otherwise, where x is the input element. + +## Args +* `tensor`(`@Tensor`) - A snapshot of a tensor to which the Leaky ReLU function will be applied. +* `alpha`(`@T`) - A snapshot of a fixed point scalar that defines the alpha value of the Thresholded ReLU function. + +## Returns +A new fixed point tensor with the same shape as the input tensor and the Thresholded ReLU function applied element-wise. + +## Type Constraints + +Constrain input and output types to fixed point tensors. + +## Examples + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, FP8x23}; +use orion::operators::nn::{NNTrait, FP8x23NN}; +use orion::numbers::{FP8x23, FixedTrait}; + +fn thresholded_relu_example() -> Tensor { + let tensor = TensorTrait::::new( + shape: array![2, 2].span(), + data: array![ + FixedTrait::new(0, false), + FixedTrait::new(256, false), + FixedTrait::new(512, false), + FixedTrait::new(513, false), + ] + .span(), + ); + let alpha = FixedTrait::from_felt(256); // 1.0 + + return NNTrait::leaky_relu(@tensor, @alpha); +} +>>> [[0, 0], [512, 513]] +``` diff --git a/nodegen/node/thresholded_relu.py b/nodegen/node/thresholded_relu.py new file mode 100644 index 000000000..12715aa84 --- /dev/null +++ b/nodegen/node/thresholded_relu.py @@ -0,0 +1,44 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl, Trait + + +class Thresholded_relu(RunAll): + + @staticmethod + def leaky_thresholded_fp8x23(): + + alpha = 1.0 + + x = np.random.uniform(-5, 7, (2, 2)).astype(np.float64) + y = np.clip(x, alpha, np.inf) + y[y == alpha] = 0 + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "thresholded_relu_fp8x23" + make_node([x], [y], name) + make_test([x], y, "NNTrait::thresholded_relu(@input_0, @FixedTrait::new(256, false))", + name, Trait.NN) + + @staticmethod + def leaky_thresholded_fp16x16(): + + alpha = 1.0 + + x = np.random.uniform(-5, 7, (2, 2)).astype(np.float64) + y = np.clip(x, alpha, np.inf) + y[y == alpha] = 0 + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "thresholded_relu_fp16x16" + make_node([x], [y], name) + make_test([x], y, "NNTrait::thresholded_relu(@input_0, @FixedTrait::new(65536, false))", + name, Trait.NN) diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index 476616e30..7d7a5ebc8 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -451,4 +451,53 @@ trait NNTrait { /// ``` /// fn leaky_relu(inputs: @Tensor, alpha: @T) -> Tensor; + /// # NNTrait::thresholded_relu + /// + /// ```rust + /// fn thresholded_relu(inputs: @Tensor, alpha: @T) -> Tensor + /// ``` + /// + /// Applies the thresholded rectified linear unit (Thresholded ReLU) activation function element-wise to a given tensor. + /// + /// The Thresholded ReLU function is defined as f(x) = x if x > alpha, f(x) = 0 otherwise, where x is the input element. + /// + /// ## Args + /// * `tensor`(`@Tensor`) - A snapshot of a tensor to which the Leaky ReLU function will be applied. + /// * `alpha`(`@T`) - A snapshot of a fixed point scalar that defines the alpha value of the Thresholded ReLU function. + /// + /// ## Returns + /// A new fixed point tensor with the same shape as the input tensor and the Thresholded ReLU function applied element-wise. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point tensors. + /// + /// ## Examples + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, FP8x23}; + /// use orion::operators::nn::{NNTrait, FP8x23NN}; + /// use orion::numbers::{FP8x23, FixedTrait}; + /// + /// fn thresholded_relu_example() -> Tensor { + /// let tensor = TensorTrait::::new( + /// shape: array![2, 2].span(), + /// data: array![ + /// FixedTrait::new(0, false), + /// FixedTrait::new(256, false), + /// FixedTrait::new(512, false), + /// FixedTrait::new(513, false), + /// ] + /// .span(), + /// ); + /// let alpha = FixedTrait::from_felt(256); // 1.0 + /// + /// return NNTrait::leaky_relu(@tensor, @alpha); + /// } + /// >>> [[0, 0], [512, 513]] + /// ``` + /// + fn thresholded_relu(tensor: @Tensor, alpha: @T) -> Tensor; } diff --git a/src/operators/nn/functional.cairo b/src/operators/nn/functional.cairo index 272475c30..12916528c 100644 --- a/src/operators/nn/functional.cairo +++ b/src/operators/nn/functional.cairo @@ -6,3 +6,4 @@ mod softsign; mod softplus; mod linear; mod logsoftmax; +mod thresholded_relu; diff --git a/src/operators/nn/functional/thresholded_relu.cairo b/src/operators/nn/functional/thresholded_relu.cairo new file mode 100644 index 000000000..1f10db96f --- /dev/null +++ b/src/operators/nn/functional/thresholded_relu.cairo @@ -0,0 +1,38 @@ +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; + +use orion::numbers::NumberTrait; +use orion::operators::tensor::core::{Tensor, TensorTrait}; + +/// Cf: NNTrait::thresholded_relu docstring +fn thresholded_relu< + T, + MAG, + impl TTensor: TensorTrait, + impl TNumber: NumberTrait, + impl TPartialOrd: PartialOrd, + impl TCopy: Copy, + impl TDrop: Drop +>( + mut z: Tensor, alpha: @T +) -> Tensor { + let mut data_result = ArrayTrait::::new(); + + loop { + match z.data.pop_front() { + Option::Some(item) => { + if (*item) < (*alpha) { + data_result.append(NumberTrait::zero()); + } else { + data_result.append(*item); + }; + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::new(z.shape, data_result.span()); +} diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index a0094de29..e06894d08 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -42,4 +42,8 @@ impl FP16x16NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP16x16) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn thresholded_relu(tensor: @Tensor, alpha: @FP16x16) -> Tensor { + functional::thresholded_relu::thresholded_relu(*tensor, alpha) + } } diff --git a/src/operators/nn/implementations/nn_fp32x32.cairo b/src/operators/nn/implementations/nn_fp32x32.cairo index a550caa31..adc9bfd87 100644 --- a/src/operators/nn/implementations/nn_fp32x32.cairo +++ b/src/operators/nn/implementations/nn_fp32x32.cairo @@ -42,4 +42,8 @@ impl FP32x32NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP32x32) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn thresholded_relu(tensor: @Tensor, alpha: @FP32x32) -> Tensor { + functional::thresholded_relu::thresholded_relu(*tensor, alpha) + } } diff --git a/src/operators/nn/implementations/nn_fp64x64.cairo b/src/operators/nn/implementations/nn_fp64x64.cairo index bcd4be5db..30ac9f549 100644 --- a/src/operators/nn/implementations/nn_fp64x64.cairo +++ b/src/operators/nn/implementations/nn_fp64x64.cairo @@ -42,4 +42,8 @@ impl FP64x64NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP64x64) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn thresholded_relu(tensor: @Tensor, alpha: @FP64x64) -> Tensor { + functional::thresholded_relu::thresholded_relu(*tensor, alpha) + } } diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 305eeaba2..d8a37b030 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -42,4 +42,8 @@ impl FP8x23NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @FP8x23) -> Tensor { functional::leaky_relu::leaky_relu(*inputs, alpha) } + + fn thresholded_relu(tensor: @Tensor, alpha: @FP8x23) -> Tensor { + functional::thresholded_relu::thresholded_relu(*tensor, alpha) + } } diff --git a/src/operators/nn/implementations/nn_i32.cairo b/src/operators/nn/implementations/nn_i32.cairo index c74dfc268..9e22dce81 100644 --- a/src/operators/nn/implementations/nn_i32.cairo +++ b/src/operators/nn/implementations/nn_i32.cairo @@ -38,4 +38,8 @@ impl I32NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @i32) -> Tensor { panic(array!['not supported!']) } + + fn thresholded_relu(tensor: @Tensor, alpha: @i32) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_i8.cairo b/src/operators/nn/implementations/nn_i8.cairo index 12e9812a1..88273a7f3 100644 --- a/src/operators/nn/implementations/nn_i8.cairo +++ b/src/operators/nn/implementations/nn_i8.cairo @@ -38,4 +38,8 @@ impl I8NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @i8) -> Tensor { panic(array!['not supported!']) } + + fn thresholded_relu(tensor: @Tensor, alpha: @i8) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_u32.cairo b/src/operators/nn/implementations/nn_u32.cairo index 0da74d544..602b3d938 100644 --- a/src/operators/nn/implementations/nn_u32.cairo +++ b/src/operators/nn/implementations/nn_u32.cairo @@ -37,4 +37,8 @@ impl U32NN of NNTrait { fn leaky_relu(inputs: @Tensor, alpha: @u32) -> Tensor { panic(array!['not supported!']) } + + fn thresholded_relu(tensor: @Tensor, alpha: @u32) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/tests/src/nodes.cairo b/tests/src/nodes.cairo index 398edc4bf..9c1925af9 100644 --- a/tests/src/nodes.cairo +++ b/tests/src/nodes.cairo @@ -433,3 +433,5 @@ mod clip_i8_2d; mod clip_i8_3d; mod clip_u32_2d; mod clip_u32_3d; +mod thresholded_relu_fp16x16; +mod thresholded_relu_fp8x23; diff --git a/tests/src/nodes/thresholded_relu_fp16x16.cairo b/tests/src/nodes/thresholded_relu_fp16x16.cairo new file mode 100644 index 000000000..881e285eb --- /dev/null +++ b/tests/src/nodes/thresholded_relu_fp16x16.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_thresholded_relu_fp16x16() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = NNTrait::thresholded_relu(@input_0, @FixedTrait::new(65536, false)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/thresholded_relu_fp16x16/input_0.cairo b/tests/src/nodes/thresholded_relu_fp16x16/input_0.cairo new file mode 100644 index 000000000..8b3c534c8 --- /dev/null +++ b/tests/src/nodes/thresholded_relu_fp16x16/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 240273, sign: true }); + data.append(FP16x16 { mag: 61472, sign: true }); + data.append(FP16x16 { mag: 255480, sign: false }); + data.append(FP16x16 { mag: 300914, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/thresholded_relu_fp16x16/output_0.cairo b/tests/src/nodes/thresholded_relu_fp16x16/output_0.cairo new file mode 100644 index 000000000..cc607e6ec --- /dev/null +++ b/tests/src/nodes/thresholded_relu_fp16x16/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 255480, sign: false }); + data.append(FP16x16 { mag: 300914, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/thresholded_relu_fp8x23.cairo b/tests/src/nodes/thresholded_relu_fp8x23.cairo new file mode 100644 index 000000000..ae79fc5c8 --- /dev/null +++ b/tests/src/nodes/thresholded_relu_fp8x23.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP8x23NN; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_thresholded_relu_fp8x23() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = NNTrait::thresholded_relu(@input_0, @FixedTrait::new(256, false)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/thresholded_relu_fp8x23/input_0.cairo b/tests/src/nodes/thresholded_relu_fp8x23/input_0.cairo new file mode 100644 index 000000000..c9cd35061 --- /dev/null +++ b/tests/src/nodes/thresholded_relu_fp8x23/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 47000614, sign: false }); + data.append(FP8x23 { mag: 18049683, sign: false }); + data.append(FP8x23 { mag: 45758723, sign: false }); + data.append(FP8x23 { mag: 45541560, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/thresholded_relu_fp8x23/output_0.cairo b/tests/src/nodes/thresholded_relu_fp8x23/output_0.cairo new file mode 100644 index 000000000..b4d86023f --- /dev/null +++ b/tests/src/nodes/thresholded_relu_fp8x23/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 47000614, sign: false }); + data.append(FP8x23 { mag: 18049683, sign: false }); + data.append(FP8x23 { mag: 45758723, sign: false }); + data.append(FP8x23 { mag: 45541560, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file From 19266efbc50544eb825611f8858cf75b83d249cd Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:49:25 -0400 Subject: [PATCH 50/78] fix docs and tests --- .../operators/neural-network/nn.thresholded_relu.md | 6 +++--- nodegen/node/thresholded_relu.py | 4 ++-- src/operators/nn/core.cairo | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/framework/operators/neural-network/nn.thresholded_relu.md b/docs/framework/operators/neural-network/nn.thresholded_relu.md index 85521efc4..8c1206ce1 100644 --- a/docs/framework/operators/neural-network/nn.thresholded_relu.md +++ b/docs/framework/operators/neural-network/nn.thresholded_relu.md @@ -1,7 +1,7 @@ # NNTrait::thresholded_relu ```rust - fn thresholded_relu(inputs: @Tensor, alpha: @T) -> Tensor + fn thresholded_relu(tensor: @Tensor, alpha: @T) -> Tensor ``` Applies the thresholded rectified linear unit (Thresholded ReLU) activation function element-wise to a given tensor. @@ -35,7 +35,7 @@ fn thresholded_relu_example() -> Tensor { FixedTrait::new(0, false), FixedTrait::new(256, false), FixedTrait::new(512, false), - FixedTrait::new(513, false), + FixedTrait::new(257, false), ] .span(), ); @@ -43,5 +43,5 @@ fn thresholded_relu_example() -> Tensor { return NNTrait::leaky_relu(@tensor, @alpha); } ->>> [[0, 0], [512, 513]] +>>> [[0, 0], [512, 257]] ``` diff --git a/nodegen/node/thresholded_relu.py b/nodegen/node/thresholded_relu.py index 12715aa84..9785cc779 100644 --- a/nodegen/node/thresholded_relu.py +++ b/nodegen/node/thresholded_relu.py @@ -6,7 +6,7 @@ class Thresholded_relu(RunAll): @staticmethod - def leaky_thresholded_fp8x23(): + def thresholded_relu_fp8x23(): alpha = 1.0 @@ -25,7 +25,7 @@ def leaky_thresholded_fp8x23(): name, Trait.NN) @staticmethod - def leaky_thresholded_fp16x16(): + def thresholded_relu_fp16x16(): alpha = 1.0 diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index 7d7a5ebc8..6af204a3a 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -454,7 +454,7 @@ trait NNTrait { /// # NNTrait::thresholded_relu /// /// ```rust - /// fn thresholded_relu(inputs: @Tensor, alpha: @T) -> Tensor + /// fn thresholded_relu(tensor: @Tensor, alpha: @T) -> Tensor /// ``` /// /// Applies the thresholded rectified linear unit (Thresholded ReLU) activation function element-wise to a given tensor. @@ -488,7 +488,7 @@ trait NNTrait { /// FixedTrait::new(0, false), /// FixedTrait::new(256, false), /// FixedTrait::new(512, false), - /// FixedTrait::new(513, false), + /// FixedTrait::new(257, false), /// ] /// .span(), /// ); @@ -496,7 +496,7 @@ trait NNTrait { /// /// return NNTrait::leaky_relu(@tensor, @alpha); /// } - /// >>> [[0, 0], [512, 513]] + /// >>> [[0, 0], [512, 257]] /// ``` /// fn thresholded_relu(tensor: @Tensor, alpha: @T) -> Tensor; From 44700ab67b7cd03a3efb0d904367ae7047160d5b Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:54:09 -0400 Subject: [PATCH 51/78] minor bug fix --- src/operators/nn/functional/thresholded_relu.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/thresholded_relu.cairo b/src/operators/nn/functional/thresholded_relu.cairo index 1f10db96f..dd4ced172 100644 --- a/src/operators/nn/functional/thresholded_relu.cairo +++ b/src/operators/nn/functional/thresholded_relu.cairo @@ -22,7 +22,7 @@ fn thresholded_relu< loop { match z.data.pop_front() { Option::Some(item) => { - if (*item) < (*alpha) { + if (*item) <= (*alpha) { data_result.append(NumberTrait::zero()); } else { data_result.append(*item); From 094353038fc30ac2aeefc7042aff7682ddd4dccc Mon Sep 17 00:00:00 2001 From: 0xd3bs Date: Mon, 23 Oct 2023 16:34:13 -0300 Subject: [PATCH 52/78] feat: add identity --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 3 +- docs/framework/operators/tensor/README.md | 1 + .../operators/tensor/tensor.identity.md | 33 +++++++ nodegen/node/identity.py | 86 +++++++++++++++++++ src/operators/tensor/core.cairo | 44 ++++++++++ .../implementations/tensor_fp16x16.cairo | 5 ++ .../implementations/tensor_fp16x16wide.cairo | 5 ++ .../implementations/tensor_fp32x32.cairo | 5 ++ .../implementations/tensor_fp64x64.cairo | 5 ++ .../implementations/tensor_fp8x23.cairo | 5 ++ .../implementations/tensor_fp8x23wide.cairo | 5 ++ .../tensor/implementations/tensor_i32.cairo | 5 ++ .../tensor/implementations/tensor_i8.cairo | 5 ++ .../tensor/implementations/tensor_u32.cairo | 5 ++ tests/src/nodes.cairo | 5 ++ tests/src/nodes/identity_fP16x16.cairo | 19 ++++ .../src/nodes/identity_fP16x16/input_0.cairo | 18 ++++ .../src/nodes/identity_fP16x16/output_0.cairo | 18 ++++ tests/src/nodes/identity_fP8x23.cairo | 20 +++++ tests/src/nodes/identity_fP8x23/input_0.cairo | 18 ++++ .../src/nodes/identity_fP8x23/output_0.cairo | 18 ++++ tests/src/nodes/identity_i32.cairo | 20 +++++ tests/src/nodes/identity_i32/input_0.cairo | 17 ++++ tests/src/nodes/identity_i32/output_0.cairo | 17 ++++ tests/src/nodes/identity_i8.cairo | 20 +++++ tests/src/nodes/identity_i8/input_0.cairo | 17 ++++ tests/src/nodes/identity_i8/output_0.cairo | 17 ++++ tests/src/nodes/identity_u32.cairo | 20 +++++ tests/src/nodes/identity_u32/input_0.cairo | 16 ++++ tests/src/nodes/identity_u32/output_0.cairo | 16 ++++ 31 files changed, 488 insertions(+), 1 deletion(-) create mode 100644 docs/framework/operators/tensor/tensor.identity.md create mode 100644 nodegen/node/identity.py create mode 100644 tests/src/nodes/identity_fP16x16.cairo create mode 100644 tests/src/nodes/identity_fP16x16/input_0.cairo create mode 100644 tests/src/nodes/identity_fP16x16/output_0.cairo create mode 100644 tests/src/nodes/identity_fP8x23.cairo create mode 100644 tests/src/nodes/identity_fP8x23/input_0.cairo create mode 100644 tests/src/nodes/identity_fP8x23/output_0.cairo create mode 100644 tests/src/nodes/identity_i32.cairo create mode 100644 tests/src/nodes/identity_i32/input_0.cairo create mode 100644 tests/src/nodes/identity_i32/output_0.cairo create mode 100644 tests/src/nodes/identity_i8.cairo create mode 100644 tests/src/nodes/identity_i8/input_0.cairo create mode 100644 tests/src/nodes/identity_i8/output_0.cairo create mode 100644 tests/src/nodes/identity_u32.cairo create mode 100644 tests/src/nodes/identity_u32/input_0.cairo create mode 100644 tests/src/nodes/identity_u32/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 07de3f249..b5d966102 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -86,6 +86,7 @@ * [tensor.unsqueeze](framework/operators/tensor/tensor.unsqueeze.md) * [tensor.sign](framework/operators/tensor/tensor.sign.md) * [tensor.clip](framework/operators/tensor/tensor.clip.md) + * [tensor.identity](framework/operators/tensor/tensor.identity.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index a7b415cc6..a18efebf5 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -57,5 +57,6 @@ You can see below the list of current supported ONNX Operators: | [Unsqueeze](operators/tensor/tensor.unsqueeze.md) | :white\_check\_mark: | | [Sign](operators/tensor/tensor.sign.md) | :white\_check\_mark: | | [Clip](operators/tensor/tensor.clip.md) | :white\_check\_mark: | +| [Identity](operators/tensor/tensor.identity.md) | :white\_check\_mark: | -Current Operators support: **50/156 (32%)** +Current Operators support: **51/156 (33%)** diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index b33f330c7..a7261616e 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -82,6 +82,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.unsqueeze`](tensor.unsqueeze.md) | Inserts single-dimensional entries to the shape of an input tensor. | | [`tensor.sign`](tensor.sign.md) | Calculates the sign of the given input tensor element-wise. | | [`tensor.clip`](tensor.clip.md) | Clip operator limits the given input within an interval. | +| [`tensor.identity`](tensor.identity.md) | Return a Tensor with the same shape and contents as input. | ## Arithmetic Operations diff --git a/docs/framework/operators/tensor/tensor.identity.md b/docs/framework/operators/tensor/tensor.identity.md new file mode 100644 index 000000000..902972c84 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.identity.md @@ -0,0 +1,33 @@ +# tensor.identity + +```rust + fn identity(self: @Tensor) -> Tensor; +``` + +Return a Tensor with the same shape and contents as input. + +## Args + +* `self`(`@Tensor`) - Input tensor. + +## Returns + +A new `Tensor` to copy input into. + +## Example + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, FP16x16Tensor}; + +fn identity_example() -> Tensor { + let tensor = TensorTrait::::new( + shape: array![2, 2].span(), + data: array![1, 2, 3, 4].span(), + ); + let t_identity = tensor.identity(); + t_identity +} +>>> [[1 2] [3 4]] // A Tensor with the same shape and contents as input +``` diff --git a/nodegen/node/identity.py b/nodegen/node/identity.py new file mode 100644 index 000000000..eb46b93e9 --- /dev/null +++ b/nodegen/node/identity.py @@ -0,0 +1,86 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl + + +class Identity(RunAll): + + @staticmethod + def identity_fP8x23(): + def identity(): + x = np.array([[1, 2], [3, 4]]) + y = x + + x = Tensor(Dtype.FP8x23, x.shape, x.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, y.flatten()) + + name = "identity_fP8x23" + make_node([x], [y], name) + + make_test( + [x], y, "input_0.identity()", name) + identity() + + @staticmethod + def identity_fP16x16(): + def identity(): + x = np.array([[1, 2], [3, 4]]) + y = x + + x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, y.flatten()) + + name = "identity_fP16x16" + make_node([x], [y], name) + + make_test( + [x], y, "input_0.identity()", name) + identity() + + @staticmethod + def identity_i8(): + def identity(): + x = np.array([[1, 2], [3, 4]]) + y = x + + x = Tensor(Dtype.I8, x.shape, x.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "identity_i8" + make_node([x], [y], name) + + make_test( + [x], y, "input_0.identity()", name) + identity() + + @staticmethod + def identity_i32(): + def identity(): + x = np.array([[1, 2], [3, 4]]) + y = x + + x = Tensor(Dtype.I32, x.shape, x.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "identity_i32" + make_node([x], [y], name) + + make_test( + [x], y, "input_0.identity()", name) + identity() + + @staticmethod + def identity_u32(): + def identity(): + x = np.array([[1, 2], [3, 4]]) + y = x + + x = Tensor(Dtype.U32, x.shape, x.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "identity_u32" + make_node([x], [y], name) + + make_test( + [x], y, "input_0.identity()", name) + identity() \ No newline at end of file diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 535ed699c..7d5300a00 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -78,6 +78,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde { /// # tensor.new @@ -2669,6 +2670,41 @@ trait TensorTrait { /// ``` /// fn sign(self: @Tensor) -> Tensor; + /// # tensor.identity + /// + /// ```rust + /// fn identity(self: @Tensor) -> Tensor; + /// ``` + /// + /// Return a Tensor with the same shape and contents as input. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - Input tensor. + /// + /// ## Returns + /// + /// A new `Tensor` to copy input into. + /// + /// ## Example + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, FP16x16Tensor}; + /// + /// fn identity_example() -> Tensor { + /// let tensor = TensorTrait::::new( + /// shape: array![2, 2].span(), + /// data: array![1, 2, 3, 4].span(), + /// ); + /// let t_identity = tensor.identity(); + /// t_identity + /// } + /// >>> [[1 2] [3 4]] // A Tensor with the same shape and contents as input + /// ``` + /// + fn identity(self: @Tensor) -> Tensor; } /// Cf: TensorTrait::new docstring @@ -3243,3 +3279,11 @@ fn clip< return Tensor:: { shape: *self.shape, data: return_data.span() }; } + +/// Cf: TensorTrait::identity docstring +fn identity +( + self: @Tensor +) -> Tensor { + Tensor:: { shape: *self.shape, data: *self.data } +} \ No newline at end of file diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ac803904e..8833f3a0c 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -229,6 +229,11 @@ impl FP16x16Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 0a89fe72d..2d402df8a 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -237,6 +237,11 @@ impl FP16x16WTensor of TensorTrait { ) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 19557eba1..c0a60cf91 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -230,6 +230,11 @@ impl FP32x32Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 686d319b1..80d80277b 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -230,6 +230,11 @@ impl FP64x64Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index f46e96fb5..6ec80ad50 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -229,6 +229,11 @@ impl FP8x23Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 91ca9338d..67b117019 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -229,6 +229,11 @@ impl FP8x23WTensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index dd6ae6f59..e66f14701 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -228,6 +228,11 @@ impl I32Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index a4518cefb..6673caddb 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -227,6 +227,11 @@ impl I8Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 4d26a8634..4c8fad8a8 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -221,6 +221,11 @@ impl U32Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn identity(self: @Tensor) -> Tensor { + core::identity(self) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/tests/src/nodes.cairo b/tests/src/nodes.cairo index 398edc4bf..18a877f8b 100644 --- a/tests/src/nodes.cairo +++ b/tests/src/nodes.cairo @@ -433,3 +433,8 @@ mod clip_i8_2d; mod clip_i8_3d; mod clip_u32_2d; mod clip_u32_3d; +mod identity_fP16x16; +mod identity_fP8x23; +mod identity_i32; +mod identity_i8; +mod identity_u32; diff --git a/tests/src/nodes/identity_fP16x16.cairo b/tests/src/nodes/identity_fP16x16.cairo new file mode 100644 index 000000000..5a3fa9893 --- /dev/null +++ b/tests/src/nodes/identity_fP16x16.cairo @@ -0,0 +1,19 @@ +mod input_0; +mod output_0; + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_identity_fP16x16() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.identity(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/identity_fP16x16/input_0.cairo b/tests/src/nodes/identity_fP16x16/input_0.cairo new file mode 100644 index 000000000..621d26f75 --- /dev/null +++ b/tests/src/nodes/identity_fP16x16/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1, sign: false }); + data.append(FP16x16 { mag: 2, sign: false }); + data.append(FP16x16 { mag: 3, sign: false }); + data.append(FP16x16 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_fP16x16/output_0.cairo b/tests/src/nodes/identity_fP16x16/output_0.cairo new file mode 100644 index 000000000..c43d35cc6 --- /dev/null +++ b/tests/src/nodes/identity_fP16x16/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1, sign: false }); + data.append(FP16x16 { mag: 2, sign: false }); + data.append(FP16x16 { mag: 3, sign: false }); + data.append(FP16x16 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_fP8x23.cairo b/tests/src/nodes/identity_fP8x23.cairo new file mode 100644 index 000000000..0d525dd66 --- /dev/null +++ b/tests/src/nodes/identity_fP8x23.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_identity_fP8x23() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.identity(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/identity_fP8x23/input_0.cairo b/tests/src/nodes/identity_fP8x23/input_0.cairo new file mode 100644 index 000000000..788d470dc --- /dev/null +++ b/tests/src/nodes/identity_fP8x23/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1, sign: false }); + data.append(FP8x23 { mag: 2, sign: false }); + data.append(FP8x23 { mag: 3, sign: false }); + data.append(FP8x23 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_fP8x23/output_0.cairo b/tests/src/nodes/identity_fP8x23/output_0.cairo new file mode 100644 index 000000000..91c661c60 --- /dev/null +++ b/tests/src/nodes/identity_fP8x23/output_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1, sign: false }); + data.append(FP8x23 { mag: 2, sign: false }); + data.append(FP8x23 { mag: 3, sign: false }); + data.append(FP8x23 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_i32.cairo b/tests/src/nodes/identity_i32.cairo new file mode 100644 index 000000000..a9071a4d5 --- /dev/null +++ b/tests/src/nodes/identity_i32.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::I32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_identity_i32() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.identity(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/identity_i32/input_0.cairo b/tests/src/nodes/identity_i32/input_0.cairo new file mode 100644 index 000000000..235eb099e --- /dev/null +++ b/tests/src/nodes/identity_i32/input_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_i32/output_0.cairo b/tests/src/nodes/identity_i32/output_0.cairo new file mode 100644 index 000000000..a9b5e2a13 --- /dev/null +++ b/tests/src/nodes/identity_i32/output_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_i8.cairo b/tests/src/nodes/identity_i8.cairo new file mode 100644 index 000000000..29af4b5ea --- /dev/null +++ b/tests/src/nodes/identity_i8.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_identity_i8() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.identity(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/identity_i8/input_0.cairo b/tests/src/nodes/identity_i8/input_0.cairo new file mode 100644 index 000000000..93b08b452 --- /dev/null +++ b/tests/src/nodes/identity_i8/input_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_i8/output_0.cairo b/tests/src/nodes/identity_i8/output_0.cairo new file mode 100644 index 000000000..1aac8f583 --- /dev/null +++ b/tests/src/nodes/identity_i8/output_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_u32.cairo b/tests/src/nodes/identity_u32.cairo new file mode 100644 index 000000000..548b41804 --- /dev/null +++ b/tests/src/nodes/identity_u32.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_identity_u32() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.identity(); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/identity_u32/input_0.cairo b/tests/src/nodes/identity_u32/input_0.cairo new file mode 100644 index 000000000..9594bde53 --- /dev/null +++ b/tests/src/nodes/identity_u32/input_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/identity_u32/output_0.cairo b/tests/src/nodes/identity_u32/output_0.cairo new file mode 100644 index 000000000..c760602db --- /dev/null +++ b/tests/src/nodes/identity_u32/output_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file From 298eac22d1a56c12b8364b40961609ddcbcc553a Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Mon, 23 Oct 2023 21:32:52 -0400 Subject: [PATCH 53/78] add wide impls --- src/numbers.cairo | 8 ++++++++ .../tensor/implementations/tensor_fp16x16wide.cairo | 4 ++++ .../tensor/implementations/tensor_fp8x23wide.cairo | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/src/numbers.cairo b/src/numbers.cairo index 77afa1409..c861c39e8 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -346,6 +346,10 @@ impl FP8x23WNumber of NumberTrait { core_fp8x23wide::abs(self) } + fn neg(self: FP8x23W) -> FP8x23W { + core_fp8x23wide::neg(self) + } + fn min_value() -> FP8x23W { FP8x23W { mag: core_fp8x23wide::MAX, sign: true } } @@ -680,6 +684,10 @@ impl FP16x16WNumber of NumberTrait { core_fp16x16wide::abs(self) } + fn neg(self: FP16x16W) -> FP16x16W { + core_fp16x16wide::neg(self) + } + fn min_value() -> FP16x16W { FP16x16W { mag: core_fp16x16wide::MAX, sign: true } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 0a89fe72d..69a39703c 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -106,6 +106,10 @@ impl FP16x16WTensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { math::ceil::ceil(*self) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 91ca9338d..6a9358a09 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -100,6 +100,10 @@ impl FP8x23WTensor of TensorTrait { math::abs::abs(*self) } + fn neg(self: @Tensor) -> Tensor { + math::neg::neg(*self) + } + fn ceil(self: @Tensor) -> Tensor { math::ceil::ceil(*self) } From f862acb7e2dbba7535608ec4fd00e6e363b24ce6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 24 Oct 2023 10:45:54 +0300 Subject: [PATCH 54/78] add missing element in doc --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + docs/framework/operators/tensor/README.md | 1 + src/operators/tensor/core.cairo | 1 + 4 files changed, 4 insertions(+) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 07de3f249..50dac5b30 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -59,6 +59,7 @@ * [tensor.less](framework/operators/tensor/tensor.less.md) * [tensor.less\_equal](framework/operators/tensor/tensor.less\_equal.md) * [tensor.abs](framework/operators/tensor/tensor.abs.md) + * [tensor.neg](framework/operators/tensor/tensor.neg.md) * [tensor.ceil](framework/operators/tensor/tensor.ceil.md) * [tensor.cumsum](framework/operators/tensor/tensor.cumsum.md) * [tensor.sin](framework/operators/tensor/tensor.sin.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index a7b415cc6..41e40258d 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -18,6 +18,7 @@ You can see below the list of current supported ONNX Operators: | [Less](operators/tensor/tensor.less.md) | :white\_check\_mark: | | [LessOrEqual](operators/tensor/tensor.less\_equal.md) | :white\_check\_mark: | | [Abs](operators/tensor/tensor.abs.md) | :white\_check\_mark: | +| [Neg](operators/tensor/tensor.neg.md) | :white\_check\_mark: | | [Ceil](operators/tensor/tensor.ceil.md) | :white\_check\_mark: | | [Exp](operators/tensor/tensor.exp.md) | :white\_check\_mark: | | [Ln](operators/tensor/tensor.log.md) | :white\_check\_mark: | diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index b33f330c7..2e5d727d6 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -60,6 +60,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.exp`](tensor.exp.md) | Computes the exponential of all elements of the input tensor. | | [`tensor.log`](tensor.log.md) | Computes the natural log of all elements of the input tensor. | | [`tensor.abs`](tensor.abs.md) | Computes the absolute value of all elements in the input tensor. | +| [`tensor.neg`](tensor.neg.md) | Computes the negation of all elements in the input tensor. | | [`tensor.ceil`](tensor.ceil.md) | Rounds up the value of each element in the input tensor. | | [`tensor.sqrt`](tensor.sqrt.md) | Computes the square root of all elements of the input tensor. | | [`tensor.sin`](tensor.sin.md) | Computes the sine of all elements of the input tensor. | diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 4b9c9ed52..bfddbc8fc 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -56,6 +56,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde Date: Tue, 24 Oct 2023 04:31:08 -0400 Subject: [PATCH 55/78] docs --- docs/framework/operators/neural-network/README.md | 1 + src/operators/nn/core.cairo | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/framework/operators/neural-network/README.md b/docs/framework/operators/neural-network/README.md index 71b4b17f0..83dafaf16 100644 --- a/docs/framework/operators/neural-network/README.md +++ b/docs/framework/operators/neural-network/README.md @@ -31,4 +31,5 @@ Orion supports currently these `NN` types. | [`nn.softsign`](nn.softsign.md) | Applies the Softsign function element-wise. | | [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. | | [`nn.linear`](nn.linear.md) | Performs a linear transformation of the input tensor using the provided weights and bias. | +| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | performs the thresholded relu activation function element-wise. | diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index 6af204a3a..9d29f750c 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -10,6 +10,7 @@ use orion::operators::tensor::core::Tensor; /// softsign - Applies the Softsign function element-wise. /// softplus - Applies the Softplus function element-wise. /// linear - Performs a linear transformation of the input tensor using the provided weights and bias. +/// thresholded_relu - performs the thresholded relu activation function element-wise. trait NNTrait { /// # NNTrait::relu /// From 010e30a46130c48cd8d5d4ba92969721b8b9d2e8 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:32:46 -0400 Subject: [PATCH 56/78] docs --- docs/framework/operators/neural-network/README.md | 1 + src/operators/nn/core.cairo | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/framework/operators/neural-network/README.md b/docs/framework/operators/neural-network/README.md index 71b4b17f0..9c6a9fe56 100644 --- a/docs/framework/operators/neural-network/README.md +++ b/docs/framework/operators/neural-network/README.md @@ -31,4 +31,5 @@ Orion supports currently these `NN` types. | [`nn.softsign`](nn.softsign.md) | Applies the Softsign function element-wise. | | [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. | | [`nn.linear`](nn.linear.md) | Performs a linear transformation of the input tensor using the provided weights and bias. | +| [`nn.hard_sigmoid`](nn.hard\_sigmoid.md) | Applies the Hard Sigmoid function to an n-dimensional input tensor. | diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index ce49540a0..7511cedf1 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -10,6 +10,7 @@ use orion::operators::tensor::core::Tensor; /// softsign - Applies the Softsign function element-wise. /// softplus - Applies the Softplus function element-wise. /// linear - Performs a linear transformation of the input tensor using the provided weights and bias. +/// hard_sigmoid - Applies the Hard Sigmoid function to an n-dimensional input tensor. trait NNTrait { /// # NNTrait::relu /// From 10325f9a358aa8093a0eab64ad2f9afd1897d41e Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:37:09 -0400 Subject: [PATCH 57/78] summary + compatibility docs --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 07de3f249..422ba6fa1 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -95,6 +95,7 @@ * [nn.softsign](framework/operators/neural-network/nn.softsign.md) * [nn.softplus](framework/operators/neural-network/nn.softplus.md) * [nn.linear](framework/operators/neural-network/nn.linear.md) + * [nn.hard_sigmoid](framework/operators/neural-network/nn.hard_sigmoid.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index a7b415cc6..d435662d6 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -39,6 +39,7 @@ You can see below the list of current supported ONNX Operators: | [Softsign](operators/neural-network/nn.softsign.md) | :white\_check\_mark: | | [Softplus](operators/neural-network/nn.softplus.md) | :white\_check\_mark: | | [Linear](operators/neural-network/nn.linear.md) | :white\_check\_mark: | +| [HardSigmoid](operators/neural-network/nn.hard_sigmoid.md) | :white\_check\_mark: | | [Sinh](operators/tensor/tensor.sinh.md) | :white\_check\_mark: | | [Asinh](operators/tensor/tensor.asinh.md) | :white\_check\_mark: | | [Cosh](operators/tensor/tensor.cosh.md) | :white\_check\_mark: | From cf1d465acfc1658000d6a2b68444d140326f0996 Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:39:58 -0400 Subject: [PATCH 58/78] summary + compatibility --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 07de3f249..a33d1193e 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -95,6 +95,7 @@ * [nn.softsign](framework/operators/neural-network/nn.softsign.md) * [nn.softplus](framework/operators/neural-network/nn.softplus.md) * [nn.linear](framework/operators/neural-network/nn.linear.md) + * [nn.thresholded\_relu](framework/operators/neural-network/nn.thresholded_relu.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index a7b415cc6..0936a4168 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -33,6 +33,7 @@ You can see below the list of current supported ONNX Operators: | [Flatten](operators/tensor/tensor.flatten.md) | :white\_check\_mark: | | [Relu](operators/neural-network/nn.relu.md) | :white\_check\_mark: | | [LeakyRelu](operators/neural-network/nn.leaky\_relu.md) | :white\_check\_mark: | +|[ThresholdedRelu](operators/neural-network/nn.thresholded\_relu.md)| :white\_check\_mark: | | [Sigmoid](operators/neural-network/nn.sigmoid.md) | :white\_check\_mark: | | [Softmax](operators/neural-network/nn.softmax.md) | :white\_check\_mark: | | [LogSoftmax](operators/neural-network/nn.logsoftmax.md) | :white\_check\_mark: | From ce0d4f5f1baabb6bbcff712bb22475bc4f7a4b3d Mon Sep 17 00:00:00 2001 From: 0x73e <132935850+0x73e@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:41:17 -0400 Subject: [PATCH 59/78] formatting --- docs/SUMMARY.md | 2 +- docs/framework/compatibility.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 422ba6fa1..bac31f592 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -95,7 +95,7 @@ * [nn.softsign](framework/operators/neural-network/nn.softsign.md) * [nn.softplus](framework/operators/neural-network/nn.softplus.md) * [nn.linear](framework/operators/neural-network/nn.linear.md) - * [nn.hard_sigmoid](framework/operators/neural-network/nn.hard_sigmoid.md) + * [nn.hard\_sigmoid](framework/operators/neural-network/nn.hard\_sigmoid.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index d435662d6..3b3e611a7 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -39,7 +39,7 @@ You can see below the list of current supported ONNX Operators: | [Softsign](operators/neural-network/nn.softsign.md) | :white\_check\_mark: | | [Softplus](operators/neural-network/nn.softplus.md) | :white\_check\_mark: | | [Linear](operators/neural-network/nn.linear.md) | :white\_check\_mark: | -| [HardSigmoid](operators/neural-network/nn.hard_sigmoid.md) | :white\_check\_mark: | +| [HardSigmoid](operators/neural-network/nn.hard\_sigmoid.md) | :white\_check\_mark: | | [Sinh](operators/tensor/tensor.sinh.md) | :white\_check\_mark: | | [Asinh](operators/tensor/tensor.asinh.md) | :white\_check\_mark: | | [Cosh](operators/tensor/tensor.cosh.md) | :white\_check\_mark: | From 71d96622e1a0cdec17cea16858bffdf9532b3065 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:30:08 +0000 Subject: [PATCH 60/78] docs: update README.md [skip ci] --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f4ca1efa..177ca9746 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Orion: An Open-source Framework for Validity and ZK ML ✨ -[![All Contributors](https://img.shields.io/badge/all_contributors-15-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-) Orion is an open-source, community-driven framework dedicated to Provable Machine Learning. It provides essential components and a new ONNX runtime for building verifiable Machine Learning models using [STARKs](https://starkware.co/stark/). @@ -82,10 +82,13 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d BemTG
BemTG

💻 📖 danilowhk
danilowhk

💻 Falco R
Falco R

💻 + dincerguner
dincerguner

💻 Rich Warner
Rich Warner

💻 Daniel Bejarano
Daniel Bejarano

📖 + + vikkydataseo
vikkydataseo

📖 - dincerguner
dincerguner

💻 + 0x73e
0x73e

💻 From 0ac3587d516534d9ec2b7c3c57d409034bec7de8 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:30:09 +0000 Subject: [PATCH 61/78] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index fc389d5f5..f4a5e47e0 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -143,6 +143,15 @@ "contributions": [ "doc" ] + }, + { + "login": "0x73e", + "name": "0x73e", + "avatar_url": "https://avatars.githubusercontent.com/u/132935850?v=4", + "profile": "https://github.com/0x73e", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, @@ -152,4 +161,4 @@ "projectName": "orion", "projectOwner": "gizatechxyz", "commitType": "docs" -} \ No newline at end of file +} From fd0dc3e3f0fac3a92777d9a5d0a1b9ca75fb56fb Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:31:38 +0000 Subject: [PATCH 62/78] docs: update README.md [skip ci] --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f4ca1efa..4b8ddad0c 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Orion: An Open-source Framework for Validity and ZK ML ✨ -[![All Contributors](https://img.shields.io/badge/all_contributors-15-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-) Orion is an open-source, community-driven framework dedicated to Provable Machine Learning. It provides essential components and a new ONNX runtime for building verifiable Machine Learning models using [STARKs](https://starkware.co/stark/). @@ -82,10 +82,13 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d BemTG
BemTG

💻 📖 danilowhk
danilowhk

💻 Falco R
Falco R

💻 + dincerguner
dincerguner

💻 Rich Warner
Rich Warner

💻 Daniel Bejarano
Daniel Bejarano

📖 + + vikkydataseo
vikkydataseo

📖 - dincerguner
dincerguner

💻 + 0xfulanito
0xfulanito

💻 From 5181d097e59bce9e6c54837cbd27d94bd708c936 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:31:39 +0000 Subject: [PATCH 63/78] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index fc389d5f5..8b4ca987a 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -143,6 +143,15 @@ "contributions": [ "doc" ] + }, + { + "login": "0xfulanito", + "name": "0xfulanito", + "avatar_url": "https://avatars.githubusercontent.com/u/145947367?v=4", + "profile": "https://github.com/0xfulanito", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, @@ -152,4 +161,4 @@ "projectName": "orion", "projectOwner": "gizatechxyz", "commitType": "docs" -} \ No newline at end of file +} From ea2df67ddf178c407ce62b930da4d3591da588c5 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:34:27 +0000 Subject: [PATCH 64/78] docs: update README.md [skip ci] --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f4ca1efa..650031fc0 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Orion: An Open-source Framework for Validity and ZK ML ✨ -[![All Contributors](https://img.shields.io/badge/all_contributors-15-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-) Orion is an open-source, community-driven framework dedicated to Provable Machine Learning. It provides essential components and a new ONNX runtime for building verifiable Machine Learning models using [STARKs](https://starkware.co/stark/). @@ -82,10 +82,13 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d BemTG
BemTG

💻 📖 danilowhk
danilowhk

💻 Falco R
Falco R

💻 + dincerguner
dincerguner

💻 Rich Warner
Rich Warner

💻 Daniel Bejarano
Daniel Bejarano

📖 + + vikkydataseo
vikkydataseo

📖 - dincerguner
dincerguner

💻 + Charlotte
Charlotte

💻 From 2633a67c724488dd0a55a099eebdcb4bc914a54d Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:34:28 +0000 Subject: [PATCH 65/78] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index fc389d5f5..157873870 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -143,6 +143,15 @@ "contributions": [ "doc" ] + }, + { + "login": "chachaleo", + "name": "Charlotte", + "avatar_url": "https://avatars.githubusercontent.com/u/49371958?v=4", + "profile": "https://github.com/chachaleo", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, @@ -152,4 +161,4 @@ "projectName": "orion", "projectOwner": "gizatechxyz", "commitType": "docs" -} \ No newline at end of file +} From 4232778554e6d50b7f80d0a4475ef0962d13dc89 Mon Sep 17 00:00:00 2001 From: raphaelDkhn <113879115+raphaelDkhn@users.noreply.github.com> Date: Tue, 24 Oct 2023 12:37:33 +0300 Subject: [PATCH 66/78] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bbe23a8c7..7bf3eab04 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Orion: An Open-source Framework for Validity and ZK ML ✨ -[![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-18-orange.svg?style=flat-square)](#contributors-) Orion is an open-source, community-driven framework dedicated to Provable Machine Learning. It provides essential components and a new ONNX runtime for building verifiable Machine Learning models using [STARKs](https://starkware.co/stark/). From d2b4ed0934f685dd9955d32a9948d470306f66f9 Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Tue, 24 Oct 2023 14:07:56 +0200 Subject: [PATCH 67/78] Implement 'and' operator --- src/numbers.cairo | 65 +++++++++++++++++++ .../implementations/fp16x16/math/comp.cairo | 9 +++ .../implementations/fp32x32/comp.cairo | 9 +++ .../implementations/fp64x64/comp.cairo | 9 +++ .../implementations/fp8x23/math/comp.cairo | 9 +++ src/operators/tensor/core.cairo | 1 + .../implementations/tensor_fp16x16.cairo | 4 ++ .../implementations/tensor_fp32x32.cairo | 4 ++ .../implementations/tensor_fp64x64.cairo | 4 ++ .../implementations/tensor_fp8x23.cairo | 4 ++ .../tensor/implementations/tensor_i32.cairo | 4 ++ .../tensor/implementations/tensor_i8.cairo | 4 ++ .../tensor/implementations/tensor_u32.cairo | 4 ++ src/operators/tensor/math.cairo | 1 + src/operators/tensor/math/and.cairo | 46 +++++++++++++ 15 files changed, 177 insertions(+) create mode 100644 src/operators/tensor/math/and.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index 04ad6efa5..ee89ff4b1 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -46,6 +46,7 @@ trait NumberTrait { fn xor(lhs: T, rhs: T) -> bool; fn or(lhs: T, rhs: T) -> bool; fn sign(self: T) -> T; + fn and(lhs: T, rhs: T) -> bool; } use orion::numbers::fixed_point::implementations::fp8x23::core::{FP8x23Impl, FP8x23}; @@ -211,6 +212,10 @@ impl FP8x23Number of NumberTrait { fn sign(self: FP8x23) -> FP8x23 { core_fp8x23::sign(self) } + + fn and(lhs: FP8x23, rhs: FP8x23) -> bool { + comp_fp8x23::and(lhs, rhs) + } } use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16}; @@ -376,6 +381,10 @@ impl FP16x16Number of NumberTrait { fn sign(self: FP16x16) -> FP16x16 { core_fp16x16::sign(self) } + + fn and(lhs: FP16x16, rhs: FP16x16) -> bool { + comp_fp16x16::and(lhs, rhs) + } } use orion::numbers::fixed_point::implementations::fp64x64::core::{FP64x64Impl, FP64x64}; @@ -542,6 +551,10 @@ impl FP64x64Number of NumberTrait { fn sign(self: FP64x64) -> FP64x64 { FP64x64Impl::sign(self) } + + fn and(lhs: FP64x64, rhs: FP64x64) -> bool { + comp_fp64x64::and(lhs, rhs) + } } use orion::numbers::fixed_point::implementations::fp32x32::core::{FP32x32Impl, FP32x32}; @@ -708,6 +721,10 @@ impl FP32x32Number of NumberTrait { fn sign(self: FP32x32) -> FP32x32 { FP32x32Impl::sign(self) } + + fn and(lhs: FP32x32, rhs: FP32x32) -> bool { + comp_fp32x32::and(lhs, rhs) + } } use orion::numbers::signed_integer::i8 as i8_core; @@ -880,6 +897,14 @@ impl I8Number of NumberTrait { fn sign(self: i8) -> i8 { i8_core::i8_sign(self) } + + fn and(lhs: i8, rhs: i8) -> bool { + if (lhs.mag == 0 || rhs.mag == 0) { + return false; + } else { + return true; + } + } } use orion::numbers::signed_integer::i16 as i16_core; @@ -1052,6 +1077,14 @@ impl i16Number of NumberTrait { fn sign(self: i16) -> i16 { i16_core::i16_sign(self) } + + fn and(lhs: i16, rhs: i16) -> bool { + if (lhs.mag == 0 || rhs.mag == 0) { + return false; + } else { + return true; + } + } } use orion::numbers::signed_integer::i32 as i32_core; @@ -1224,6 +1257,14 @@ impl i32Number of NumberTrait { fn sign(self: i32) -> i32 { i32_core::i32_sign(self) } + + fn and(lhs: i32, rhs: i32) -> bool { + if (lhs.mag == 0 || rhs.mag == 0) { + return false; + } else { + return true; + } + } } use orion::numbers::signed_integer::i64 as i64_core; @@ -1396,6 +1437,14 @@ impl i64Number of NumberTrait { fn sign(self: i64) -> i64 { i64_core::i64_sign(self) } + + fn and(lhs: i64, rhs: i64) -> bool { + if (lhs.mag == 0 || rhs.mag == 0) { + return false; + } else { + return true; + } + } } use orion::numbers::signed_integer::i128 as i128_core; @@ -1569,6 +1618,14 @@ impl i128Number of NumberTrait { fn sign(self: i128) -> i128 { i128_core::i128_sign(self) } + + fn and(lhs: i128, rhs: i128) -> bool { + if (lhs.mag == 0 || rhs.mag == 0) { + return false; + } else { + return true; + } + } } impl u32Number of NumberTrait { @@ -1747,4 +1804,12 @@ impl u32Number of NumberTrait { fn sign(self: u32) -> u32 { panic(array!['not supported!']) } + + fn and(lhs: u32, rhs: u32) -> bool { + if (lhs == 0 || rhs == 0) { + return false; + } else { + return true; + } + } } diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/comp.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/comp.cairo index 3aa3219b5..cf1e6ec08 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math/comp.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math/comp.cairo @@ -35,6 +35,15 @@ fn or(a: FP16x16, b: FP16x16) -> bool { } } +fn and(a: FP16x16, b: FP16x16) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero || b == zero { + return false; + } else { + return true; + } +} + // Tests -------------------------------------------------------------------------------------------------------------- #[test] diff --git a/src/numbers/fixed_point/implementations/fp32x32/comp.cairo b/src/numbers/fixed_point/implementations/fp32x32/comp.cairo index 60cae2ab3..de2b7d9a8 100644 --- a/src/numbers/fixed_point/implementations/fp32x32/comp.cairo +++ b/src/numbers/fixed_point/implementations/fp32x32/comp.cairo @@ -17,3 +17,12 @@ fn or(a: FP32x32, b: FP32x32) -> bool { return true; } } + +fn and(a: FP32x32, b: FP32x32) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero || b == zero { + return false; + } else { + return true; + } +} diff --git a/src/numbers/fixed_point/implementations/fp64x64/comp.cairo b/src/numbers/fixed_point/implementations/fp64x64/comp.cairo index d26a3180f..9cae44352 100644 --- a/src/numbers/fixed_point/implementations/fp64x64/comp.cairo +++ b/src/numbers/fixed_point/implementations/fp64x64/comp.cairo @@ -17,3 +17,12 @@ fn or(a: FP64x64, b: FP64x64) -> bool { return true; } } + +fn and(a: FP64x64, b: FP64x64) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero || b == zero { + return false; + } else { + return true; + } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23/math/comp.cairo b/src/numbers/fixed_point/implementations/fp8x23/math/comp.cairo index 53b1f2e1d..c2158965a 100644 --- a/src/numbers/fixed_point/implementations/fp8x23/math/comp.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23/math/comp.cairo @@ -35,6 +35,15 @@ fn or(a: FP8x23, b: FP8x23) -> bool { } } +fn and(a: FP8x23, b: FP8x23) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero || b == zero { + return false; + } else { + return true; + } +} + // Tests -------------------------------------------------------------------------------------------------------------- #[test] diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 535ed699c..a161087db 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -2669,6 +2669,7 @@ trait TensorTrait { /// ``` /// fn sign(self: @Tensor) -> Tensor; + fn and(self: @Tensor, other: @Tensor) -> Tensor; } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ac803904e..c81b08bfc 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -229,6 +229,10 @@ impl FP16x16Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 19557eba1..45329dd74 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -230,6 +230,10 @@ impl FP32x32Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 686d319b1..2a1e0a88f 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -230,6 +230,10 @@ impl FP64x64Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index f46e96fb5..0ea96041c 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -229,6 +229,10 @@ impl FP8x23Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index dd6ae6f59..66d13c111 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -228,6 +228,10 @@ impl I32Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index a4518cefb..d1dfa1cb0 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -227,6 +227,10 @@ impl I8Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 4d26a8634..e073086b1 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -221,6 +221,10 @@ impl U32Tensor of TensorTrait { fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { core::clip(self, min, max) } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index 5ac834280..cc101b696 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -32,3 +32,4 @@ mod sqrt; mod concat; mod gather; mod sign; +mod and; diff --git a/src/operators/tensor/math/and.cairo b/src/operators/tensor/math/and.cairo new file mode 100644 index 000000000..dba0b0b1d --- /dev/null +++ b/src/operators/tensor/math/and.cairo @@ -0,0 +1,46 @@ +use array::ArrayTrait; +use option::OptionTrait; +use array::SpanTrait; + +use orion::numbers::NumberTrait; +use orion::operators::tensor::core::{Tensor, TensorTrait, unravel_index}; +use orion::operators::tensor::helpers::{ + broadcast_shape, broadcast_index_mapping, len_from_shape, check_compatibility +}; + +fn and< + T, + MAG, + impl TNumber: NumberTrait, + impl UsizeFTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop +>( + y: @Tensor, z: @Tensor +) -> Tensor { + let broadcasted_shape = broadcast_shape(*y.shape, *z.shape); + let mut result: Array = ArrayTrait::new(); + + let num_elements = len_from_shape(broadcasted_shape); + + let mut n: usize = 0; + loop { + let indices_broadcasted = unravel_index(n, broadcasted_shape); + + let indices_self = broadcast_index_mapping(*y.shape, indices_broadcasted); + let indices_other = broadcast_index_mapping(*z.shape, indices_broadcasted); + + if NumberTrait::and(*(*y.data)[indices_self], *(*z.data)[indices_other]) { + result.append(1); + } else { + result.append(0); + } + + n += 1; + if n == num_elements { + break (); + }; + }; + + return TensorTrait::new(broadcasted_shape, result.span()); +} From bbf9833c437b83c7b4dd1fd4dde194717321628a Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Tue, 24 Oct 2023 15:16:58 +0200 Subject: [PATCH 68/78] Add nodegen tests --- nodegen/node/and.py | 168 ++++++++++++++++++ tests/src/nodes.cairo | 10 ++ tests/src/nodes/and_fp16x16.cairo | 22 +++ tests/src/nodes/and_fp16x16/input_0.cairo | 42 +++++ tests/src/nodes/and_fp16x16/input_1.cairo | 42 +++++ tests/src/nodes/and_fp16x16/output_0.cairo | 40 +++++ tests/src/nodes/and_fp16x16_broadcast.cairo | 22 +++ .../nodes/and_fp16x16_broadcast/input_0.cairo | 18 ++ .../nodes/and_fp16x16_broadcast/input_1.cairo | 16 ++ .../and_fp16x16_broadcast/output_0.cairo | 16 ++ tests/src/nodes/and_fp8x23.cairo | 22 +++ tests/src/nodes/and_fp8x23/input_0.cairo | 42 +++++ tests/src/nodes/and_fp8x23/input_1.cairo | 42 +++++ tests/src/nodes/and_fp8x23/output_0.cairo | 40 +++++ tests/src/nodes/and_fp8x23_broadcast.cairo | 22 +++ .../nodes/and_fp8x23_broadcast/input_0.cairo | 18 ++ .../nodes/and_fp8x23_broadcast/input_1.cairo | 16 ++ .../nodes/and_fp8x23_broadcast/output_0.cairo | 16 ++ tests/src/nodes/and_i32.cairo | 22 +++ tests/src/nodes/and_i32/input_0.cairo | 41 +++++ tests/src/nodes/and_i32/input_1.cairo | 41 +++++ tests/src/nodes/and_i32/output_0.cairo | 40 +++++ tests/src/nodes/and_i32_broadcast.cairo | 22 +++ .../src/nodes/and_i32_broadcast/input_0.cairo | 17 ++ .../src/nodes/and_i32_broadcast/input_1.cairo | 15 ++ .../nodes/and_i32_broadcast/output_0.cairo | 16 ++ tests/src/nodes/and_i8.cairo | 22 +++ tests/src/nodes/and_i8/input_0.cairo | 41 +++++ tests/src/nodes/and_i8/input_1.cairo | 41 +++++ tests/src/nodes/and_i8/output_0.cairo | 40 +++++ tests/src/nodes/and_i8_broadcast.cairo | 22 +++ .../src/nodes/and_i8_broadcast/input_0.cairo | 17 ++ .../src/nodes/and_i8_broadcast/input_1.cairo | 15 ++ .../src/nodes/and_i8_broadcast/output_0.cairo | 16 ++ tests/src/nodes/and_u32.cairo | 22 +++ tests/src/nodes/and_u32/input_0.cairo | 40 +++++ tests/src/nodes/and_u32/input_1.cairo | 40 +++++ tests/src/nodes/and_u32/output_0.cairo | 40 +++++ tests/src/nodes/and_u32_broadcast.cairo | 22 +++ .../src/nodes/and_u32_broadcast/input_0.cairo | 16 ++ .../src/nodes/and_u32_broadcast/input_1.cairo | 14 ++ .../nodes/and_u32_broadcast/output_0.cairo | 16 ++ 42 files changed, 1252 insertions(+) create mode 100644 nodegen/node/and.py create mode 100644 tests/src/nodes/and_fp16x16.cairo create mode 100644 tests/src/nodes/and_fp16x16/input_0.cairo create mode 100644 tests/src/nodes/and_fp16x16/input_1.cairo create mode 100644 tests/src/nodes/and_fp16x16/output_0.cairo create mode 100644 tests/src/nodes/and_fp16x16_broadcast.cairo create mode 100644 tests/src/nodes/and_fp16x16_broadcast/input_0.cairo create mode 100644 tests/src/nodes/and_fp16x16_broadcast/input_1.cairo create mode 100644 tests/src/nodes/and_fp16x16_broadcast/output_0.cairo create mode 100644 tests/src/nodes/and_fp8x23.cairo create mode 100644 tests/src/nodes/and_fp8x23/input_0.cairo create mode 100644 tests/src/nodes/and_fp8x23/input_1.cairo create mode 100644 tests/src/nodes/and_fp8x23/output_0.cairo create mode 100644 tests/src/nodes/and_fp8x23_broadcast.cairo create mode 100644 tests/src/nodes/and_fp8x23_broadcast/input_0.cairo create mode 100644 tests/src/nodes/and_fp8x23_broadcast/input_1.cairo create mode 100644 tests/src/nodes/and_fp8x23_broadcast/output_0.cairo create mode 100644 tests/src/nodes/and_i32.cairo create mode 100644 tests/src/nodes/and_i32/input_0.cairo create mode 100644 tests/src/nodes/and_i32/input_1.cairo create mode 100644 tests/src/nodes/and_i32/output_0.cairo create mode 100644 tests/src/nodes/and_i32_broadcast.cairo create mode 100644 tests/src/nodes/and_i32_broadcast/input_0.cairo create mode 100644 tests/src/nodes/and_i32_broadcast/input_1.cairo create mode 100644 tests/src/nodes/and_i32_broadcast/output_0.cairo create mode 100644 tests/src/nodes/and_i8.cairo create mode 100644 tests/src/nodes/and_i8/input_0.cairo create mode 100644 tests/src/nodes/and_i8/input_1.cairo create mode 100644 tests/src/nodes/and_i8/output_0.cairo create mode 100644 tests/src/nodes/and_i8_broadcast.cairo create mode 100644 tests/src/nodes/and_i8_broadcast/input_0.cairo create mode 100644 tests/src/nodes/and_i8_broadcast/input_1.cairo create mode 100644 tests/src/nodes/and_i8_broadcast/output_0.cairo create mode 100644 tests/src/nodes/and_u32.cairo create mode 100644 tests/src/nodes/and_u32/input_0.cairo create mode 100644 tests/src/nodes/and_u32/input_1.cairo create mode 100644 tests/src/nodes/and_u32/output_0.cairo create mode 100644 tests/src/nodes/and_u32_broadcast.cairo create mode 100644 tests/src/nodes/and_u32_broadcast/input_0.cairo create mode 100644 tests/src/nodes/and_u32_broadcast/input_1.cairo create mode 100644 tests/src/nodes/and_u32_broadcast/output_0.cairo diff --git a/nodegen/node/and.py b/nodegen/node/and.py new file mode 100644 index 000000000..0a1741a21 --- /dev/null +++ b/nodegen/node/and.py @@ -0,0 +1,168 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl + + +class And(RunAll): + @staticmethod + def and_u32(): + def default(): + x = np.random.randint(0, 6, (3, 3, 3)).astype(np.uint32) + y = np.random.randint(0, 6, (3, 3, 3)).astype(np.uint32) + z = np.logical_and(x, y) + + x = Tensor(Dtype.U32, x.shape, x.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_u32" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + def broadcast(): + x = np.random.randint(0, 6, (2, 2)).astype(np.uint32) + y = np.random.randint(0, 6, (1, 2)).astype(np.uint32) + z = np.logical_and(x, y) + + x = Tensor(Dtype.U32, x.shape, x.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_u32_broadcast" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + default() + broadcast() + + @staticmethod + def and_i32(): + def default(): + x = np.random.randint(-3, 3, (3, 3, 3)).astype(np.int32) + y = np.random.randint(-3, 3, (3, 3, 3)).astype(np.int32) + z = np.logical_and(x, y) + + x = Tensor(Dtype.I32, x.shape, x.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_i32" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + def broadcast(): + x = np.random.randint(-3, 3, (2, 2)).astype(np.int32) + y = np.random.randint(-3, 3, (1, 2)).astype(np.int32) + z = np.logical_and(x, y) + + x = Tensor(Dtype.I32, x.shape, x.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_i32_broadcast" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + default() + broadcast() + + @staticmethod + def and_i8(): + def default(): + x = np.random.randint(-3, 3, (3, 3, 3)).astype(np.int8) + y = np.random.randint(-3, 3, (3, 3, 3)).astype(np.int8) + z = np.logical_and(x, y) + + x = Tensor(Dtype.I8, x.shape, x.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_i8" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + def broadcast(): + x = np.random.randint(-3, 3, (2, 2)).astype(np.int8) + y = np.random.randint(-3, 3, (1, 2)).astype(np.int8) + z = np.logical_and(x, y) + + x = Tensor(Dtype.I8, x.shape, x.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_i8_broadcast" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + default() + broadcast() + + @staticmethod + def and_fp8x23(): + def default(): + x = np.random.randint(-3, 3, (3, 3, 3)).astype(np.float64) + y = np.random.randint(-3, 3, (3, 3, 3)).astype(np.float64) + z = np.logical_and(x, y) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_fp8x23" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + def broadcast(): + x = np.random.randint(-3, 3, (2, 2)).astype(np.float64) + y = np.random.randint(-3, 3, (1, 2)).astype(np.float64) + z = np.logical_and(x, y) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_fp8x23_broadcast" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + default() + broadcast() + + @staticmethod + def and_fp16x16(): + def default(): + x = np.random.randint(-3, 3, (3, 3, 3)).astype(np.float64) + y = np.random.randint(-3, 3, (3, 3, 3)).astype(np.float64) + z = np.logical_and(x, y) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_fp16x16" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + def broadcast(): + x = np.random.randint(-3, 3, (2, 2)).astype(np.float64) + y = np.random.randint(-3, 3, (1, 2)).astype(np.float64) + z = np.logical_and(x, y) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + + name = "and_fp16x16_broadcast" + make_node([x, y], [z], name) + make_test([x, y], z, "input_0.and(@input_1)", name) + + default() + broadcast() diff --git a/tests/src/nodes.cairo b/tests/src/nodes.cairo index 398edc4bf..a0c1c5781 100644 --- a/tests/src/nodes.cairo +++ b/tests/src/nodes.cairo @@ -433,3 +433,13 @@ mod clip_i8_2d; mod clip_i8_3d; mod clip_u32_2d; mod clip_u32_3d; +mod and_fp16x16; +mod and_fp16x16_broadcast; +mod and_fp8x23; +mod and_fp8x23_broadcast; +mod and_i32; +mod and_i32_broadcast; +mod and_i8; +mod and_i8_broadcast; +mod and_u32; +mod and_u32_broadcast; diff --git a/tests/src/nodes/and_fp16x16.cairo b/tests/src/nodes/and_fp16x16.cairo new file mode 100644 index 000000000..87b32da17 --- /dev/null +++ b/tests/src/nodes/and_fp16x16.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_fp16x16() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16/input_0.cairo b/tests/src/nodes/and_fp16x16/input_0.cairo new file mode 100644 index 000000000..05fb9a94d --- /dev/null +++ b/tests/src/nodes/and_fp16x16/input_0.cairo @@ -0,0 +1,42 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16/input_1.cairo b/tests/src/nodes/and_fp16x16/input_1.cairo new file mode 100644 index 000000000..991a8f688 --- /dev/null +++ b/tests/src/nodes/and_fp16x16/input_1.cairo @@ -0,0 +1,42 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16/output_0.cairo b/tests/src/nodes/and_fp16x16/output_0.cairo new file mode 100644 index 000000000..75d7aed70 --- /dev/null +++ b/tests/src/nodes/and_fp16x16/output_0.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + data.append(0); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16_broadcast.cairo b/tests/src/nodes/and_fp16x16_broadcast.cairo new file mode 100644 index 000000000..3abd8c250 --- /dev/null +++ b/tests/src/nodes/and_fp16x16_broadcast.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_fp16x16_broadcast() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16_broadcast/input_0.cairo b/tests/src/nodes/and_fp16x16_broadcast/input_0.cairo new file mode 100644 index 000000000..a461be36c --- /dev/null +++ b/tests/src/nodes/and_fp16x16_broadcast/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16_broadcast/input_1.cairo b/tests/src/nodes/and_fp16x16_broadcast/input_1.cairo new file mode 100644 index 000000000..0bc098ab8 --- /dev/null +++ b/tests/src/nodes/and_fp16x16_broadcast/input_1.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp16x16_broadcast/output_0.cairo b/tests/src/nodes/and_fp16x16_broadcast/output_0.cairo new file mode 100644 index 000000000..2094003c4 --- /dev/null +++ b/tests/src/nodes/and_fp16x16_broadcast/output_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23.cairo b/tests/src/nodes/and_fp8x23.cairo new file mode 100644 index 000000000..17e9bdad0 --- /dev/null +++ b/tests/src/nodes/and_fp8x23.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_fp8x23() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23/input_0.cairo b/tests/src/nodes/and_fp8x23/input_0.cairo new file mode 100644 index 000000000..d43898dd1 --- /dev/null +++ b/tests/src/nodes/and_fp8x23/input_0.cairo @@ -0,0 +1,42 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 25165824, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23/input_1.cairo b/tests/src/nodes/and_fp8x23/input_1.cairo new file mode 100644 index 000000000..d2d59fda5 --- /dev/null +++ b/tests/src/nodes/and_fp8x23/input_1.cairo @@ -0,0 +1,42 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23/output_0.cairo b/tests/src/nodes/and_fp8x23/output_0.cairo new file mode 100644 index 000000000..e2e18825c --- /dev/null +++ b/tests/src/nodes/and_fp8x23/output_0.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23_broadcast.cairo b/tests/src/nodes/and_fp8x23_broadcast.cairo new file mode 100644 index 000000000..d54fd4afe --- /dev/null +++ b/tests/src/nodes/and_fp8x23_broadcast.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_fp8x23_broadcast() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23_broadcast/input_0.cairo b/tests/src/nodes/and_fp8x23_broadcast/input_0.cairo new file mode 100644 index 000000000..4bcaf93cd --- /dev/null +++ b/tests/src/nodes/and_fp8x23_broadcast/input_0.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23_broadcast/input_1.cairo b/tests/src/nodes/and_fp8x23_broadcast/input_1.cairo new file mode 100644 index 000000000..538d00e70 --- /dev/null +++ b/tests/src/nodes/and_fp8x23_broadcast/input_1.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP8x23; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_fp8x23_broadcast/output_0.cairo b/tests/src/nodes/and_fp8x23_broadcast/output_0.cairo new file mode 100644 index 000000000..e7b1da4b1 --- /dev/null +++ b/tests/src/nodes/and_fp8x23_broadcast/output_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32.cairo b/tests/src/nodes/and_i32.cairo new file mode 100644 index 000000000..6c85c1433 --- /dev/null +++ b/tests/src/nodes/and_i32.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_i32() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32/input_0.cairo b/tests/src/nodes/and_i32/input_0.cairo new file mode 100644 index 000000000..246a30bdf --- /dev/null +++ b/tests/src/nodes/and_i32/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32/input_1.cairo b/tests/src/nodes/and_i32/input_1.cairo new file mode 100644 index 000000000..2798d8bf5 --- /dev/null +++ b/tests/src/nodes/and_i32/input_1.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: true }); + data.append(i32 { mag: 1, sign: true }); + data.append(i32 { mag: 0, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32/output_0.cairo b/tests/src/nodes/and_i32/output_0.cairo new file mode 100644 index 000000000..5cfd192c9 --- /dev/null +++ b/tests/src/nodes/and_i32/output_0.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32_broadcast.cairo b/tests/src/nodes/and_i32_broadcast.cairo new file mode 100644 index 000000000..33f7abd02 --- /dev/null +++ b/tests/src/nodes/and_i32_broadcast.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_i32_broadcast() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32_broadcast/input_0.cairo b/tests/src/nodes/and_i32_broadcast/input_0.cairo new file mode 100644 index 000000000..8893b6d3b --- /dev/null +++ b/tests/src/nodes/and_i32_broadcast/input_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 0, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32_broadcast/input_1.cairo b/tests/src/nodes/and_i32_broadcast/input_1.cairo new file mode 100644 index 000000000..0e3bae4c3 --- /dev/null +++ b/tests/src/nodes/and_i32_broadcast/input_1.cairo @@ -0,0 +1,15 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i32_broadcast/output_0.cairo b/tests/src/nodes/and_i32_broadcast/output_0.cairo new file mode 100644 index 000000000..386b678ef --- /dev/null +++ b/tests/src/nodes/and_i32_broadcast/output_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8.cairo b/tests/src/nodes/and_i8.cairo new file mode 100644 index 000000000..7d2e67c2d --- /dev/null +++ b/tests/src/nodes/and_i8.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_i8() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8/input_0.cairo b/tests/src/nodes/and_i8/input_0.cairo new file mode 100644 index 000000000..4271563b5 --- /dev/null +++ b/tests/src/nodes/and_i8/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 3, sign: true }); + data.append(i8 { mag: 3, sign: true }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 3, sign: true }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8/input_1.cairo b/tests/src/nodes/and_i8/input_1.cairo new file mode 100644 index 000000000..ba15f89b4 --- /dev/null +++ b/tests/src/nodes/and_i8/input_1.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 3, sign: true }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 2, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8/output_0.cairo b/tests/src/nodes/and_i8/output_0.cairo new file mode 100644 index 000000000..3e2f5f3f3 --- /dev/null +++ b/tests/src/nodes/and_i8/output_0.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8_broadcast.cairo b/tests/src/nodes/and_i8_broadcast.cairo new file mode 100644 index 000000000..12d5f13e7 --- /dev/null +++ b/tests/src/nodes/and_i8_broadcast.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_i8_broadcast() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8_broadcast/input_0.cairo b/tests/src/nodes/and_i8_broadcast/input_0.cairo new file mode 100644 index 000000000..6de8ea3a8 --- /dev/null +++ b/tests/src/nodes/and_i8_broadcast/input_0.cairo @@ -0,0 +1,17 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 1, sign: true }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 2, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8_broadcast/input_1.cairo b/tests/src/nodes/and_i8_broadcast/input_1.cairo new file mode 100644 index 000000000..1bc1a2f2b --- /dev/null +++ b/tests/src/nodes/and_i8_broadcast/input_1.cairo @@ -0,0 +1,15 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_i8_broadcast/output_0.cairo b/tests/src/nodes/and_i8_broadcast/output_0.cairo new file mode 100644 index 000000000..328774a25 --- /dev/null +++ b/tests/src/nodes/and_i8_broadcast/output_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32.cairo b/tests/src/nodes/and_u32.cairo new file mode 100644 index 000000000..477f59981 --- /dev/null +++ b/tests/src/nodes/and_u32.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_u32() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32/input_0.cairo b/tests/src/nodes/and_u32/input_0.cairo new file mode 100644 index 000000000..a1d822fcd --- /dev/null +++ b/tests/src/nodes/and_u32/input_0.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(5); + data.append(2); + data.append(5); + data.append(2); + data.append(4); + data.append(3); + data.append(1); + data.append(0); + data.append(2); + data.append(0); + data.append(4); + data.append(3); + data.append(0); + data.append(2); + data.append(5); + data.append(1); + data.append(3); + data.append(5); + data.append(0); + data.append(5); + data.append(1); + data.append(5); + data.append(0); + data.append(2); + data.append(0); + data.append(0); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32/input_1.cairo b/tests/src/nodes/and_u32/input_1.cairo new file mode 100644 index 000000000..37ce51b01 --- /dev/null +++ b/tests/src/nodes/and_u32/input_1.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(0); + data.append(3); + data.append(0); + data.append(4); + data.append(5); + data.append(2); + data.append(1); + data.append(3); + data.append(5); + data.append(0); + data.append(0); + data.append(0); + data.append(5); + data.append(2); + data.append(1); + data.append(2); + data.append(5); + data.append(1); + data.append(4); + data.append(1); + data.append(4); + data.append(4); + data.append(1); + data.append(1); + data.append(3); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32/output_0.cairo b/tests/src/nodes/and_u32/output_0.cairo new file mode 100644 index 000000000..0eced5867 --- /dev/null +++ b/tests/src/nodes/and_u32/output_0.cairo @@ -0,0 +1,40 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(0); + data.append(0); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(0); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32_broadcast.cairo b/tests/src/nodes/and_u32_broadcast.cairo new file mode 100644 index 000000000..a98e806d9 --- /dev/null +++ b/tests/src/nodes/and_u32_broadcast.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_and_u32_broadcast() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.and(@input_1); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32_broadcast/input_0.cairo b/tests/src/nodes/and_u32_broadcast/input_0.cairo new file mode 100644 index 000000000..afd5ea46c --- /dev/null +++ b/tests/src/nodes/and_u32_broadcast/input_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(2); + data.append(5); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32_broadcast/input_1.cairo b/tests/src/nodes/and_u32_broadcast/input_1.cairo new file mode 100644 index 000000000..ea3e50244 --- /dev/null +++ b/tests/src/nodes/and_u32_broadcast/input_1.cairo @@ -0,0 +1,14 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(5); + data.append(5); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/and_u32_broadcast/output_0.cairo b/tests/src/nodes/and_u32_broadcast/output_0.cairo new file mode 100644 index 000000000..e7b1da4b1 --- /dev/null +++ b/tests/src/nodes/and_u32_broadcast/output_0.cairo @@ -0,0 +1,16 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file From 3081f293cf57e41cba32504944716fbf5afced94 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 24 Oct 2023 16:28:43 +0300 Subject: [PATCH 69/78] implement gemm function --- src/operators/nn/functional.cairo | 1 + src/operators/nn/functional/gemm.cairo | 54 +++++++++ src/operators/tensor/math/arithmetic.cairo | 129 +++++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 src/operators/nn/functional/gemm.cairo diff --git a/src/operators/nn/functional.cairo b/src/operators/nn/functional.cairo index 704948fec..d130433bc 100644 --- a/src/operators/nn/functional.cairo +++ b/src/operators/nn/functional.cairo @@ -8,3 +8,4 @@ mod linear; mod logsoftmax; mod thresholded_relu; mod hard_sigmoid; +mod gemm; \ No newline at end of file diff --git a/src/operators/nn/functional/gemm.cairo b/src/operators/nn/functional/gemm.cairo new file mode 100644 index 000000000..8d4498cc6 --- /dev/null +++ b/src/operators/nn/functional/gemm.cairo @@ -0,0 +1,54 @@ +use array::SpanTrait; + +use orion::numbers::NumberTrait; +use orion::operators::tensor::{core::{Tensor, TensorTrait}, math::arithmetic::mul_by_val}; + +/// Cf: NNTrait::gemm docstring +fn gemm< + T, + MAG, + impl TTensor: TensorTrait, + impl TAddTensor: Add>, + impl TNumberTrait: NumberTrait, + impl TPartialEq: PartialEq, + impl TMul: Mul, + impl TCopy: Copy, + impl TDrop: Drop +>( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool +) -> Tensor { + let alpha: T = if alpha.is_some() { + alpha.unwrap() + } else { + NumberTrait::one() + }; + + let beta: T = if beta.is_some() { + beta.unwrap() + } else { + NumberTrait::one() + }; + + if transA == true { + let A = A.transpose(array![1, 0].span()); + } + + if transB == true { + let B = B.transpose(array![1, 0].span()); + } + + match C { + Option::Some(c) => { + return mul_by_val(@A.matmul(@B), alpha) + mul_by_val(@c, beta); + }, + Option::None(_) => { + return mul_by_val(@A.matmul(@B), alpha); + } + } +} diff --git a/src/operators/tensor/math/arithmetic.cairo b/src/operators/tensor/math/arithmetic.cairo index 06879a4af..6c1df4c28 100644 --- a/src/operators/tensor/math/arithmetic.cairo +++ b/src/operators/tensor/math/arithmetic.cairo @@ -5,6 +5,7 @@ use array::SpanTrait; use orion::operators::tensor::helpers::broadcast_shape; +use orion::numbers::NumberTrait; use orion::operators::tensor::core::{Tensor, TensorTrait, unravel_index,}; use orion::operators::tensor::helpers::{broadcast_index_mapping, len_from_shape,}; use orion::utils::saturate; @@ -37,6 +38,38 @@ fn add< return TensorTrait::::new(broadcasted_shape, result.span()); } +fn add_by_val< + T, + MAG, + impl TTensor: TensorTrait, + impl TAdd: Add, + impl TNumber: NumberTrait, + impl TPartialEq: PartialEq, + impl TCopy: Copy, + impl TDrop: Drop +>( + self: @Tensor, val: T +) -> Tensor { + if val == NumberTrait::zero() { + return *self; + } + + let mut input_data = *self.data; + let mut data_result = ArrayTrait::::new(); + loop { + match input_data.pop_front() { + Option::Some(ele) => { + data_result.append(*ele + val); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::::new(*self.shape, data_result.span()); +} + fn saturated_add< T, Q, @@ -111,6 +144,38 @@ fn sub< return TensorTrait::::new(broadcasted_shape, result.span()); } +fn sub_by_val< + T, + MAG, + impl TTensor: TensorTrait, + impl TSub: Sub, + impl TNumber: NumberTrait, + impl TPartialEq: PartialEq, + impl TCopy: Copy, + impl TDrop: Drop +>( + self: @Tensor, val: T +) -> Tensor { + if val == NumberTrait::zero() { + return *self; + } + + let mut input_data = *self.data; + let mut data_result = ArrayTrait::::new(); + loop { + match input_data.pop_front() { + Option::Some(ele) => { + data_result.append(*ele - val); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::::new(*self.shape, data_result.span()); +} + fn saturated_sub< T, Q, @@ -185,6 +250,38 @@ fn mul< return TensorTrait::::new(broadcasted_shape, result.span()); } +fn mul_by_val< + T, + MAG, + impl TTensor: TensorTrait, + impl TMul: Mul, + impl TNumber: NumberTrait, + impl TPartialEq: PartialEq, + impl TCopy: Copy, + impl TDrop: Drop +>( + self: @Tensor, val: T +) -> Tensor { + if val == NumberTrait::one() { + return *self; + } + + let mut input_data = *self.data; + let mut data_result = ArrayTrait::::new(); + loop { + match input_data.pop_front() { + Option::Some(ele) => { + data_result.append(*ele * val); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::::new(*self.shape, data_result.span()); +} + fn saturated_mul< T, Q, @@ -259,6 +356,38 @@ fn div< return TensorTrait::::new(broadcasted_shape, result.span()); } +fn div_by_val< + T, + MAG, + impl TTensor: TensorTrait, + impl TDiv: Div, + impl TNumber: NumberTrait, + impl TPartialEq: PartialEq, + impl TCopy: Copy, + impl TDrop: Drop +>( + self: @Tensor, val: T +) -> Tensor { + if val == NumberTrait::one() { + return *self; + } + + let mut input_data = *self.data; + let mut data_result = ArrayTrait::::new(); + loop { + match input_data.pop_front() { + Option::Some(ele) => { + data_result.append(*ele / val); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::::new(*self.shape, data_result.span()); +} + fn saturated_div< T, Q, From 6c86e5a8dbcbbb97ca917216358ad907a0cbd393 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 24 Oct 2023 16:36:57 +0300 Subject: [PATCH 70/78] add implementations --- src/operators/nn/core.cairo | 9 +++++++++ src/operators/nn/implementations/nn_fp16x16.cairo | 12 ++++++++++++ src/operators/nn/implementations/nn_fp32x32.cairo | 12 ++++++++++++ src/operators/nn/implementations/nn_fp64x64.cairo | 12 ++++++++++++ src/operators/nn/implementations/nn_fp8x23.cairo | 12 ++++++++++++ src/operators/nn/implementations/nn_i32.cairo | 12 ++++++++++++ src/operators/nn/implementations/nn_i8.cairo | 12 ++++++++++++ src/operators/nn/implementations/nn_u32.cairo | 12 ++++++++++++ 8 files changed, 93 insertions(+) diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index 36a039e39..060c30ee3 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -557,4 +557,13 @@ trait NNTrait { /// ``` /// fn thresholded_relu(tensor: @Tensor, alpha: @T) -> Tensor; + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor; } diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index 4c401588e..c1bc6970b 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -56,4 +56,16 @@ impl FP16x16NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @FP16x16, beta: @FP16x16) -> Tensor { functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } diff --git a/src/operators/nn/implementations/nn_fp32x32.cairo b/src/operators/nn/implementations/nn_fp32x32.cairo index 67a6d86db..832c26dcf 100644 --- a/src/operators/nn/implementations/nn_fp32x32.cairo +++ b/src/operators/nn/implementations/nn_fp32x32.cairo @@ -50,4 +50,16 @@ impl FP32x32NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @FP32x32, beta: @FP32x32) -> Tensor { functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } diff --git a/src/operators/nn/implementations/nn_fp64x64.cairo b/src/operators/nn/implementations/nn_fp64x64.cairo index 86facdc1a..0a674fe47 100644 --- a/src/operators/nn/implementations/nn_fp64x64.cairo +++ b/src/operators/nn/implementations/nn_fp64x64.cairo @@ -50,4 +50,16 @@ impl FP64x64NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @FP64x64, beta: @FP64x64) -> Tensor { functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 9530eefc4..d246bf2cc 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -54,4 +54,16 @@ impl FP8x23NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @FP8x23, beta: @FP8x23) -> Tensor { functional::hard_sigmoid::hard_sigmoid(*tensor, alpha, beta) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } diff --git a/src/operators/nn/implementations/nn_i32.cairo b/src/operators/nn/implementations/nn_i32.cairo index a032ac1dc..dee95ec9f 100644 --- a/src/operators/nn/implementations/nn_i32.cairo +++ b/src/operators/nn/implementations/nn_i32.cairo @@ -46,4 +46,16 @@ impl I32NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @i32, beta: @i32) -> Tensor { panic(array!['not supported!']) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } diff --git a/src/operators/nn/implementations/nn_i8.cairo b/src/operators/nn/implementations/nn_i8.cairo index aaba87fd8..c9057fcce 100644 --- a/src/operators/nn/implementations/nn_i8.cairo +++ b/src/operators/nn/implementations/nn_i8.cairo @@ -46,4 +46,16 @@ impl I8NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @i8, beta: @i8) -> Tensor { panic(array!['not supported!']) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } diff --git a/src/operators/nn/implementations/nn_u32.cairo b/src/operators/nn/implementations/nn_u32.cairo index 74e6beb37..1a2883a16 100644 --- a/src/operators/nn/implementations/nn_u32.cairo +++ b/src/operators/nn/implementations/nn_u32.cairo @@ -45,4 +45,16 @@ impl U32NN of NNTrait { fn hard_sigmoid(tensor: @Tensor, alpha: @u32, beta: @u32) -> Tensor { panic(array!['not supported!']) } + + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor { + functional::gemm::gemm(A, B, C, alpha, beta, transA, transB) + } } From eebe7345df48f289322a3452e3a7a74a4a9c2667 Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Tue, 24 Oct 2023 16:12:54 +0200 Subject: [PATCH 71/78] Add docstring --- docs/framework/operators/tensor/README.md | 1 + docs/framework/operators/tensor/tensor.and.md | 67 ++++++++++++++++++ src/operators/tensor/core.cairo | 69 +++++++++++++++++++ src/operators/tensor/math/and.cairo | 1 + 4 files changed, 138 insertions(+) create mode 100644 docs/framework/operators/tensor/tensor.and.md diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index b33f330c7..7017e4056 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -82,6 +82,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.unsqueeze`](tensor.unsqueeze.md) | Inserts single-dimensional entries to the shape of an input tensor. | | [`tensor.sign`](tensor.sign.md) | Calculates the sign of the given input tensor element-wise. | | [`tensor.clip`](tensor.clip.md) | Clip operator limits the given input within an interval. | +| [`tensor.and`](tensor.and.md) | Computes the logical AND of two tensors element-wise. | ## Arithmetic Operations diff --git a/docs/framework/operators/tensor/tensor.and.md b/docs/framework/operators/tensor/tensor.and.md new file mode 100644 index 000000000..94556ce0f --- /dev/null +++ b/docs/framework/operators/tensor/tensor.and.md @@ -0,0 +1,67 @@ +#tensor.and + +```rust + fn and(self: @Tensor, other: @Tensor) -> Tensor; +``` + +Computes the logical AND of two tensors element-wise. +The input tensors must have either: +* Exactly the same shape +* The same number of dimensions and the length of each dimension is either a common length or 1. + +## Args + +* `self`(`@Tensor`) - The first tensor to be compared +* `other`(`@Tensor`) - The second tensor to be compared + +## Panics + +* Panics if the shapes are not equal or broadcastable + +## Returns + +A new `Tensor` of booleans (0 or 1) with the same shape as the broadcasted inputs. + +## Examples + +Case 1: Compare tensors with same shape + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + +fn and_example() -> Tensor { + let tensor_1 = TensorTrait::::new( + shape: array![3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(), + ); + + let tensor_2 = TensorTrait::::new( + shape: array![3, 3].span(), data: array![0, 1, 2, 0, 1, 2, 0, 1, 2].span(), + ); + + return tensor_1.and(@tensor_2); +} +>>> [0,1,1,0,1,1,0,1,1] +``` + +Case 2: Compare tensors with different shapes + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + +fn and_example() -> Tensor { + let tensor_1 = TensorTrait::::new( + shape: array![3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(), + ); + + let tensor_2 = TensorTrait::::new( + shape: array![1, 3].span(), data: array![0, 1, 2].span(), + ); + + return tensor_1.and(@tensor_2); +} +>>> [0,1,1,0,1,1,0,1,1] +``` diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index a161087db..b8f5a7fd0 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -78,6 +78,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde { /// # tensor.new @@ -2669,6 +2670,74 @@ trait TensorTrait { /// ``` /// fn sign(self: @Tensor) -> Tensor; + /// #tensor.and + /// + /// ```rust + /// fn and(self: @Tensor, other: @Tensor) -> Tensor; + /// ``` + /// + /// Computes the logical AND of two tensors element-wise. + /// The input tensors must have either: + /// * Exactly the same shape + /// * The same number of dimensions and the length of each dimension is either a common length or 1. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - The first tensor to be compared + /// * `other`(`@Tensor`) - The second tensor to be compared + /// + /// ## Panics + /// + /// * Panics if the shapes are not equal or broadcastable + /// + /// ## Returns + /// + /// A new `Tensor` of booleans (0 or 1) with the same shape as the broadcasted inputs. + /// + /// ## Examples + /// + /// Case 1: Compare tensors with same shape + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + /// + /// fn and_example() -> Tensor { + /// let tensor_1 = TensorTrait::::new( + /// shape: array![3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(), + /// ); + /// + /// let tensor_2 = TensorTrait::::new( + /// shape: array![3, 3].span(), data: array![0, 1, 2, 0, 1, 2, 0, 1, 2].span(), + /// ); + /// + /// return tensor_1.and(@tensor_2); + /// } + /// >>> [0,1,1,0,1,1,0,1,1] + /// ``` + /// + /// Case 2: Compare tensors with different shapes + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + /// + /// fn and_example() -> Tensor { + /// let tensor_1 = TensorTrait::::new( + /// shape: array![3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(), + /// ); + /// + /// let tensor_2 = TensorTrait::::new( + /// shape: array![1, 3].span(), data: array![0, 1, 2].span(), + /// ); + /// + /// return tensor_1.and(@tensor_2); + /// } + /// >>> [0,1,1,0,1,1,0,1,1] + /// ``` + /// fn and(self: @Tensor, other: @Tensor) -> Tensor; } diff --git a/src/operators/tensor/math/and.cairo b/src/operators/tensor/math/and.cairo index dba0b0b1d..80872185e 100644 --- a/src/operators/tensor/math/and.cairo +++ b/src/operators/tensor/math/and.cairo @@ -8,6 +8,7 @@ use orion::operators::tensor::helpers::{ broadcast_shape, broadcast_index_mapping, len_from_shape, check_compatibility }; +/// Cf: TensorTrait::and docstring fn and< T, MAG, From 1245027800bf39a45fc81583bf3c747b809e5f51 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 25 Oct 2023 11:01:10 +0300 Subject: [PATCH 72/78] add AND in fp16x16w and fp8x23w + missing doc --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 115 +++++++++--------- src/numbers.cairo | 8 ++ .../fp16x16wide/math/comp.cairo | 9 ++ .../fp8x23wide/math/comp.cairo | 9 ++ .../implementations/tensor_fp16x16wide.cairo | 4 + .../implementations/tensor_fp8x23wide.cairo | 4 + 7 files changed, 93 insertions(+), 57 deletions(-) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index cc8fe91b1..f1543c5c6 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -88,6 +88,7 @@ * [tensor.sign](framework/operators/tensor/tensor.sign.md) * [tensor.clip](framework/operators/tensor/tensor.clip.md) * [tensor.identity](framework/operators/tensor/tensor.identity.md) + * [tensor.and](framework/operators/tensor/tensor.and.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 35e2daf55..c2d49c2be 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -4,62 +4,63 @@ To see the full list of available ONNX Operators refer to [this table](https://g You can see below the list of current supported ONNX Operators: -| Operator | Implemented | -| :-------------------------------------------------------------: | :------------------: | -| [MatMul](operators/tensor/tensor.matmul.md) | :white\_check\_mark: | -| [MatMulInteger](operators/tensor/tensor.matmul.md) | :white\_check\_mark: | -| [Add](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | -| [Sub](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | -| [Mul](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | -| [Div](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | -| [Equal](operators/tensor/tensor.equal.md) | :white\_check\_mark: | -| [Greater](operators/tensor/tensor.greater.md) | :white\_check\_mark: | -| [GreaterOrEqual](operators/tensor/tensor.greater\_equal.md) | :white\_check\_mark: | -| [Less](operators/tensor/tensor.less.md) | :white\_check\_mark: | -| [LessOrEqual](operators/tensor/tensor.less\_equal.md) | :white\_check\_mark: | -| [Abs](operators/tensor/tensor.abs.md) | :white\_check\_mark: | -| [Neg](operators/tensor/tensor.neg.md) | :white\_check\_mark: | -| [Ceil](operators/tensor/tensor.ceil.md) | :white\_check\_mark: | -| [Exp](operators/tensor/tensor.exp.md) | :white\_check\_mark: | -| [Ln](operators/tensor/tensor.log.md) | :white\_check\_mark: | -| [Reshape](operators/tensor/tensor.reshape.md) | :white\_check\_mark: | -| [Transpose](operators/tensor/tensor.transpose.md) | :white\_check\_mark: | -| [ArgMax](operators/tensor/tensor.argmax.md) | :white\_check\_mark: | -| [ArgMin](operators/tensor/tensor.argmin.md) | :white\_check\_mark: | -| [ReduceSum](operators/tensor/tensor.reduce\_sum.md) | :white\_check\_mark: | -| [CumSum](operators/tensor/tensor.cumsum.md) | :white\_check\_mark: | -| [Cos](operators/tensor/tensor.cos.md) | :white\_check\_mark: | -| [Sin](operators/tensor/tensor.sin.md) | :white\_check\_mark: | -| [Asin](operators/tensor/tensor.asin.md) | :white\_check\_mark: | -| [Flatten](operators/tensor/tensor.flatten.md) | :white\_check\_mark: | -| [Relu](operators/neural-network/nn.relu.md) | :white\_check\_mark: | -| [LeakyRelu](operators/neural-network/nn.leaky\_relu.md) | :white\_check\_mark: | -|[ThresholdedRelu](operators/neural-network/nn.thresholded\_relu.md)| :white\_check\_mark: | -| [Sigmoid](operators/neural-network/nn.sigmoid.md) | :white\_check\_mark: | -| [Softmax](operators/neural-network/nn.softmax.md) | :white\_check\_mark: | -| [LogSoftmax](operators/neural-network/nn.logsoftmax.md) | :white\_check\_mark: | -| [Softsign](operators/neural-network/nn.softsign.md) | :white\_check\_mark: | -| [Softplus](operators/neural-network/nn.softplus.md) | :white\_check\_mark: | -| [Linear](operators/neural-network/nn.linear.md) | :white\_check\_mark: | -| [HardSigmoid](operators/neural-network/nn.hard\_sigmoid.md) | :white\_check\_mark: | -| [Sinh](operators/tensor/tensor.sinh.md) | :white\_check\_mark: | -| [Asinh](operators/tensor/tensor.asinh.md) | :white\_check\_mark: | -| [Cosh](operators/tensor/tensor.cosh.md) | :white\_check\_mark: | -| [ACosh](operators/tensor/tensor.acosh.md) | :white\_check\_mark: | -| [Tanh](operators/tensor/tensor.tanh.md) | :white\_check\_mark: | -| [Acos](operators/tensor/tensor.acos.md) | :white\_check\_mark: | -| [Sqrt](operators/tensor/tensor.sqrt.md) | :white\_check\_mark: | -| [Onehot](operators/tensor/tensor.onehot.md) | :white\_check\_mark: | -| [Slice](operators/tensor/tensor.slice.md) | :white\_check\_mark: | -| [Concat](operators/tensor/tensor.concat.md) | :white\_check\_mark: | -| [Gather](operators/tensor/tensor.gather.md) | :white\_check\_mark: | -| [QuantizeLinear](operators/tensor/tensor.quantize\_linear.md) | :white\_check\_mark: | -| [DequantizeLinear](operators/tensor/tensor.quantize\_linear.md) | :white\_check\_mark: | -| [Nonzero](operators/tensor/tensor.nonzero.md) | :white\_check\_mark: | -| [Squeeze](operators/tensor/tensor.squeeze.md) | :white\_check\_mark: | -| [Unsqueeze](operators/tensor/tensor.unsqueeze.md) | :white\_check\_mark: | -| [Sign](operators/tensor/tensor.sign.md) | :white\_check\_mark: | -| [Clip](operators/tensor/tensor.clip.md) | :white\_check\_mark: | -| [Identity](operators/tensor/tensor.identity.md) | :white\_check\_mark: | +| Operator | Implemented | +| :-----------------------------------------------------------------: | :------------------: | +| [MatMul](operators/tensor/tensor.matmul.md) | :white\_check\_mark: | +| [MatMulInteger](operators/tensor/tensor.matmul.md) | :white\_check\_mark: | +| [Add](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | +| [Sub](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | +| [Mul](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | +| [Div](operators/tensor/#arithmetic-operations) | :white\_check\_mark: | +| [Equal](operators/tensor/tensor.equal.md) | :white\_check\_mark: | +| [Greater](operators/tensor/tensor.greater.md) | :white\_check\_mark: | +| [GreaterOrEqual](operators/tensor/tensor.greater\_equal.md) | :white\_check\_mark: | +| [Less](operators/tensor/tensor.less.md) | :white\_check\_mark: | +| [LessOrEqual](operators/tensor/tensor.less\_equal.md) | :white\_check\_mark: | +| [Abs](operators/tensor/tensor.abs.md) | :white\_check\_mark: | +| [Neg](operators/tensor/tensor.neg.md) | :white\_check\_mark: | +| [Ceil](operators/tensor/tensor.ceil.md) | :white\_check\_mark: | +| [Exp](operators/tensor/tensor.exp.md) | :white\_check\_mark: | +| [Ln](operators/tensor/tensor.log.md) | :white\_check\_mark: | +| [Reshape](operators/tensor/tensor.reshape.md) | :white\_check\_mark: | +| [Transpose](operators/tensor/tensor.transpose.md) | :white\_check\_mark: | +| [ArgMax](operators/tensor/tensor.argmax.md) | :white\_check\_mark: | +| [ArgMin](operators/tensor/tensor.argmin.md) | :white\_check\_mark: | +| [ReduceSum](operators/tensor/tensor.reduce\_sum.md) | :white\_check\_mark: | +| [CumSum](operators/tensor/tensor.cumsum.md) | :white\_check\_mark: | +| [Cos](operators/tensor/tensor.cos.md) | :white\_check\_mark: | +| [Sin](operators/tensor/tensor.sin.md) | :white\_check\_mark: | +| [Asin](operators/tensor/tensor.asin.md) | :white\_check\_mark: | +| [Flatten](operators/tensor/tensor.flatten.md) | :white\_check\_mark: | +| [Relu](operators/neural-network/nn.relu.md) | :white\_check\_mark: | +| [LeakyRelu](operators/neural-network/nn.leaky\_relu.md) | :white\_check\_mark: | +| [ThresholdedRelu](operators/neural-network/nn.thresholded\_relu.md) | :white\_check\_mark: | +| [Sigmoid](operators/neural-network/nn.sigmoid.md) | :white\_check\_mark: | +| [Softmax](operators/neural-network/nn.softmax.md) | :white\_check\_mark: | +| [LogSoftmax](operators/neural-network/nn.logsoftmax.md) | :white\_check\_mark: | +| [Softsign](operators/neural-network/nn.softsign.md) | :white\_check\_mark: | +| [Softplus](operators/neural-network/nn.softplus.md) | :white\_check\_mark: | +| [Linear](operators/neural-network/nn.linear.md) | :white\_check\_mark: | +| [HardSigmoid](operators/neural-network/nn.hard\_sigmoid.md) | :white\_check\_mark: | +| [Sinh](operators/tensor/tensor.sinh.md) | :white\_check\_mark: | +| [Asinh](operators/tensor/tensor.asinh.md) | :white\_check\_mark: | +| [Cosh](operators/tensor/tensor.cosh.md) | :white\_check\_mark: | +| [ACosh](operators/tensor/tensor.acosh.md) | :white\_check\_mark: | +| [Tanh](operators/tensor/tensor.tanh.md) | :white\_check\_mark: | +| [Acos](operators/tensor/tensor.acos.md) | :white\_check\_mark: | +| [Sqrt](operators/tensor/tensor.sqrt.md) | :white\_check\_mark: | +| [Onehot](operators/tensor/tensor.onehot.md) | :white\_check\_mark: | +| [Slice](operators/tensor/tensor.slice.md) | :white\_check\_mark: | +| [Concat](operators/tensor/tensor.concat.md) | :white\_check\_mark: | +| [Gather](operators/tensor/tensor.gather.md) | :white\_check\_mark: | +| [QuantizeLinear](operators/tensor/tensor.quantize\_linear.md) | :white\_check\_mark: | +| [DequantizeLinear](operators/tensor/tensor.quantize\_linear.md) | :white\_check\_mark: | +| [Nonzero](operators/tensor/tensor.nonzero.md) | :white\_check\_mark: | +| [Squeeze](operators/tensor/tensor.squeeze.md) | :white\_check\_mark: | +| [Unsqueeze](operators/tensor/tensor.unsqueeze.md) | :white\_check\_mark: | +| [Sign](operators/tensor/tensor.sign.md) | :white\_check\_mark: | +| [Clip](operators/tensor/tensor.clip.md) | :white\_check\_mark: | +| [Identity](operators/tensor/tensor.identity.md) | :white\_check\_mark: | +| [And](operators/tensor/tensor.and.md) | :white\_check\_mark: | Current Operators support: **51/156 (33%)** diff --git a/src/numbers.cairo b/src/numbers.cairo index 821e68812..5432d67f8 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -390,6 +390,10 @@ impl FP8x23WNumber of NumberTrait { fn sign(self: FP8x23W) -> FP8x23W { core_fp8x23wide::sign(self) } + + fn and(lhs: FP8x23W, rhs: FP8x23W) -> bool { + comp_fp8x23wide::and(lhs, rhs) + } } use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16}; @@ -732,6 +736,10 @@ impl FP16x16WNumber of NumberTrait { fn sign(self: FP16x16W) -> FP16x16W { core_fp16x16wide::sign(self) } + + fn and(lhs: FP16x16W, rhs: FP16x16W) -> bool { + comp_fp16x16wide::and(lhs, rhs) + } } use orion::numbers::fixed_point::implementations::fp64x64::core::{FP64x64Impl, FP64x64}; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo index 63a3e4855..4f815a124 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo @@ -35,6 +35,15 @@ fn or(a: FP16x16W, b: FP16x16W) -> bool { } } +fn and(a: FP16x16W, b: FP16x16W) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero || b == zero { + return false; + } else { + return true; + } +} + // Tests -------------------------------------------------------------------------------------------------------------- #[test] diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo index 95b329109..5cacb7f0e 100644 --- a/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo @@ -35,6 +35,15 @@ fn or(a: FP8x23W, b: FP8x23W) -> bool { } } +fn and(a: FP8x23W, b: FP8x23W) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero || b == zero { + return false; + } else { + return true; + } +} + // Tests -------------------------------------------------------------------------------------------------------------- #[test] diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 2669c91e5..101ad4d21 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -242,6 +242,10 @@ impl FP16x16WTensor of TensorTrait { core::clip(self, min, max) } + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } + fn identity(self: @Tensor) -> Tensor { core::identity(self) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a085c1f8b..e81beb028 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -234,6 +234,10 @@ impl FP8x23WTensor of TensorTrait { core::clip(self, min, max) } + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } + fn identity(self: @Tensor) -> Tensor { core::identity(self) } From 7b522b80e680c1bd16d7da0e6017b9bd554c60ad Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 08:03:39 +0000 Subject: [PATCH 73/78] docs: update README.md [skip ci] --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f4ca1efa..d6bc4cdb3 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Orion: An Open-source Framework for Validity and ZK ML ✨ -[![All Contributors](https://img.shields.io/badge/all_contributors-15-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-) Orion is an open-source, community-driven framework dedicated to Provable Machine Learning. It provides essential components and a new ONNX runtime for building verifiable Machine Learning models using [STARKs](https://starkware.co/stark/). @@ -82,10 +82,13 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d BemTG
BemTG

💻 📖 danilowhk
danilowhk

💻 Falco R
Falco R

💻 + dincerguner
dincerguner

💻 Rich Warner
Rich Warner

💻 Daniel Bejarano
Daniel Bejarano

📖 + + vikkydataseo
vikkydataseo

📖 - dincerguner
dincerguner

💻 + Daniel
Daniel

💻 From 0f1779b6c11d6d436ea625654ef1dfa42f918acf Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 08:03:40 +0000 Subject: [PATCH 74/78] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index fc389d5f5..fd19cc69c 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -143,6 +143,15 @@ "contributions": [ "doc" ] + }, + { + "login": "Dl-Vv", + "name": "Daniel", + "avatar_url": "https://avatars.githubusercontent.com/u/83556514?v=4", + "profile": "https://www.brilliantblocks.io/", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, @@ -152,4 +161,4 @@ "projectName": "orion", "projectOwner": "gizatechxyz", "commitType": "docs" -} \ No newline at end of file +} From d5b84eac7b88249363dfdea652b778c35145737c Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 25 Oct 2023 12:41:18 +0300 Subject: [PATCH 75/78] write tests --- nodegen/node/gemm.py | 105 +++++++++++++++++++++ src/operators/nn/functional/gemm.cairo | 6 +- src/operators/tensor/math/arithmetic.cairo | 8 +- 3 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 nodegen/node/gemm.py diff --git a/nodegen/node/gemm.py b/nodegen/node/gemm.py new file mode 100644 index 000000000..d21b70039 --- /dev/null +++ b/nodegen/node/gemm.py @@ -0,0 +1,105 @@ +from typing import Optional + +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl, Trait + + +def gemm_reference_implementation( + A: np.ndarray, + B: np.ndarray, + C: Optional[np.ndarray] = None, + alpha: float = 1.0, + beta: float = 1.0, + transA: int = 0, + transB: int = 0, +) -> np.ndarray: + A = A if transA == 0 else A.T + B = B if transB == 0 else B.T + C = C if C is not None else np.array(0) + + Y = alpha * np.dot(A, B) + beta * C + return Y + + +class Gemm(RunAll): + + @staticmethod + def gemm_default_zero_bias(): + a = np.random.ranf([3, 5]).astype(np.float32) + b = np.random.ranf([5, 4]).astype(np.float32) + c = np.zeros([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_default_no_bias" + make_node([a, b], [y], name) + make_test( + [a, b], y, "NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), false, false)", name, Trait.NN) + + @staticmethod + def gemm_default_vector_bias(): + a = np.random.ranf([2, 7]).astype(np.float32) + b = np.random.ranf([7, 4]).astype(np.float32) + c = np.random.ranf([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + c = Tensor(Dtype.FP16x16, c.shape, to_fp( + c.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_default_vector_bias" + make_node([a, b, c], [y], name) + make_test( + [a, b, c], y, "NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::None(()), Option::None(()), false, false)", name, Trait.NN) + + @staticmethod + def gemm_default_matrix_bias(): + a = np.random.ranf([3, 6]).astype(np.float32) + b = np.random.ranf([6, 4]).astype(np.float32) + c = np.random.ranf([3, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + c = Tensor(Dtype.FP16x16, c.shape, to_fp( + c.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_default_matrix_bias" + make_node([a, b, c], [y], name) + make_test( + [a, b, c], y, "NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::None(()), Option::None(()), false, false)", name, Trait.NN) + + @staticmethod + def gemm_transposeA(): + a = np.random.ranf([6, 3]).astype(np.float32) + b = np.random.ranf([6, 4]).astype(np.float32) + c = np.zeros([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c, transA=1) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_transposeA" + make_node([a, b], [y], name) + make_test( + [a, b], y, "NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), true, false)", name, Trait.NN) diff --git a/src/operators/nn/functional/gemm.cairo b/src/operators/nn/functional/gemm.cairo index 8d4498cc6..d6ec1451a 100644 --- a/src/operators/nn/functional/gemm.cairo +++ b/src/operators/nn/functional/gemm.cairo @@ -1,7 +1,7 @@ use array::SpanTrait; use orion::numbers::NumberTrait; -use orion::operators::tensor::{core::{Tensor, TensorTrait}, math::arithmetic::mul_by_val}; +use orion::operators::tensor::{core::{Tensor, TensorTrait}, math::arithmetic::mul_by_scalar}; /// Cf: NNTrait::gemm docstring fn gemm< @@ -45,10 +45,10 @@ fn gemm< match C { Option::Some(c) => { - return mul_by_val(@A.matmul(@B), alpha) + mul_by_val(@c, beta); + return mul_by_scalar(@A.matmul(@B), alpha) + mul_by_scalar(@c, beta); }, Option::None(_) => { - return mul_by_val(@A.matmul(@B), alpha); + return mul_by_scalar(@A.matmul(@B), alpha); } } } diff --git a/src/operators/tensor/math/arithmetic.cairo b/src/operators/tensor/math/arithmetic.cairo index 6c1df4c28..12fbff8c2 100644 --- a/src/operators/tensor/math/arithmetic.cairo +++ b/src/operators/tensor/math/arithmetic.cairo @@ -38,7 +38,7 @@ fn add< return TensorTrait::::new(broadcasted_shape, result.span()); } -fn add_by_val< +fn add_by_scalar< T, MAG, impl TTensor: TensorTrait, @@ -144,7 +144,7 @@ fn sub< return TensorTrait::::new(broadcasted_shape, result.span()); } -fn sub_by_val< +fn sub_by_scalar< T, MAG, impl TTensor: TensorTrait, @@ -250,7 +250,7 @@ fn mul< return TensorTrait::::new(broadcasted_shape, result.span()); } -fn mul_by_val< +fn mul_by_scalar< T, MAG, impl TTensor: TensorTrait, @@ -356,7 +356,7 @@ fn div< return TensorTrait::::new(broadcasted_shape, result.span()); } -fn div_by_val< +fn div_by_scalar< T, MAG, impl TTensor: TensorTrait, From 8069f8870ca0747ecddf2eb4e5ae86c45b2385a4 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 25 Oct 2023 14:36:33 +0300 Subject: [PATCH 76/78] write all tests --- nodegen/node/gemm.py | 83 +++++++++++++++++++ src/operators/nn/functional/gemm.cairo | 7 +- tests/src/nodes.cairo | 8 ++ tests/src/nodes/gemm_all_attributes.cairo | 24 ++++++ .../nodes/gemm_all_attributes/input_0.cairo | 26 ++++++ .../nodes/gemm_all_attributes/input_1.cairo | 34 ++++++++ .../nodes/gemm_all_attributes/input_2.cairo | 19 +++++ .../nodes/gemm_all_attributes/output_0.cairo | 29 +++++++ tests/src/nodes/gemm_alpha.cairo | 22 +++++ tests/src/nodes/gemm_alpha/input_0.cairo | 29 +++++++ tests/src/nodes/gemm_alpha/input_1.cairo | 34 ++++++++ tests/src/nodes/gemm_alpha/output_0.cairo | 26 ++++++ tests/src/nodes/gemm_beta.cairo | 24 ++++++ tests/src/nodes/gemm_beta/input_0.cairo | 28 +++++++ tests/src/nodes/gemm_beta/input_1.cairo | 42 ++++++++++ tests/src/nodes/gemm_beta/input_2.cairo | 18 ++++ tests/src/nodes/gemm_beta/output_0.cairo | 22 +++++ .../src/nodes/gemm_default_matrix_bias.cairo | 24 ++++++ .../gemm_default_matrix_bias/input_0.cairo | 32 +++++++ .../gemm_default_matrix_bias/input_1.cairo | 38 +++++++++ .../gemm_default_matrix_bias/input_2.cairo | 26 ++++++ .../gemm_default_matrix_bias/output_0.cairo | 26 ++++++ tests/src/nodes/gemm_default_no_bias.cairo | 22 +++++ .../nodes/gemm_default_no_bias/input_0.cairo | 29 +++++++ .../nodes/gemm_default_no_bias/input_1.cairo | 34 ++++++++ .../nodes/gemm_default_no_bias/output_0.cairo | 26 ++++++ .../src/nodes/gemm_default_vector_bias.cairo | 24 ++++++ .../gemm_default_vector_bias/input_0.cairo | 28 +++++++ .../gemm_default_vector_bias/input_1.cairo | 42 ++++++++++ .../gemm_default_vector_bias/input_2.cairo | 18 ++++ .../gemm_default_vector_bias/output_0.cairo | 22 +++++ tests/src/nodes/gemm_transposeA.cairo | 22 +++++ tests/src/nodes/gemm_transposeA/input_0.cairo | 32 +++++++ tests/src/nodes/gemm_transposeA/input_1.cairo | 38 +++++++++ .../src/nodes/gemm_transposeA/output_0.cairo | 26 ++++++ tests/src/nodes/gemm_transposeB.cairo | 22 +++++ tests/src/nodes/gemm_transposeB/input_0.cairo | 32 +++++++ tests/src/nodes/gemm_transposeB/input_1.cairo | 38 +++++++++ .../src/nodes/gemm_transposeB/output_0.cairo | 26 ++++++ 39 files changed, 1100 insertions(+), 2 deletions(-) create mode 100644 tests/src/nodes/gemm_all_attributes.cairo create mode 100644 tests/src/nodes/gemm_all_attributes/input_0.cairo create mode 100644 tests/src/nodes/gemm_all_attributes/input_1.cairo create mode 100644 tests/src/nodes/gemm_all_attributes/input_2.cairo create mode 100644 tests/src/nodes/gemm_all_attributes/output_0.cairo create mode 100644 tests/src/nodes/gemm_alpha.cairo create mode 100644 tests/src/nodes/gemm_alpha/input_0.cairo create mode 100644 tests/src/nodes/gemm_alpha/input_1.cairo create mode 100644 tests/src/nodes/gemm_alpha/output_0.cairo create mode 100644 tests/src/nodes/gemm_beta.cairo create mode 100644 tests/src/nodes/gemm_beta/input_0.cairo create mode 100644 tests/src/nodes/gemm_beta/input_1.cairo create mode 100644 tests/src/nodes/gemm_beta/input_2.cairo create mode 100644 tests/src/nodes/gemm_beta/output_0.cairo create mode 100644 tests/src/nodes/gemm_default_matrix_bias.cairo create mode 100644 tests/src/nodes/gemm_default_matrix_bias/input_0.cairo create mode 100644 tests/src/nodes/gemm_default_matrix_bias/input_1.cairo create mode 100644 tests/src/nodes/gemm_default_matrix_bias/input_2.cairo create mode 100644 tests/src/nodes/gemm_default_matrix_bias/output_0.cairo create mode 100644 tests/src/nodes/gemm_default_no_bias.cairo create mode 100644 tests/src/nodes/gemm_default_no_bias/input_0.cairo create mode 100644 tests/src/nodes/gemm_default_no_bias/input_1.cairo create mode 100644 tests/src/nodes/gemm_default_no_bias/output_0.cairo create mode 100644 tests/src/nodes/gemm_default_vector_bias.cairo create mode 100644 tests/src/nodes/gemm_default_vector_bias/input_0.cairo create mode 100644 tests/src/nodes/gemm_default_vector_bias/input_1.cairo create mode 100644 tests/src/nodes/gemm_default_vector_bias/input_2.cairo create mode 100644 tests/src/nodes/gemm_default_vector_bias/output_0.cairo create mode 100644 tests/src/nodes/gemm_transposeA.cairo create mode 100644 tests/src/nodes/gemm_transposeA/input_0.cairo create mode 100644 tests/src/nodes/gemm_transposeA/input_1.cairo create mode 100644 tests/src/nodes/gemm_transposeA/output_0.cairo create mode 100644 tests/src/nodes/gemm_transposeB.cairo create mode 100644 tests/src/nodes/gemm_transposeB/input_0.cairo create mode 100644 tests/src/nodes/gemm_transposeB/input_1.cairo create mode 100644 tests/src/nodes/gemm_transposeB/output_0.cairo diff --git a/nodegen/node/gemm.py b/nodegen/node/gemm.py index d21b70039..f4d7ed13d 100644 --- a/nodegen/node/gemm.py +++ b/nodegen/node/gemm.py @@ -103,3 +103,86 @@ def gemm_transposeA(): make_node([a, b], [y], name) make_test( [a, b], y, "NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), true, false)", name, Trait.NN) + + @staticmethod + def gemm_transposeB(): + a = np.random.ranf([3, 6]).astype(np.float32) + b = np.random.ranf([4, 6]).astype(np.float32) + c = np.zeros([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c, transB=1) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_transposeB" + make_node([a, b], [y], name) + make_test( + [a, b], y, "NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), false, true)", name, Trait.NN) + + @staticmethod + def gemm_alpha(): + a = np.random.ranf([3, 5]).astype(np.float32) + b = np.random.ranf([5, 4]).astype(np.float32) + c = np.zeros([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c, alpha=0.5) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_alpha" + make_node([a, b], [y], name) + make_test( + [a, b], y, "NNTrait::gemm(input_0, input_1, Option::None(()), Option::Some(FixedTrait::new(32768, false)), Option::None(()), false, false)", name, Trait.NN) + + @staticmethod + def gemm_beta(): + a = np.random.ranf([2, 7]).astype(np.float32) + b = np.random.ranf([7, 4]).astype(np.float32) + c = np.random.ranf([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c, beta=0.5) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + c = Tensor(Dtype.FP16x16, c.shape, to_fp( + c.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_beta" + make_node([a, b, c], [y], name) + make_test( + [a, b, c], y, "NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::None(()), Option::Some(FixedTrait::new(32768, false)), false, false)", name, Trait.NN) + + @staticmethod + def gemm_all_attributes(): + a = np.random.ranf([4, 3]).astype(np.float32) + b = np.random.ranf([5, 4]).astype(np.float32) + c = np.random.ranf([1, 5]).astype(np.float32) + y = gemm_reference_implementation( + a, b, c, transA=1, transB=1, alpha=0.25, beta=0.35 + ) + + a = Tensor(Dtype.FP16x16, a.shape, to_fp( + a.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, b.shape, to_fp( + b.flatten(), FixedImpl.FP16x16)) + c = Tensor(Dtype.FP16x16, c.shape, to_fp( + c.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "gemm_all_attributes" + make_node([a, b, c], [y], name) + make_test( + [a, b, c], y, "NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::Some(FixedTrait::new(16384, false)), Option::Some(FixedTrait::new(22938, false)), true, true)", name, Trait.NN) + diff --git a/src/operators/nn/functional/gemm.cairo b/src/operators/nn/functional/gemm.cairo index d6ec1451a..7f63990a1 100644 --- a/src/operators/nn/functional/gemm.cairo +++ b/src/operators/nn/functional/gemm.cairo @@ -23,6 +23,9 @@ fn gemm< transA: bool, transB: bool ) -> Tensor { + let mut A = A; + let mut B = B; + let alpha: T = if alpha.is_some() { alpha.unwrap() } else { @@ -36,11 +39,11 @@ fn gemm< }; if transA == true { - let A = A.transpose(array![1, 0].span()); + A = A.transpose(array![1, 0].span()); } if transB == true { - let B = B.transpose(array![1, 0].span()); + B = B.transpose(array![1, 0].span()); } match C { diff --git a/tests/src/nodes.cairo b/tests/src/nodes.cairo index 6d9192dc1..466b0e664 100644 --- a/tests/src/nodes.cairo +++ b/tests/src/nodes.cairo @@ -456,3 +456,11 @@ mod neg_fp16x16; mod neg_fp8x23; mod neg_i32; mod neg_i8; +mod gemm_all_attributes; +mod gemm_alpha; +mod gemm_beta; +mod gemm_default_matrix_bias; +mod gemm_default_vector_bias; +mod gemm_default_no_bias; +mod gemm_transposeA; +mod gemm_transposeB; diff --git a/tests/src/nodes/gemm_all_attributes.cairo b/tests/src/nodes/gemm_all_attributes.cairo new file mode 100644 index 000000000..84dda3884 --- /dev/null +++ b/tests/src/nodes/gemm_all_attributes.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_all_attributes() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::Some(FixedTrait::new(16384, false)), Option::Some(FixedTrait::new(22938, false)), true, true); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_all_attributes/input_0.cairo b/tests/src/nodes/gemm_all_attributes/input_0.cairo new file mode 100644 index 000000000..8f6e44c42 --- /dev/null +++ b/tests/src/nodes/gemm_all_attributes/input_0.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 25442, sign: false }); + data.append(FP16x16 { mag: 21621, sign: false }); + data.append(FP16x16 { mag: 20558, sign: false }); + data.append(FP16x16 { mag: 63086, sign: false }); + data.append(FP16x16 { mag: 42888, sign: false }); + data.append(FP16x16 { mag: 5836, sign: false }); + data.append(FP16x16 { mag: 36243, sign: false }); + data.append(FP16x16 { mag: 31967, sign: false }); + data.append(FP16x16 { mag: 64085, sign: false }); + data.append(FP16x16 { mag: 26601, sign: false }); + data.append(FP16x16 { mag: 40779, sign: false }); + data.append(FP16x16 { mag: 41935, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_all_attributes/input_1.cairo b/tests/src/nodes/gemm_all_attributes/input_1.cairo new file mode 100644 index 000000000..1ac4d61df --- /dev/null +++ b/tests/src/nodes/gemm_all_attributes/input_1.cairo @@ -0,0 +1,34 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 951, sign: false }); + data.append(FP16x16 { mag: 60848, sign: false }); + data.append(FP16x16 { mag: 51199, sign: false }); + data.append(FP16x16 { mag: 16691, sign: false }); + data.append(FP16x16 { mag: 14621, sign: false }); + data.append(FP16x16 { mag: 51626, sign: false }); + data.append(FP16x16 { mag: 33242, sign: false }); + data.append(FP16x16 { mag: 36152, sign: false }); + data.append(FP16x16 { mag: 41495, sign: false }); + data.append(FP16x16 { mag: 21214, sign: false }); + data.append(FP16x16 { mag: 63748, sign: false }); + data.append(FP16x16 { mag: 9058, sign: false }); + data.append(FP16x16 { mag: 38129, sign: false }); + data.append(FP16x16 { mag: 32448, sign: false }); + data.append(FP16x16 { mag: 34299, sign: false }); + data.append(FP16x16 { mag: 28592, sign: false }); + data.append(FP16x16 { mag: 60878, sign: false }); + data.append(FP16x16 { mag: 1143, sign: false }); + data.append(FP16x16 { mag: 2602, sign: false }); + data.append(FP16x16 { mag: 12136, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_all_attributes/input_2.cairo b/tests/src/nodes/gemm_all_attributes/input_2.cairo new file mode 100644 index 000000000..fcf62d4c2 --- /dev/null +++ b/tests/src/nodes/gemm_all_attributes/input_2.cairo @@ -0,0 +1,19 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 10671, sign: false }); + data.append(FP16x16 { mag: 42014, sign: false }); + data.append(FP16x16 { mag: 54635, sign: false }); + data.append(FP16x16 { mag: 20143, sign: false }); + data.append(FP16x16 { mag: 23206, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_all_attributes/output_0.cairo b/tests/src/nodes/gemm_all_attributes/output_0.cairo new file mode 100644 index 000000000..3b5c49233 --- /dev/null +++ b/tests/src/nodes/gemm_all_attributes/output_0.cairo @@ -0,0 +1,29 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 27243, sign: false }); + data.append(FP16x16 { mag: 36813, sign: false }); + data.append(FP16x16 { mag: 37987, sign: false }); + data.append(FP16x16 { mag: 26203, sign: false }); + data.append(FP16x16 { mag: 15897, sign: false }); + data.append(FP16x16 { mag: 22608, sign: false }); + data.append(FP16x16 { mag: 34035, sign: false }); + data.append(FP16x16 { mag: 35198, sign: false }); + data.append(FP16x16 { mag: 24134, sign: false }); + data.append(FP16x16 { mag: 15535, sign: false }); + data.append(FP16x16 { mag: 20351, sign: false }); + data.append(FP16x16 { mag: 30911, sign: false }); + data.append(FP16x16 { mag: 39882, sign: false }); + data.append(FP16x16 { mag: 23721, sign: false }); + data.append(FP16x16 { mag: 15499, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_alpha.cairo b/tests/src/nodes/gemm_alpha.cairo new file mode 100644 index 000000000..bc10ace4e --- /dev/null +++ b/tests/src/nodes/gemm_alpha.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_alpha() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::None(()), Option::Some(FixedTrait::new(32768, false)), Option::None(()), false, false); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_alpha/input_0.cairo b/tests/src/nodes/gemm_alpha/input_0.cairo new file mode 100644 index 000000000..dbe226af7 --- /dev/null +++ b/tests/src/nodes/gemm_alpha/input_0.cairo @@ -0,0 +1,29 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 25149, sign: false }); + data.append(FP16x16 { mag: 57333, sign: false }); + data.append(FP16x16 { mag: 4965, sign: false }); + data.append(FP16x16 { mag: 43218, sign: false }); + data.append(FP16x16 { mag: 49951, sign: false }); + data.append(FP16x16 { mag: 61057, sign: false }); + data.append(FP16x16 { mag: 50263, sign: false }); + data.append(FP16x16 { mag: 29479, sign: false }); + data.append(FP16x16 { mag: 3849, sign: false }); + data.append(FP16x16 { mag: 38336, sign: false }); + data.append(FP16x16 { mag: 27897, sign: false }); + data.append(FP16x16 { mag: 9815, sign: false }); + data.append(FP16x16 { mag: 10500, sign: false }); + data.append(FP16x16 { mag: 46201, sign: false }); + data.append(FP16x16 { mag: 51565, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_alpha/input_1.cairo b/tests/src/nodes/gemm_alpha/input_1.cairo new file mode 100644 index 000000000..d170ea9c4 --- /dev/null +++ b/tests/src/nodes/gemm_alpha/input_1.cairo @@ -0,0 +1,34 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7870, sign: false }); + data.append(FP16x16 { mag: 11258, sign: false }); + data.append(FP16x16 { mag: 34213, sign: false }); + data.append(FP16x16 { mag: 31148, sign: false }); + data.append(FP16x16 { mag: 29977, sign: false }); + data.append(FP16x16 { mag: 56430, sign: false }); + data.append(FP16x16 { mag: 43116, sign: false }); + data.append(FP16x16 { mag: 22990, sign: false }); + data.append(FP16x16 { mag: 3089, sign: false }); + data.append(FP16x16 { mag: 47936, sign: false }); + data.append(FP16x16 { mag: 13186, sign: false }); + data.append(FP16x16 { mag: 14386, sign: false }); + data.append(FP16x16 { mag: 63802, sign: false }); + data.append(FP16x16 { mag: 19313, sign: false }); + data.append(FP16x16 { mag: 40436, sign: false }); + data.append(FP16x16 { mag: 31890, sign: false }); + data.append(FP16x16 { mag: 34370, sign: false }); + data.append(FP16x16 { mag: 8853, sign: false }); + data.append(FP16x16 { mag: 59520, sign: false }); + data.append(FP16x16 { mag: 40977, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_alpha/output_0.cairo b/tests/src/nodes/gemm_alpha/output_0.cairo new file mode 100644 index 000000000..0ccff1b3a --- /dev/null +++ b/tests/src/nodes/gemm_alpha/output_0.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 48876, sign: false }); + data.append(FP16x16 { mag: 38402, sign: false }); + data.append(FP16x16 { mag: 61940, sign: false }); + data.append(FP16x16 { mag: 42709, sign: false }); + data.append(FP16x16 { mag: 27783, sign: false }); + data.append(FP16x16 { mag: 40822, sign: false }); + data.append(FP16x16 { mag: 54034, sign: false }); + data.append(FP16x16 { mag: 39483, sign: false }); + data.append(FP16x16 { mag: 40179, sign: false }); + data.append(FP16x16 { mag: 20753, sign: false }); + data.append(FP16x16 { mag: 49237, sign: false }); + data.append(FP16x16 { mag: 36865, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_beta.cairo b/tests/src/nodes/gemm_beta.cairo new file mode 100644 index 000000000..e12b6a0fc --- /dev/null +++ b/tests/src/nodes/gemm_beta.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_beta() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::None(()), Option::Some(FixedTrait::new(32768, false)), false, false); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_beta/input_0.cairo b/tests/src/nodes/gemm_beta/input_0.cairo new file mode 100644 index 000000000..c90783e32 --- /dev/null +++ b/tests/src/nodes/gemm_beta/input_0.cairo @@ -0,0 +1,28 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(7); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 14966, sign: false }); + data.append(FP16x16 { mag: 36896, sign: false }); + data.append(FP16x16 { mag: 4679, sign: false }); + data.append(FP16x16 { mag: 36625, sign: false }); + data.append(FP16x16 { mag: 48874, sign: false }); + data.append(FP16x16 { mag: 35563, sign: false }); + data.append(FP16x16 { mag: 40736, sign: false }); + data.append(FP16x16 { mag: 12321, sign: false }); + data.append(FP16x16 { mag: 42458, sign: false }); + data.append(FP16x16 { mag: 65341, sign: false }); + data.append(FP16x16 { mag: 43716, sign: false }); + data.append(FP16x16 { mag: 43328, sign: false }); + data.append(FP16x16 { mag: 7074, sign: false }); + data.append(FP16x16 { mag: 45946, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_beta/input_1.cairo b/tests/src/nodes/gemm_beta/input_1.cairo new file mode 100644 index 000000000..adef0b2c0 --- /dev/null +++ b/tests/src/nodes/gemm_beta/input_1.cairo @@ -0,0 +1,42 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(7); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 51391, sign: false }); + data.append(FP16x16 { mag: 22014, sign: false }); + data.append(FP16x16 { mag: 33442, sign: false }); + data.append(FP16x16 { mag: 24116, sign: false }); + data.append(FP16x16 { mag: 49410, sign: false }); + data.append(FP16x16 { mag: 60215, sign: false }); + data.append(FP16x16 { mag: 9310, sign: false }); + data.append(FP16x16 { mag: 20950, sign: false }); + data.append(FP16x16 { mag: 20541, sign: false }); + data.append(FP16x16 { mag: 21583, sign: false }); + data.append(FP16x16 { mag: 28565, sign: false }); + data.append(FP16x16 { mag: 41677, sign: false }); + data.append(FP16x16 { mag: 18308, sign: false }); + data.append(FP16x16 { mag: 25095, sign: false }); + data.append(FP16x16 { mag: 44238, sign: false }); + data.append(FP16x16 { mag: 27465, sign: false }); + data.append(FP16x16 { mag: 30581, sign: false }); + data.append(FP16x16 { mag: 41045, sign: false }); + data.append(FP16x16 { mag: 46018, sign: false }); + data.append(FP16x16 { mag: 17358, sign: false }); + data.append(FP16x16 { mag: 50102, sign: false }); + data.append(FP16x16 { mag: 16577, sign: false }); + data.append(FP16x16 { mag: 16374, sign: false }); + data.append(FP16x16 { mag: 54251, sign: false }); + data.append(FP16x16 { mag: 46337, sign: false }); + data.append(FP16x16 { mag: 15187, sign: false }); + data.append(FP16x16 { mag: 25652, sign: false }); + data.append(FP16x16 { mag: 20892, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_beta/input_2.cairo b/tests/src/nodes/gemm_beta/input_2.cairo new file mode 100644 index 000000000..1a31b973d --- /dev/null +++ b/tests/src/nodes/gemm_beta/input_2.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 42564, sign: false }); + data.append(FP16x16 { mag: 18018, sign: false }); + data.append(FP16x16 { mag: 28175, sign: false }); + data.append(FP16x16 { mag: 36784, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_beta/output_0.cairo b/tests/src/nodes/gemm_beta/output_0.cairo new file mode 100644 index 000000000..c8270eb8e --- /dev/null +++ b/tests/src/nodes/gemm_beta/output_0.cairo @@ -0,0 +1,22 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 151333, sign: false }); + data.append(FP16x16 { mag: 112550, sign: false }); + data.append(FP16x16 { mag: 112879, sign: false }); + data.append(FP16x16 { mag: 109391, sign: false }); + data.append(FP16x16 { mag: 153762, sign: false }); + data.append(FP16x16 { mag: 129994, sign: false }); + data.append(FP16x16 { mag: 134574, sign: false }); + data.append(FP16x16 { mag: 128355, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_matrix_bias.cairo b/tests/src/nodes/gemm_default_matrix_bias.cairo new file mode 100644 index 000000000..f2172e0e8 --- /dev/null +++ b/tests/src/nodes/gemm_default_matrix_bias.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_default_matrix_bias() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::None(()), Option::None(()), false, false); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_matrix_bias/input_0.cairo b/tests/src/nodes/gemm_default_matrix_bias/input_0.cairo new file mode 100644 index 000000000..7c87dd7f2 --- /dev/null +++ b/tests/src/nodes/gemm_default_matrix_bias/input_0.cairo @@ -0,0 +1,32 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 54171, sign: false }); + data.append(FP16x16 { mag: 576, sign: false }); + data.append(FP16x16 { mag: 51387, sign: false }); + data.append(FP16x16 { mag: 37774, sign: false }); + data.append(FP16x16 { mag: 47415, sign: false }); + data.append(FP16x16 { mag: 30278, sign: false }); + data.append(FP16x16 { mag: 35329, sign: false }); + data.append(FP16x16 { mag: 56770, sign: false }); + data.append(FP16x16 { mag: 29001, sign: false }); + data.append(FP16x16 { mag: 19387, sign: false }); + data.append(FP16x16 { mag: 16747, sign: false }); + data.append(FP16x16 { mag: 42410, sign: false }); + data.append(FP16x16 { mag: 53192, sign: false }); + data.append(FP16x16 { mag: 30490, sign: false }); + data.append(FP16x16 { mag: 55512, sign: false }); + data.append(FP16x16 { mag: 63983, sign: false }); + data.append(FP16x16 { mag: 45579, sign: false }); + data.append(FP16x16 { mag: 12475, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_matrix_bias/input_1.cairo b/tests/src/nodes/gemm_default_matrix_bias/input_1.cairo new file mode 100644 index 000000000..5aee549de --- /dev/null +++ b/tests/src/nodes/gemm_default_matrix_bias/input_1.cairo @@ -0,0 +1,38 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 47776, sign: false }); + data.append(FP16x16 { mag: 8702, sign: false }); + data.append(FP16x16 { mag: 39764, sign: false }); + data.append(FP16x16 { mag: 21672, sign: false }); + data.append(FP16x16 { mag: 34121, sign: false }); + data.append(FP16x16 { mag: 60787, sign: false }); + data.append(FP16x16 { mag: 50462, sign: false }); + data.append(FP16x16 { mag: 61510, sign: false }); + data.append(FP16x16 { mag: 39048, sign: false }); + data.append(FP16x16 { mag: 32834, sign: false }); + data.append(FP16x16 { mag: 57152, sign: false }); + data.append(FP16x16 { mag: 4001, sign: false }); + data.append(FP16x16 { mag: 37122, sign: false }); + data.append(FP16x16 { mag: 45910, sign: false }); + data.append(FP16x16 { mag: 22021, sign: false }); + data.append(FP16x16 { mag: 10298, sign: false }); + data.append(FP16x16 { mag: 33089, sign: false }); + data.append(FP16x16 { mag: 35378, sign: false }); + data.append(FP16x16 { mag: 1834, sign: false }); + data.append(FP16x16 { mag: 22627, sign: false }); + data.append(FP16x16 { mag: 37576, sign: false }); + data.append(FP16x16 { mag: 57351, sign: false }); + data.append(FP16x16 { mag: 22814, sign: false }); + data.append(FP16x16 { mag: 60423, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_matrix_bias/input_2.cairo b/tests/src/nodes/gemm_default_matrix_bias/input_2.cairo new file mode 100644 index 000000000..59be7972a --- /dev/null +++ b/tests/src/nodes/gemm_default_matrix_bias/input_2.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 17180, sign: false }); + data.append(FP16x16 { mag: 16229, sign: false }); + data.append(FP16x16 { mag: 15872, sign: false }); + data.append(FP16x16 { mag: 41908, sign: false }); + data.append(FP16x16 { mag: 19343, sign: false }); + data.append(FP16x16 { mag: 23171, sign: false }); + data.append(FP16x16 { mag: 40127, sign: false }); + data.append(FP16x16 { mag: 21249, sign: false }); + data.append(FP16x16 { mag: 14046, sign: false }); + data.append(FP16x16 { mag: 17154, sign: false }); + data.append(FP16x16 { mag: 31592, sign: false }); + data.append(FP16x16 { mag: 38555, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_matrix_bias/output_0.cairo b/tests/src/nodes/gemm_default_matrix_bias/output_0.cairo new file mode 100644 index 000000000..4fa3d81b8 --- /dev/null +++ b/tests/src/nodes/gemm_default_matrix_bias/output_0.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 150288, sign: false }); + data.append(FP16x16 { mag: 128258, sign: false }); + data.append(FP16x16 { mag: 118559, sign: false }); + data.append(FP16x16 { mag: 113724, sign: false }); + data.append(FP16x16 { mag: 135691, sign: false }); + data.append(FP16x16 { mag: 154786, sign: false }); + data.append(FP16x16 { mag: 152316, sign: false }); + data.append(FP16x16 { mag: 135919, sign: false }); + data.append(FP16x16 { mag: 168184, sign: false }); + data.append(FP16x16 { mag: 160656, sign: false }); + data.append(FP16x16 { mag: 162874, sign: false }); + data.append(FP16x16 { mag: 125446, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_no_bias.cairo b/tests/src/nodes/gemm_default_no_bias.cairo new file mode 100644 index 000000000..747ddc2c9 --- /dev/null +++ b/tests/src/nodes/gemm_default_no_bias.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_default_no_bias() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), false, false); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_no_bias/input_0.cairo b/tests/src/nodes/gemm_default_no_bias/input_0.cairo new file mode 100644 index 000000000..46bc3b52f --- /dev/null +++ b/tests/src/nodes/gemm_default_no_bias/input_0.cairo @@ -0,0 +1,29 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 48671, sign: false }); + data.append(FP16x16 { mag: 53291, sign: false }); + data.append(FP16x16 { mag: 61962, sign: false }); + data.append(FP16x16 { mag: 23548, sign: false }); + data.append(FP16x16 { mag: 12042, sign: false }); + data.append(FP16x16 { mag: 198, sign: false }); + data.append(FP16x16 { mag: 26605, sign: false }); + data.append(FP16x16 { mag: 42749, sign: false }); + data.append(FP16x16 { mag: 42426, sign: false }); + data.append(FP16x16 { mag: 16917, sign: false }); + data.append(FP16x16 { mag: 50488, sign: false }); + data.append(FP16x16 { mag: 10785, sign: false }); + data.append(FP16x16 { mag: 63703, sign: false }); + data.append(FP16x16 { mag: 16964, sign: false }); + data.append(FP16x16 { mag: 24102, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_no_bias/input_1.cairo b/tests/src/nodes/gemm_default_no_bias/input_1.cairo new file mode 100644 index 000000000..9c0d880c9 --- /dev/null +++ b/tests/src/nodes/gemm_default_no_bias/input_1.cairo @@ -0,0 +1,34 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 50500, sign: false }); + data.append(FP16x16 { mag: 17886, sign: false }); + data.append(FP16x16 { mag: 46985, sign: false }); + data.append(FP16x16 { mag: 55588, sign: false }); + data.append(FP16x16 { mag: 13076, sign: false }); + data.append(FP16x16 { mag: 60436, sign: false }); + data.append(FP16x16 { mag: 39821, sign: false }); + data.append(FP16x16 { mag: 26415, sign: false }); + data.append(FP16x16 { mag: 21305, sign: false }); + data.append(FP16x16 { mag: 14320, sign: false }); + data.append(FP16x16 { mag: 28448, sign: false }); + data.append(FP16x16 { mag: 25828, sign: false }); + data.append(FP16x16 { mag: 47472, sign: false }); + data.append(FP16x16 { mag: 52266, sign: false }); + data.append(FP16x16 { mag: 7390, sign: false }); + data.append(FP16x16 { mag: 56380, sign: false }); + data.append(FP16x16 { mag: 13296, sign: false }); + data.append(FP16x16 { mag: 59748, sign: false }); + data.append(FP16x16 { mag: 8798, sign: false }); + data.append(FP16x16 { mag: 32105, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_no_bias/output_0.cairo b/tests/src/nodes/gemm_default_no_bias/output_0.cairo new file mode 100644 index 000000000..c01e5ab6c --- /dev/null +++ b/tests/src/nodes/gemm_default_no_bias/output_0.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 87783, sign: false }); + data.append(FP16x16 { mag: 105727, sign: false }); + data.append(FP16x16 { mag: 98446, sign: false }); + data.append(FP16x16 { mag: 113342, sign: false }); + data.append(FP16x16 { mag: 53524, sign: false }); + data.append(FP16x16 { mag: 83190, sign: false }); + data.append(FP16x16 { mag: 41921, sign: false }); + data.append(FP16x16 { mag: 72526, sign: false }); + data.append(FP16x16 { mag: 78945, sign: false }); + data.append(FP16x16 { mag: 73149, sign: false }); + data.append(FP16x16 { mag: 75553, sign: false }); + data.append(FP16x16 { mag: 98680, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_vector_bias.cairo b/tests/src/nodes/gemm_default_vector_bias.cairo new file mode 100644 index 000000000..5892179d9 --- /dev/null +++ b/tests/src/nodes/gemm_default_vector_bias.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_default_vector_bias() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::None(()), Option::None(()), false, false); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_vector_bias/input_0.cairo b/tests/src/nodes/gemm_default_vector_bias/input_0.cairo new file mode 100644 index 000000000..604574282 --- /dev/null +++ b/tests/src/nodes/gemm_default_vector_bias/input_0.cairo @@ -0,0 +1,28 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(7); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 42416, sign: false }); + data.append(FP16x16 { mag: 877, sign: false }); + data.append(FP16x16 { mag: 21463, sign: false }); + data.append(FP16x16 { mag: 55531, sign: false }); + data.append(FP16x16 { mag: 62444, sign: false }); + data.append(FP16x16 { mag: 30762, sign: false }); + data.append(FP16x16 { mag: 15704, sign: false }); + data.append(FP16x16 { mag: 36007, sign: false }); + data.append(FP16x16 { mag: 18900, sign: false }); + data.append(FP16x16 { mag: 3784, sign: false }); + data.append(FP16x16 { mag: 356, sign: false }); + data.append(FP16x16 { mag: 51406, sign: false }); + data.append(FP16x16 { mag: 57856, sign: false }); + data.append(FP16x16 { mag: 27283, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_vector_bias/input_1.cairo b/tests/src/nodes/gemm_default_vector_bias/input_1.cairo new file mode 100644 index 000000000..835e8631a --- /dev/null +++ b/tests/src/nodes/gemm_default_vector_bias/input_1.cairo @@ -0,0 +1,42 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(7); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 40337, sign: false }); + data.append(FP16x16 { mag: 29183, sign: false }); + data.append(FP16x16 { mag: 2662, sign: false }); + data.append(FP16x16 { mag: 26364, sign: false }); + data.append(FP16x16 { mag: 42934, sign: false }); + data.append(FP16x16 { mag: 65150, sign: false }); + data.append(FP16x16 { mag: 19395, sign: false }); + data.append(FP16x16 { mag: 39868, sign: false }); + data.append(FP16x16 { mag: 12023, sign: false }); + data.append(FP16x16 { mag: 28456, sign: false }); + data.append(FP16x16 { mag: 20310, sign: false }); + data.append(FP16x16 { mag: 33530, sign: false }); + data.append(FP16x16 { mag: 15549, sign: false }); + data.append(FP16x16 { mag: 37265, sign: false }); + data.append(FP16x16 { mag: 64596, sign: false }); + data.append(FP16x16 { mag: 58778, sign: false }); + data.append(FP16x16 { mag: 41122, sign: false }); + data.append(FP16x16 { mag: 29826, sign: false }); + data.append(FP16x16 { mag: 43424, sign: false }); + data.append(FP16x16 { mag: 47301, sign: false }); + data.append(FP16x16 { mag: 5420, sign: false }); + data.append(FP16x16 { mag: 54233, sign: false }); + data.append(FP16x16 { mag: 28313, sign: false }); + data.append(FP16x16 { mag: 12356, sign: false }); + data.append(FP16x16 { mag: 54540, sign: false }); + data.append(FP16x16 { mag: 42851, sign: false }); + data.append(FP16x16 { mag: 28457, sign: false }); + data.append(FP16x16 { mag: 16731, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_vector_bias/input_2.cairo b/tests/src/nodes/gemm_default_vector_bias/input_2.cairo new file mode 100644 index 000000000..bef8eccaf --- /dev/null +++ b/tests/src/nodes/gemm_default_vector_bias/input_2.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 42958, sign: false }); + data.append(FP16x16 { mag: 57252, sign: false }); + data.append(FP16x16 { mag: 44948, sign: false }); + data.append(FP16x16 { mag: 18261, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_default_vector_bias/output_0.cairo b/tests/src/nodes/gemm_default_vector_bias/output_0.cairo new file mode 100644 index 000000000..9dcd30c63 --- /dev/null +++ b/tests/src/nodes/gemm_default_vector_bias/output_0.cairo @@ -0,0 +1,22 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 141552, sign: false }); + data.append(FP16x16 { mag: 182056, sign: false }); + data.append(FP16x16 { mag: 169805, sign: false }); + data.append(FP16x16 { mag: 151525, sign: false }); + data.append(FP16x16 { mag: 138029, sign: false }); + data.append(FP16x16 { mag: 183035, sign: false }); + data.append(FP16x16 { mag: 124434, sign: false }); + data.append(FP16x16 { mag: 101476, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeA.cairo b/tests/src/nodes/gemm_transposeA.cairo new file mode 100644 index 000000000..6863e6327 --- /dev/null +++ b/tests/src/nodes/gemm_transposeA.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_transposeA() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), true, false); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeA/input_0.cairo b/tests/src/nodes/gemm_transposeA/input_0.cairo new file mode 100644 index 000000000..6bfd7ea69 --- /dev/null +++ b/tests/src/nodes/gemm_transposeA/input_0.cairo @@ -0,0 +1,32 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 60981, sign: false }); + data.append(FP16x16 { mag: 58843, sign: false }); + data.append(FP16x16 { mag: 19404, sign: false }); + data.append(FP16x16 { mag: 56768, sign: false }); + data.append(FP16x16 { mag: 2442, sign: false }); + data.append(FP16x16 { mag: 45529, sign: false }); + data.append(FP16x16 { mag: 1800, sign: false }); + data.append(FP16x16 { mag: 38751, sign: false }); + data.append(FP16x16 { mag: 29332, sign: false }); + data.append(FP16x16 { mag: 17874, sign: false }); + data.append(FP16x16 { mag: 39405, sign: false }); + data.append(FP16x16 { mag: 7286, sign: false }); + data.append(FP16x16 { mag: 23687, sign: false }); + data.append(FP16x16 { mag: 7092, sign: false }); + data.append(FP16x16 { mag: 20015, sign: false }); + data.append(FP16x16 { mag: 26356, sign: false }); + data.append(FP16x16 { mag: 49636, sign: false }); + data.append(FP16x16 { mag: 54933, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeA/input_1.cairo b/tests/src/nodes/gemm_transposeA/input_1.cairo new file mode 100644 index 000000000..73cc1030c --- /dev/null +++ b/tests/src/nodes/gemm_transposeA/input_1.cairo @@ -0,0 +1,38 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 14140, sign: false }); + data.append(FP16x16 { mag: 4427, sign: false }); + data.append(FP16x16 { mag: 61152, sign: false }); + data.append(FP16x16 { mag: 63478, sign: false }); + data.append(FP16x16 { mag: 46719, sign: false }); + data.append(FP16x16 { mag: 40516, sign: false }); + data.append(FP16x16 { mag: 5299, sign: false }); + data.append(FP16x16 { mag: 27500, sign: false }); + data.append(FP16x16 { mag: 22968, sign: false }); + data.append(FP16x16 { mag: 16628, sign: false }); + data.append(FP16x16 { mag: 14772, sign: false }); + data.append(FP16x16 { mag: 37261, sign: false }); + data.append(FP16x16 { mag: 62578, sign: false }); + data.append(FP16x16 { mag: 20150, sign: false }); + data.append(FP16x16 { mag: 38069, sign: false }); + data.append(FP16x16 { mag: 9702, sign: false }); + data.append(FP16x16 { mag: 12410, sign: false }); + data.append(FP16x16 { mag: 30336, sign: false }); + data.append(FP16x16 { mag: 65424, sign: false }); + data.append(FP16x16 { mag: 37187, sign: false }); + data.append(FP16x16 { mag: 28867, sign: false }); + data.append(FP16x16 { mag: 1671, sign: false }); + data.append(FP16x16 { mag: 57203, sign: false }); + data.append(FP16x16 { mag: 17320, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeA/output_0.cairo b/tests/src/nodes/gemm_transposeA/output_0.cairo new file mode 100644 index 000000000..ed2738343 --- /dev/null +++ b/tests/src/nodes/gemm_transposeA/output_0.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 87420, sign: false }); + data.append(FP16x16 { mag: 56805, sign: false }); + data.append(FP16x16 { mag: 118934, sign: false }); + data.append(FP16x16 { mag: 106965, sign: false }); + data.append(FP16x16 { mag: 88852, sign: false }); + data.append(FP16x16 { mag: 31983, sign: false }); + data.append(FP16x16 { mag: 137135, sign: false }); + data.append(FP16x16 { mag: 103030, sign: false }); + data.append(FP16x16 { mag: 81868, sign: false }); + data.append(FP16x16 { mag: 49808, sign: false }); + data.append(FP16x16 { mag: 100563, sign: false }); + data.append(FP16x16 { mag: 81533, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeB.cairo b/tests/src/nodes/gemm_transposeB.cairo new file mode 100644 index 000000000..2146957cc --- /dev/null +++ b/tests/src/nodes/gemm_transposeB.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gemm_transposeB() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = NNTrait::gemm(input_0, input_1, Option::None(()), Option::None(()), Option::None(()), false, true); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeB/input_0.cairo b/tests/src/nodes/gemm_transposeB/input_0.cairo new file mode 100644 index 000000000..3d609ad0f --- /dev/null +++ b/tests/src/nodes/gemm_transposeB/input_0.cairo @@ -0,0 +1,32 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 44827, sign: false }); + data.append(FP16x16 { mag: 53565, sign: false }); + data.append(FP16x16 { mag: 1198, sign: false }); + data.append(FP16x16 { mag: 31917, sign: false }); + data.append(FP16x16 { mag: 38005, sign: false }); + data.append(FP16x16 { mag: 22276, sign: false }); + data.append(FP16x16 { mag: 34928, sign: false }); + data.append(FP16x16 { mag: 28767, sign: false }); + data.append(FP16x16 { mag: 18918, sign: false }); + data.append(FP16x16 { mag: 40664, sign: false }); + data.append(FP16x16 { mag: 26252, sign: false }); + data.append(FP16x16 { mag: 26596, sign: false }); + data.append(FP16x16 { mag: 8752, sign: false }); + data.append(FP16x16 { mag: 15994, sign: false }); + data.append(FP16x16 { mag: 483, sign: false }); + data.append(FP16x16 { mag: 27831, sign: false }); + data.append(FP16x16 { mag: 28378, sign: false }); + data.append(FP16x16 { mag: 25924, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeB/input_1.cairo b/tests/src/nodes/gemm_transposeB/input_1.cairo new file mode 100644 index 000000000..a726da8d5 --- /dev/null +++ b/tests/src/nodes/gemm_transposeB/input_1.cairo @@ -0,0 +1,38 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 28356, sign: false }); + data.append(FP16x16 { mag: 19354, sign: false }); + data.append(FP16x16 { mag: 2749, sign: false }); + data.append(FP16x16 { mag: 24026, sign: false }); + data.append(FP16x16 { mag: 11157, sign: false }); + data.append(FP16x16 { mag: 62112, sign: false }); + data.append(FP16x16 { mag: 8802, sign: false }); + data.append(FP16x16 { mag: 40701, sign: false }); + data.append(FP16x16 { mag: 11492, sign: false }); + data.append(FP16x16 { mag: 56717, sign: false }); + data.append(FP16x16 { mag: 12174, sign: false }); + data.append(FP16x16 { mag: 37607, sign: false }); + data.append(FP16x16 { mag: 18568, sign: false }); + data.append(FP16x16 { mag: 56759, sign: false }); + data.append(FP16x16 { mag: 17097, sign: false }); + data.append(FP16x16 { mag: 39335, sign: false }); + data.append(FP16x16 { mag: 50570, sign: false }); + data.append(FP16x16 { mag: 54411, sign: false }); + data.append(FP16x16 { mag: 25640, sign: false }); + data.append(FP16x16 { mag: 55921, sign: false }); + data.append(FP16x16 { mag: 37203, sign: false }); + data.append(FP16x16 { mag: 10548, sign: false }); + data.append(FP16x16 { mag: 48030, sign: false }); + data.append(FP16x16 { mag: 37338, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/tests/src/nodes/gemm_transposeB/output_0.cairo b/tests/src/nodes/gemm_transposeB/output_0.cairo new file mode 100644 index 000000000..c543f1f7f --- /dev/null +++ b/tests/src/nodes/gemm_transposeB/output_0.cairo @@ -0,0 +1,26 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::FixedTrait; +use orion::numbers::FP16x16; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 74549, sign: false }); + data.append(FP16x16 { mag: 86964, sign: false }); + data.append(FP16x16 { mag: 126384, sign: false }); + data.append(FP16x16 { mag: 109608, sign: false }); + data.append(FP16x16 { mag: 68987, sign: false }); + data.append(FP16x16 { mag: 81207, sign: false }); + data.append(FP16x16 { mag: 106493, sign: false }); + data.append(FP16x16 { mag: 89890, sign: false }); + data.append(FP16x16 { mag: 48135, sign: false }); + data.append(FP16x16 { mag: 55429, sign: false }); + data.append(FP16x16 { mag: 76585, sign: false }); + data.append(FP16x16 { mag: 57394, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file From 755556d7738faac7b79db57ef899831a6cba5391 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 25 Oct 2023 15:14:25 +0300 Subject: [PATCH 77/78] doc --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + .../operators/neural-network/README.md | 3 +- .../operators/neural-network/nn.gemm.md | 67 +++++++++++++++++ src/operators/nn/core.cairo | 71 ++++++++++++++++++- tests/src/nodes/gemm_all_attributes.cairo | 21 ++++-- 6 files changed, 155 insertions(+), 9 deletions(-) create mode 100644 docs/framework/operators/neural-network/nn.gemm.md diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index f1543c5c6..98b1b414b 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -100,6 +100,7 @@ * [nn.linear](framework/operators/neural-network/nn.linear.md) * [nn.hard\_sigmoid](framework/operators/neural-network/nn.hard\_sigmoid.md) * [nn.thresholded\_relu](framework/operators/neural-network/nn.thresholded_relu.md) + * [nn.gemm](framework/operators/neural-network/nn.gemm.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index c2d49c2be..309892e1f 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -62,5 +62,6 @@ You can see below the list of current supported ONNX Operators: | [Clip](operators/tensor/tensor.clip.md) | :white\_check\_mark: | | [Identity](operators/tensor/tensor.identity.md) | :white\_check\_mark: | | [And](operators/tensor/tensor.and.md) | :white\_check\_mark: | +| [Gemm](operators/neural-network/nn.gemm.md) | :white\_check\_mark: | Current Operators support: **51/156 (33%)** diff --git a/docs/framework/operators/neural-network/README.md b/docs/framework/operators/neural-network/README.md index b6ec74fd0..cd1c92f8d 100644 --- a/docs/framework/operators/neural-network/README.md +++ b/docs/framework/operators/neural-network/README.md @@ -32,5 +32,6 @@ Orion supports currently these `NN` types. | [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. | | [`nn.linear`](nn.linear.md) | Performs a linear transformation of the input tensor using the provided weights and bias. | | [`nn.hard_sigmoid`](nn.hard\_sigmoid.md) | Applies the Hard Sigmoid function to an n-dimensional input tensor. | -| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | performs the thresholded relu activation function element-wise. | +| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | Performs the thresholded relu activation function element-wise. | +| [`nn.gemm`](nn.gemm.md) | Performs General Matrix multiplication. | diff --git a/docs/framework/operators/neural-network/nn.gemm.md b/docs/framework/operators/neural-network/nn.gemm.md new file mode 100644 index 000000000..b89d884fc --- /dev/null +++ b/docs/framework/operators/neural-network/nn.gemm.md @@ -0,0 +1,67 @@ +# NNTrait::gemm + +```rust + fn gemm( + A: Tensor, + B: Tensor, + C: Option>, + alpha: Option, + beta: Option, + transA: bool, + transB: bool + ) -> Tensor; +``` + +Performs General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 + +* A' = transpose(A) if transA else A +* B' = transpose(B) if transB else B + +Compute `Y = alpha * A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). +`A` will be transposed before doing the computation if attribute `transA` is `true`, same for `B` and `transB`. + +## Args + +* `A`(`Tensor`) - Input tensor A. The shape of `A` should be (M, K) if `transA` is `false`, or (K, M) if `transA` is `true`. +* `B`(`Tensor`) - Input tensor B. The shape of `B` should be (K, N) if `transB` is `false`, or (N, K) if `transB` is `true`. +* `C`(`Option>`) - Optional input tensor C. The shape of C should be unidirectional broadcastable to (M, N). +* `alpha`(`Option`) - Optional scalar multiplier for the product of input tensors `A * B`. +* `beta`(`Option`) - Optional scalar multiplier for input tensor `C`. +* `transA`(`bool`) - Whether `A` should be transposed. +* `transB`(`bool`) - Whether `B` should be transposed. + +## Returns + +A `Tensor` of shape (M, N). + +## Examples + +```rust + mod input_0; + mod input_1; + mod input_2; + + use orion::operators::nn::NNTrait; + use orion::numbers::FixedTrait; + use orion::operators::nn::FP16x16NN; + use orion::operators::tensor::FP16x16TensorPartialEq; + + fn gemm_all_attributes_example() -> Tensor { + let input_0 = input_0::input_0(); // shape [4;3] + let input_1 = input_1::input_1(); // shape [5;4] + let input_2 = input_2::input_2(); // shape [1;5] + + let y = NNTrait::gemm( + input_0, + input_1, + Option::Some(input_2), + Option::Some(FixedTrait::new(16384, false)), // 0.25 + Option::Some(FixedTrait::new(22938, false)), // 0.35 + true, + true + ); + + return y; + } + >>> tensor of shape [3;5] +```` diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index 060c30ee3..c0e391dd2 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -11,7 +11,8 @@ use orion::operators::tensor::core::Tensor; /// softplus - Applies the Softplus function element-wise. /// linear - Performs a linear transformation of the input tensor using the provided weights and bias. /// hard_sigmoid - Applies the Hard Sigmoid function to an n-dimensional input tensor. -/// thresholded_relu - performs the thresholded relu activation function element-wise. +/// thresholded_relu - Performs the thresholded relu activation function element-wise. +/// gemm - Performs General Matrix multiplication. trait NNTrait { /// # NNTrait::relu /// @@ -557,6 +558,74 @@ trait NNTrait { /// ``` /// fn thresholded_relu(tensor: @Tensor, alpha: @T) -> Tensor; + /// # NNTrait::gemm + /// + /// ```rust + /// fn gemm( + /// A: Tensor, + /// B: Tensor, + /// C: Option>, + /// alpha: Option, + /// beta: Option, + /// transA: bool, + /// transB: bool + /// ) -> Tensor; + /// ``` + /// + /// Performs General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 + /// + /// * A' = transpose(A) if transA else A + /// * B' = transpose(B) if transB else B + /// + /// Compute `Y = alpha * A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). + /// `A` will be transposed before doing the computation if attribute `transA` is `true`, same for `B` and `transB`. + /// + /// ## Args + /// + /// * `A`(`Tensor`) - Input tensor A. The shape of `A` should be (M, K) if `transA` is `false`, or (K, M) if `transA` is `true`. + /// * `B`(`Tensor`) - Input tensor B. The shape of `B` should be (K, N) if `transB` is `false`, or (N, K) if `transB` is `true`. + /// * `C`(`Option>`) - Optional input tensor C. The shape of C should be unidirectional broadcastable to (M, N). + /// * `alpha`(`Option`) - Optional scalar multiplier for the product of input tensors `A * B`. + /// * `beta`(`Option`) - Optional scalar multiplier for input tensor `C`. + /// * `transA`(`bool`) - Whether `A` should be transposed. + /// * `transB`(`bool`) - Whether `B` should be transposed. + /// + /// ## Returns + /// + /// A `Tensor` of shape (M, N). + /// + /// ## Examples + /// + /// ```rust + /// mod input_0; + /// mod input_1; + /// mod input_2; + /// + /// use orion::operators::nn::NNTrait; + /// use orion::numbers::FixedTrait; + /// use orion::operators::nn::FP16x16NN; + /// use orion::operators::tensor::FP16x16TensorPartialEq; + /// + /// fn gemm_all_attributes_example() -> Tensor { + /// let input_0 = input_0::input_0(); // shape [4;3] + /// let input_1 = input_1::input_1(); // shape [5;4] + /// let input_2 = input_2::input_2(); // shape [1;5] + /// + /// let y = NNTrait::gemm( + /// input_0, + /// input_1, + /// Option::Some(input_2), + /// Option::Some(FixedTrait::new(16384, false)), // 0.25 + /// Option::Some(FixedTrait::new(22938, false)), // 0.35 + /// true, + /// true + /// ); + /// + /// return y; + /// } + /// >>> tensor of shape [3;5] + /// ```` + /// fn gemm( A: Tensor, B: Tensor, diff --git a/tests/src/nodes/gemm_all_attributes.cairo b/tests/src/nodes/gemm_all_attributes.cairo index 84dda3884..30949c209 100644 --- a/tests/src/nodes/gemm_all_attributes.cairo +++ b/tests/src/nodes/gemm_all_attributes.cairo @@ -1,8 +1,7 @@ -mod input_0; -mod input_1; -mod input_2; -mod output_0; - +mod input_0; +mod input_1; +mod input_2; +mod output_0; use orion::operators::nn::NNTrait; use orion::numbers::FixedTrait; @@ -18,7 +17,15 @@ fn test_gemm_all_attributes() { let input_2 = input_2::input_2(); let z = output_0::output_0(); - let y = NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::Some(FixedTrait::new(16384, false)), Option::Some(FixedTrait::new(22938, false)), true, true); + let y = NNTrait::gemm( + input_0, + input_1, + Option::Some(input_2), + Option::Some(FixedTrait::new(16384, false)), + Option::Some(FixedTrait::new(22938, false)), + true, + true + ); assert_eq(y, z); -} \ No newline at end of file +} From 8393df17afe8c2d794738eb666d6b245d0f4bdc6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 25 Oct 2023 15:45:41 +0300 Subject: [PATCH 78/78] Update compatibility.md --- docs/framework/compatibility.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 309892e1f..91457e120 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -44,6 +44,7 @@ You can see below the list of current supported ONNX Operators: | [HardSigmoid](operators/neural-network/nn.hard\_sigmoid.md) | :white\_check\_mark: | | [Sinh](operators/tensor/tensor.sinh.md) | :white\_check\_mark: | | [Asinh](operators/tensor/tensor.asinh.md) | :white\_check\_mark: | +| [Atanh](operators/tensor/tensor.atanh.md) | :white\_check\_mark: | | [Cosh](operators/tensor/tensor.cosh.md) | :white\_check\_mark: | | [ACosh](operators/tensor/tensor.acosh.md) | :white\_check\_mark: | | [Tanh](operators/tensor/tensor.tanh.md) | :white\_check\_mark: | @@ -62,6 +63,8 @@ You can see below the list of current supported ONNX Operators: | [Clip](operators/tensor/tensor.clip.md) | :white\_check\_mark: | | [Identity](operators/tensor/tensor.identity.md) | :white\_check\_mark: | | [And](operators/tensor/tensor.and.md) | :white\_check\_mark: | +| [Xor](operators/tensor/tensor.xor.md) | :white\_check\_mark: | +| [Or](operators/tensor/tensor.or.md) | :white\_check\_mark: | | [Gemm](operators/neural-network/nn.gemm.md) | :white\_check\_mark: | -Current Operators support: **51/156 (33%)** +Current Operators support: **60/156 (38%)**