-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Merge branch 'main' into fix-doc-conflict"
- Loading branch information
1 parent
707d664
commit e10183f
Showing
5 changed files
with
131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Machine Learning | ||
|
22 changes: 22 additions & 0 deletions
22
docs/framework/operators/machine-learning/tree-regressor/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Tree Regressor | ||
|
||
`TreeRegressorTrait` provides a trait definition for decision tree regression. This trait offers functionalities to build a decision tree and predict target values based on input features. | ||
|
||
```rust | ||
use orion::operators::ml::TreeRegressorTrait; | ||
``` | ||
|
||
### Data types | ||
|
||
Orion supports currently only fixed point data types for `TreeRegressorTrait`. | ||
|
||
| Data type | dtype | | ||
| -------------------- | ------------------------------------------------------------- | | ||
| Fixed point (signed) | `TreeRegressorTrait<FP8x23 \| FP16x16 \| FP64x64 \| FP32x32>` | | ||
|
||
*** | ||
|
||
| function | description | | ||
| --- | --- | | ||
| [`tree.fit`](tree.fit.md) | Constructs a decision tree regressor based on the provided data and target values. | | ||
| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. | |
50 changes: 50 additions & 0 deletions
50
docs/framework/operators/machine-learning/tree-regressor/tree.fit.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# TreeRegressorTrait::fit | ||
|
||
```rust | ||
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize, random_state: usize) -> TreeNode<T>; | ||
``` | ||
|
||
Builds a decision tree based on the provided data and target values up to a specified maximum depth. | ||
|
||
## Args | ||
|
||
* `data`: A span of spans representing rows of features in the dataset. | ||
* `target`: A span representing the target values corresponding to each row in the dataset. | ||
* `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached. | ||
* `random_state`: It ensures that the tie-breaking is consistent across multiple runs, leading to reproducible results. | ||
|
||
## Returns | ||
|
||
A `TreeNode` representing the root of the constructed decision tree. | ||
|
||
## Type Constraints | ||
|
||
Constrain input and output types to fixed point tensors. | ||
|
||
## Examples | ||
|
||
```rust | ||
use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; | ||
use orion::numbers::{FP16x16, FixedTrait}; | ||
|
||
fn tree_regressor_example() { | ||
|
||
let data = array![ | ||
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), | ||
array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), | ||
array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), | ||
array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), | ||
] | ||
.span(); | ||
|
||
let target = array![ | ||
FixedTrait::new_unscaled(2, false), | ||
FixedTrait::new_unscaled(4, false), | ||
FixedTrait::new_unscaled(6, false), | ||
FixedTrait::new_unscaled(8, false) | ||
] | ||
.span(); | ||
|
||
TreeRegressorTrait::fit(data, target, 3, 42); | ||
} | ||
``` |
53 changes: 53 additions & 0 deletions
53
docs/framework/operators/machine-learning/tree-regressor/tree.predict.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# tree.predict | ||
|
||
```rust | ||
fn predict(ref self: TreeNode<T>, features: Span<T>) -> T; | ||
``` | ||
|
||
Predicts the target value for a set of features using the provided decision tree. | ||
|
||
## Args | ||
|
||
* `self`: A reference to the decision tree used for making the prediction. | ||
* `features`: A span representing the features for which the prediction is to be made. | ||
|
||
## Returns | ||
|
||
The predicted target value. | ||
|
||
## Type Constraints | ||
|
||
Constrain input and output types to fixed point tensors. | ||
|
||
## Examples | ||
|
||
```rust | ||
use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; | ||
use orion::numbers::{FP16x16, FixedTrait}; | ||
|
||
fn tree_regressor_example() { | ||
|
||
let data = array![ | ||
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), | ||
array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), | ||
array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), | ||
array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), | ||
] | ||
.span(); | ||
|
||
let target = array![ | ||
FixedTrait::new_unscaled(2, false), | ||
FixedTrait::new_unscaled(4, false), | ||
FixedTrait::new_unscaled(6, false), | ||
FixedTrait::new_unscaled(8, false) | ||
] | ||
.span(); | ||
|
||
let mut tree = TreeRegressorTrait::fit(data, target, 3); | ||
|
||
let prediction_1 = tree | ||
.predict( | ||
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() | ||
); | ||
} | ||
``` |