Skip to content

Commit

Permalink
Revert "Merge branch 'main' into fix-doc-conflict"
Browse files Browse the repository at this point in the history
This reverts commit 707d664, reversing
changes made to 020e102.
  • Loading branch information
raphaelDkhn committed Sep 29, 2023
1 parent 707d664 commit e10183f
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@
* [nn.softsign](framework/operators/neural-network/nn.softsign.md)
* [nn.softplus](framework/operators/neural-network/nn.softplus.md)
* [nn.linear](framework/operators/neural-network/nn.linear.md)
* [Machine Learning](framework/operators/machine-learning/README.md)
* [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md)
* [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md)
* [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md)

## 🏛 Hub

Expand Down
2 changes: 2 additions & 0 deletions docs/framework/operators/machine-learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Machine Learning

22 changes: 22 additions & 0 deletions docs/framework/operators/machine-learning/tree-regressor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Tree Regressor

`TreeRegressorTrait` provides a trait definition for decision tree regression. This trait offers functionalities to build a decision tree and predict target values based on input features.

```rust
use orion::operators::ml::TreeRegressorTrait;
```

### Data types

Orion supports currently only fixed point data types for `TreeRegressorTrait`.

| Data type | dtype |
| -------------------- | ------------------------------------------------------------- |
| Fixed point (signed) | `TreeRegressorTrait<FP8x23 \| FP16x16 \| FP64x64 \| FP32x32>` |

***

| function | description |
| --- | --- |
| [`tree.fit`](tree.fit.md) | Constructs a decision tree regressor based on the provided data and target values. |
| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. |
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# TreeRegressorTrait::fit

```rust
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize, random_state: usize) -> TreeNode<T>;
```

Builds a decision tree based on the provided data and target values up to a specified maximum depth.

## Args

* `data`: A span of spans representing rows of features in the dataset.
* `target`: A span representing the target values corresponding to each row in the dataset.
* `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached.
* `random_state`: It ensures that the tie-breaking is consistent across multiple runs, leading to reproducible results.

## Returns

A `TreeNode` representing the root of the constructed decision tree.

## Type Constraints

Constrain input and output types to fixed point tensors.

## Examples

```rust
use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait};
use orion::numbers::{FP16x16, FixedTrait};

fn tree_regressor_example() {

let data = array![
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(),
array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(),
array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(),
array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(),
]
.span();

let target = array![
FixedTrait::new_unscaled(2, false),
FixedTrait::new_unscaled(4, false),
FixedTrait::new_unscaled(6, false),
FixedTrait::new_unscaled(8, false)
]
.span();

TreeRegressorTrait::fit(data, target, 3, 42);
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# tree.predict

```rust
fn predict(ref self: TreeNode<T>, features: Span<T>) -> T;
```

Predicts the target value for a set of features using the provided decision tree.

## Args

* `self`: A reference to the decision tree used for making the prediction.
* `features`: A span representing the features for which the prediction is to be made.

## Returns

The predicted target value.

## Type Constraints

Constrain input and output types to fixed point tensors.

## Examples

```rust
use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait};
use orion::numbers::{FP16x16, FixedTrait};

fn tree_regressor_example() {

let data = array![
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(),
array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(),
array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(),
array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(),
]
.span();

let target = array![
FixedTrait::new_unscaled(2, false),
FixedTrait::new_unscaled(4, false),
FixedTrait::new_unscaled(6, false),
FixedTrait::new_unscaled(8, false)
]
.span();

let mut tree = TreeRegressorTrait::fit(data, target, 3);

let prediction_1 = tree
.predict(
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span()
);
}
```

0 comments on commit e10183f

Please sign in to comment.