Skip to content

Commit

Permalink
replace build_tree with fit
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Sep 28, 2023
1 parent 3d751ed commit c66ad63
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
* [nn.linear](framework/operators/neural-network/nn.linear.md)
* [Machine Learning](framework/operators/machine-learning/README.md)
* [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md)
* [tree.build](framework/operators/machine-learning/tree-regressor/tree.build_tree.md)
* [tree.build](framework/operators/machine-learning/tree-regressor/tree.fit.md)
* [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md)

## 🏛 Hub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ Orion supports currently only fixed point data types for `TreeRegressorTrait`.

| function | description |
| --- | --- |
| [`tree.build_tree`](tree.build\_tree.md) | Constructs a decision tree regressor based on the provided data and target values. |
| [`tree.fit`](tree.fit\_tree.md) | Constructs a decision tree regressor based on the provided data and target values. |
| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. |

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TreeRegressorTrait::build_tree
# TreeRegressorTrait::fit

```rust
fn build_tree(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
```

Builds a decision tree based on the provided data and target values up to a specified maximum depth.
Expand Down Expand Up @@ -44,6 +44,6 @@ fn tree_regressor_example() {
]
.span();

TreeRegressorTrait::build_tree(data, target, 3);
TreeRegressorTrait::fit(data, target, 3);
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn tree_regressor_example() {
]
.span();

let mut tree = TreeRegressorTrait::build_tree(data, target, 3);
let mut tree = TreeRegressorTrait::fit(data, target, 3);

let prediction_1 = tree
.predict(
Expand Down
18 changes: 9 additions & 9 deletions src/operators/ml/tree_regressor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ struct TreeNode<T> {

/// Trait
///
/// build_tree - Constructs a decision tree regressor based on the provided data and target values.
/// fit - Constructs a decision tree regressor based on the provided data and target values.
/// predict - Given a set of features, predicts the target value using the constructed decision tree.
trait TreeRegressorTrait<T> {
/// # TreeRegressorTrait::build_tree
/// # TreeRegressorTrait::fit
///
/// ```rust
/// fn build_tree(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
/// fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
/// ```
///
/// Builds a decision tree based on the provided data and target values up to a specified maximum depth.
Expand Down Expand Up @@ -62,11 +62,11 @@ trait TreeRegressorTrait<T> {
/// ]
/// .span();
///
/// TreeRegressorTrait::build_tree(data, target, 3);
/// TreeRegressorTrait::fit(data, target, 3);
/// }
/// ```
///
fn build_tree(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
/// # TreeRegressorTrait::predict
///
/// ```rust
Expand Down Expand Up @@ -112,7 +112,7 @@ trait TreeRegressorTrait<T> {
/// ]
/// .span();
///
/// let mut tree = TreeRegressorTrait::build_tree(data, target, 3);
/// let mut tree = TreeRegressorTrait::fit(data, target, 3);
///
/// let prediction_1 = tree
/// .predict(
Expand Down Expand Up @@ -343,7 +343,7 @@ fn best_split<
(best_split_feature, best_split_value, best_prediction)
}

fn build_tree<
fn fit<
T,
MAG,
impl FFixedTrait: FixedTrait<T, MAG>,
Expand Down Expand Up @@ -412,10 +412,10 @@ fn build_tree<

TreeNode {
left: Option::Some(
BoxTrait::new(build_tree(left_data.span(), left_target.span(), depth + 1, max_depth))
BoxTrait::new(fit(left_data.span(), left_target.span(), depth + 1, max_depth))
),
right: Option::Some(
BoxTrait::new(build_tree(right_data.span(), right_target.span(), depth + 1, max_depth))
BoxTrait::new(fit(right_data.span(), right_target.span(), depth + 1, max_depth))
),
split_feature,
split_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use orion::operators::ml::tree_regressor::core;
use orion::numbers::FP16x16;

impl FP16x16TreeRegressor of TreeRegressorTrait<FP16x16> {
fn build_tree(
fn fit(
data: Span<Span<FP16x16>>, target: Span<FP16x16>, max_depth: usize
) -> TreeNode<FP16x16> {
core::build_tree(data, target, 0, max_depth)
core::fit(data, target, 0, max_depth)
}

fn predict(ref self: TreeNode<FP16x16>, features: Span<FP16x16>) -> FP16x16 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use orion::operators::ml::tree_regressor::core;
use orion::numbers::{FP32x32, FP32x32Impl};

impl FP32x32TreeRegressor of TreeRegressorTrait<FP32x32> {
fn build_tree(
fn fit(
data: Span<Span<FP32x32>>, target: Span<FP32x32>, max_depth: usize
) -> TreeNode<FP32x32> {
core::build_tree(data, target, 0, max_depth)
core::fit(data, target, 0, max_depth)
}

fn predict(ref self: TreeNode<FP32x32>, features: Span<FP32x32>) -> FP32x32 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use orion::operators::ml::tree_regressor::core;
use orion::numbers::{FP64x64, FP64x64Impl};

impl FP64x64TreeRegressor of TreeRegressorTrait<FP64x64> {
fn build_tree(
fn fit(
data: Span<Span<FP64x64>>, target: Span<FP64x64>, max_depth: usize
) -> TreeNode<FP64x64> {
core::build_tree(data, target, 0, max_depth)
core::fit(data, target, 0, max_depth)
}

fn predict(ref self: TreeNode<FP64x64>, features: Span<FP64x64>) -> FP64x64 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use orion::operators::ml::tree_regressor::core;
use orion::numbers::FP8x23;

impl FP8x23TreeRegressor of TreeRegressorTrait<FP8x23> {
fn build_tree(
fn fit(
data: Span<Span<FP8x23>>, target: Span<FP8x23>, max_depth: usize
) -> TreeNode<FP8x23> {
core::build_tree(data, target, 0, max_depth)
core::fit(data, target, 0, max_depth)
}

fn predict(ref self: TreeNode<FP8x23>, features: Span<FP8x23>) -> FP8x23 {
Expand Down
2 changes: 1 addition & 1 deletion src/tests/ml/tree_regressor.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn test_tree() {
]
.span();

let mut tree = TreeRegressorTrait::build_tree(data, target, 3);
let mut tree = TreeRegressorTrait::fit(data, target, 3);

let prediction_1 = tree
.predict(
Expand Down

0 comments on commit c66ad63

Please sign in to comment.