From 3d751ed4578f2f39fd10d9e78560763369b37f9a Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 28 Sep 2023 15:52:01 +0300 Subject: [PATCH] add randomness in `best_split` function --- src/operators/ml/tree_regressor/core.cairo | 25 ++++++++++++++++------ src/tests.cairo | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index e8311e39f..a0afba56a 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -1,3 +1,5 @@ +use cubit::f64::procgen::rand::u64_between; + use orion::numbers::{FixedTrait}; #[derive(Copy, Drop)] @@ -197,6 +199,7 @@ fn best_split< MAG, impl FFixedTrait: FixedTrait, impl TPartialOrd: PartialOrd, + impl TPartialEq: PartialEq, impl TAddEq: AddEq, impl TAdd: Add, impl TSub: Sub, @@ -211,8 +214,7 @@ fn best_split< ) -> (usize, T, T) { let mut best_mse = FixedTrait::MAX(); let mut best_split_feature = 0; - let mut best_split_value = FixedTrait::ZERO(); - let mut best_prediction = FixedTrait::ZERO(); + let mut best_splits: Array<(usize, T, T)> = ArrayTrait::new(); let n_features: u32 = (*data[0]).len(); @@ -298,10 +300,11 @@ fn best_split< let current_mse = (left_target_as_fp * mse(left_target.span(), left_pred)) + (right_target_as_fp * mse(right_target.span(), right_pred)); - if current_mse < best_mse { - best_mse = current_mse; - best_split_feature = feature; - best_split_value = *value; + if !(current_mse > best_mse) { + if current_mse < best_mse { + best_mse = current_mse; + best_splits = array![]; + } let mut total_sum = FixedTrait::ZERO(); let mut target_copy = target; @@ -316,8 +319,10 @@ fn best_split< }; }; - best_prediction = total_sum + let prediction = total_sum / FixedTrait::new_unscaled(target.len().into(), false); + + best_splits.append((feature, *value, prediction)); } } }, @@ -330,6 +335,11 @@ fn best_split< feature += 1; }; + let random_idx: usize = u64_between(42, 0, best_splits.len().into()) // TODO: add seed + .try_into() + .unwrap(); + let (best_split_feature, best_split_value, best_prediction) = *best_splits.at(random_idx); + (best_split_feature, best_split_value, best_prediction) } @@ -338,6 +348,7 @@ fn build_tree< MAG, impl FFixedTrait: FixedTrait, impl TPartialOrd: PartialOrd, + impl TPartialEq: PartialEq, impl TAddEq: AddEq, impl TAdd: Add, impl TSub: Sub, diff --git a/src/tests.cairo b/src/tests.cairo index dfaf9eda7..e1e2e893f 100644 --- a/src/tests.cairo +++ b/src/tests.cairo @@ -1,6 +1,6 @@ mod numbers; mod performance; mod tensor_core; -mod nodes; +// mod nodes; mod helpers; mod ml; \ No newline at end of file