Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement generic TreeRegressor #217

Merged
merged 11 commits into from
Sep 29, 2023
8 changes: 8 additions & 0 deletions docgen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ fn main() {
let trait_name: &str = "IntegerTrait";
doc_trait(trait_path, doc_path, label);
doc_functions(trait_path, doc_path, trait_name, label);

// TREE REGRESSOR DOC
let trait_path = "src/operators/ml/tree_regressor/core.cairo";
let doc_path = "docs/framework/operators/machine-learning/tree-regressor";
let label = "tree";
let trait_name: &str = "TreeRegressorTrait";
doc_trait(trait_path, doc_path, label);
doc_functions(trait_path, doc_path, trait_name, label);
}

fn doc_trait(trait_path: &str, doc_path: &str, label: &str) {
Expand Down
5 changes: 5 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased] - 2023-09-27

## Add
- Implement `TreeRegressor` trait for decision tree regression.

## [Unreleased] - 2023-09-03

## Changed
Expand Down
4 changes: 4 additions & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@
* [nn.softsign](framework/operators/neural-network/nn.softsign.md)
* [nn.softplus](framework/operators/neural-network/nn.softplus.md)
* [nn.linear](framework/operators/neural-network/nn.linear.md)
* [Machine Learning](framework/operators/machine-learning/README.md)
* [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md)
* [tree.build](framework/operators/machine-learning/tree-regressor/tree.fit.md)
* [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md)

## 🏛 Hub

Expand Down
Empty file.
24 changes: 24 additions & 0 deletions docs/framework/operators/machine-learning/tree-regressor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Tree Regressor

`TreeRegressorTrait` provides a trait definition for decision tree regression.
This trait offers functionalities to build a decision tree and predict target values based on input features.

```rust
use orion::operators::ml::TreeRegressorTrait;
```

### Data types

Orion supports currently only fixed point data types for `TreeRegressorTrait`.

| Data type | dtype |
| -------------------- | ------------------------------------------------------------- |
| Fixed point (signed) | `TreeRegressorTrait<FP8x23 \| FP16x16 \| FP64x64 \| FP32x32>` |

***

| function | description |
| --- | --- |
| [`tree.fit`](tree.fit\_tree.md) | Constructs a decision tree regressor based on the provided data and target values. |
| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. |

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# TreeRegressorTrait::fit

```rust
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize) -> TreeNode<T>;
```

Builds a decision tree based on the provided data and target values up to a specified maximum depth.

## Args

* `data`: A span of spans representing rows of features in the dataset.
* `target`: A span representing the target values corresponding to each row in the dataset.
* `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached.

## Returns

A `TreeNode` representing the root of the constructed decision tree.

## Type Constraints

Constrain input and output types to fixed point tensors.

## Examples

```rust
use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait};
use orion::numbers::{FP16x16, FixedTrait};

fn tree_regressor_example() {

let data = array![
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(),
array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(),
array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(),
array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(),
]
.span();

let target = array![
FixedTrait::new_unscaled(2, false),
FixedTrait::new_unscaled(4, false),
FixedTrait::new_unscaled(6, false),
FixedTrait::new_unscaled(8, false)
]
.span();

TreeRegressorTrait::fit(data, target, 3);
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# TreeRegressorTrait::predict

```rust
fn predict(ref self: TreeNode<T>, features: Span<T>) -> T;
```

Predicts the target value for a set of features using the provided decision tree.

## Args

* `self`: A reference to the decision tree used for making the prediction.
* `features`: A span representing the features for which the prediction is to be made.

## Returns

The predicted target value.

## Type Constraints

Constrain input and output types to fixed point tensors.

## Examples

```rust
use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait};
use orion::numbers::{FP16x16, FixedTrait};

fn tree_regressor_example() {

let data = array![
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(),
array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(),
array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(),
array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(),
]
.span();

let target = array![
FixedTrait::new_unscaled(2, false),
FixedTrait::new_unscaled(4, false),
FixedTrait::new_unscaled(6, false),
FixedTrait::new_unscaled(8, false)
]
.span();

let mut tree = TreeRegressorTrait::fit(data, target, 3);

let prediction_1 = tree
.predict(
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span()
);
}
```
1 change: 1 addition & 0 deletions src/numbers/fixed_point/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1069,4 +1069,5 @@ trait FixedTrait<T, MAG> {

fn ZERO() -> T;
fn ONE() -> T;
fn MAX() -> T;
}
4 changes: 4 additions & 0 deletions src/numbers/fixed_point/implementations/fp16x16/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl FP16x16Impl of FixedTrait<FP16x16, u32> {
return FP16x16 { mag: ONE, sign: false };
}

fn MAX() -> FP16x16 {
return FP16x16 { mag: MAX, sign: false };
}

fn new(mag: u32, sign: bool) -> FP16x16 {
return FP16x16 { mag: mag, sign: sign };
}
Expand Down
4 changes: 4 additions & 0 deletions src/numbers/fixed_point/implementations/fp32x32/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ impl FP32x32Impl of FixedTrait<FP32x32, u64> {
return FP32x32 { mag: ONE, sign: false };
}

fn MAX() -> FP32x32 {
return FP32x32 { mag: MAX, sign: false };
}

fn new(mag: u64, sign: bool) -> FP32x32 {
return FP32x32 { mag: mag, sign: sign };
}
Expand Down
4 changes: 4 additions & 0 deletions src/numbers/fixed_point/implementations/fp64x64/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ impl FP64x64Impl of FixedTrait<FP64x64, u128> {
return FP64x64 { mag: ONE, sign: false };
}

fn MAX() -> FP64x64 {
return FP64x64 { mag: MAX, sign: false };
}

fn new(mag: u128, sign: bool) -> FP64x64 {
return FP64x64 { mag: mag, sign: sign };
}
Expand Down
4 changes: 4 additions & 0 deletions src/numbers/fixed_point/implementations/fp8x23/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl FP8x23Impl of FixedTrait<FP8x23, u32> {
return FP8x23 { mag: ONE, sign: false };
}

fn MAX() -> FP8x23 {
return FP8x23 { mag: MAX, sign: false };
}

fn new(mag: u32, sign: bool) -> FP8x23 {
return FP8x23 { mag: mag, sign: sign };
}
Expand Down
2 changes: 1 addition & 1 deletion src/operators.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mod tensor;
mod nn;

mod ml;
7 changes: 7 additions & 0 deletions src/operators/ml.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod tree_regressor;

use orion::operators::ml::tree_regressor::core::TreeRegressorTrait;
use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x16::FP16x16TreeRegressor;
use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor;
use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor;
use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp64x64::FP64x64TreeRegressor;
2 changes: 2 additions & 0 deletions src/operators/ml/tree_regressor.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod implementations;
mod core;
Loading