Skip to content

Commit

Permalink
add randomness in best_split function
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Sep 28, 2023
1 parent c619d9f commit 3d751ed
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
25 changes: 18 additions & 7 deletions src/operators/ml/tree_regressor/core.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use cubit::f64::procgen::rand::u64_between;

use orion::numbers::{FixedTrait};

#[derive(Copy, Drop)]
Expand Down Expand Up @@ -197,6 +199,7 @@ fn best_split<
MAG,
impl FFixedTrait: FixedTrait<T, MAG>,
impl TPartialOrd: PartialOrd<T>,
impl TPartialEq: PartialEq<T>,
impl TAddEq: AddEq<T>,
impl TAdd: Add<T>,
impl TSub: Sub<T>,
Expand All @@ -211,8 +214,7 @@ fn best_split<
) -> (usize, T, T) {
let mut best_mse = FixedTrait::MAX();
let mut best_split_feature = 0;
let mut best_split_value = FixedTrait::ZERO();
let mut best_prediction = FixedTrait::ZERO();
let mut best_splits: Array<(usize, T, T)> = ArrayTrait::new();

let n_features: u32 = (*data[0]).len();

Expand Down Expand Up @@ -298,10 +300,11 @@ fn best_split<
let current_mse = (left_target_as_fp * mse(left_target.span(), left_pred))
+ (right_target_as_fp * mse(right_target.span(), right_pred));

if current_mse < best_mse {
best_mse = current_mse;
best_split_feature = feature;
best_split_value = *value;
if !(current_mse > best_mse) {
if current_mse < best_mse {
best_mse = current_mse;
best_splits = array![];
}

let mut total_sum = FixedTrait::ZERO();
let mut target_copy = target;
Expand All @@ -316,8 +319,10 @@ fn best_split<
};
};

best_prediction = total_sum
let prediction = total_sum
/ FixedTrait::new_unscaled(target.len().into(), false);

best_splits.append((feature, *value, prediction));
}
}
},
Expand All @@ -330,6 +335,11 @@ fn best_split<
feature += 1;
};

let random_idx: usize = u64_between(42, 0, best_splits.len().into()) // TODO: add seed
.try_into()
.unwrap();
let (best_split_feature, best_split_value, best_prediction) = *best_splits.at(random_idx);

(best_split_feature, best_split_value, best_prediction)
}

Expand All @@ -338,6 +348,7 @@ fn build_tree<
MAG,
impl FFixedTrait: FixedTrait<T, MAG>,
impl TPartialOrd: PartialOrd<T>,
impl TPartialEq: PartialEq<T>,
impl TAddEq: AddEq<T>,
impl TAdd: Add<T>,
impl TSub: Sub<T>,
Expand Down
2 changes: 1 addition & 1 deletion src/tests.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod numbers;
mod performance;
mod tensor_core;
mod nodes;
// mod nodes;
mod helpers;
mod ml;

0 comments on commit 3d751ed

Please sign in to comment.