Skip to content

Commit

Permalink
generate doc
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Oct 16, 2023
1 parent 4e8d686 commit 3de6907
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 6 deletions.
8 changes: 8 additions & 0 deletions docgen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ fn main() {
let trait_name: &str = "TreeRegressorTrait";
doc_trait(trait_path, doc_path, label);
doc_functions(trait_path, doc_path, trait_name, label);

// TREE ClASSIFIER DOC
let trait_path = "src/operators/ml/tree_classifier/core.cairo";
let doc_path = "docs/framework/operators/machine-learning/tree-classifier";
let label = "tree";
let trait_name: &str = "TreeClassifierTrait";
doc_trait(trait_path, doc_path, label);
doc_functions(trait_path, doc_path, trait_name, label);
}

fn doc_trait(trait_path: &str, doc_path: &str, label: &str) {
Expand Down
5 changes: 5 additions & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@
* [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md)
* [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md)
* [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md)
* [Tree Classifier](framework/operators/machine-learning/tree-classifier/README.md)
* [tree.predict](framework/operators/machine-learning/tree-classifier/tree.predict.md)
* [tree.predict_proba](framework/operators/machine-learning/tree-classifier/tree.predict_proba.md)



## 🏛 Hub

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Tree Classifier

`TreeClassifierTrait` provides a trait definition for decision tree classifier. This trait offers functionalities to build a decision tree and predict target values based on input features.

```rust
use orion::operators::ml::TreeClassifierTrait;
```

### Data types

Orion supports currently only fixed point data types for `TreeClassifierTrait`.

| Data type | dtype |
| -------------------- | ------------------------------------------------------------- |
| Fixed point (signed) | `TreeClassifierTrait<FP8x23 \| FP16x16 \| FP64x64 \| FP32x32>` |

***

| function | description |
| --- | --- |
| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. |
| [`tree.predict_proba`](tree.predict\_proba.md) | Predicts class probabilities based on feature data. |

Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# TreeClassifierTrait::predict

```rust
fn predict(ref self: TreeClassifier<T>, features: Span<T>) -> T;
```

Predicts the target value for a set of features using the provided decision tree.

## Args

* `self`: A reference to the decision tree used for making the prediction.
* `features`: A span representing the features for which the prediction is to be made.

## Returns

The predicted target value.

## Type Constraints

Constrain input and output types to fixed point.

## Examples

```rust
use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier};
use orion::numbers::{FP16x16, FixedTrait};

fn tree_classifier_example(tree: TreeClassifier<FP16x16>) {

tree.predict(
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span()
);

}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# TreeClassifierTrait::predict_proba

```rust
fn predict_proba(ref self: TreeClassifier<T>, features: Span<T>) -> Span<T>;
```

Given a set of features, this method traverses the decision tree
represented by `self` and returns the class distribution (probabilities)
found in the leaf node that matches the provided features. The traversal
stops once a leaf node is reached in the decision tree.

## Args

* `self`: A reference to the decision tree used for making the prediction.
* `features`: A span representing the features for which the prediction is to be made.

## Returns

Returns a `Span<T>` representing the class distribution at the leaf node.

## Type Constraints

Constrain input and output types to fixed points.

## Examples

```rust
use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier};
use orion::numbers::{FP16x16, FixedTrait};

fn tree_classifier_example(tree: TreeClassifier<FP16x16>) {

tree.predict_proba(
array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span()
);

}
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TreeRegressorTrait::fit

```rust
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize, random_state: usize) -> TreeNode<T>;
fn fit(data: Span<Span<T>>, target: Span<T>, max_depth: usize, random_state: usize) -> TreeRegressor<T>;
```

Builds a decision tree based on the provided data and target values up to a specified maximum depth.
Expand All @@ -15,7 +15,7 @@ Builds a decision tree based on the provided data and target values up to a spec

## Returns

A `TreeNode` representing the root of the constructed decision tree.
A `TreeRegressor` representing the root of the constructed decision tree.

## Type Constraints

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TreeRegressorTrait::predict

```rust
fn predict(ref self: TreeNode<T>, features: Span<T>) -> T;
fn predict(ref self: TreeRegressor<T>, features: Span<T>) -> T;
```

Predicts the target value for a set of features using the provided decision tree.
Expand All @@ -17,7 +17,7 @@ The predicted target value.

## Type Constraints

Constrain input and output types to fixed point tensors.
Constrain input and output types to fixed point.

## Examples

Expand Down
77 changes: 76 additions & 1 deletion src/operators/ml/tree_classifier/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,84 @@ struct TreeClassifier<T> {
/// Trait
///
/// predict - Given a set of features, predicts the target value using the constructed decision tree.
/// predict_proba - Given a set of features, predicts the probability of each X example being of a given class..
/// predict_proba - Predicts class probabilities based on feature data.
trait TreeClassifierTrait<T> {
/// # TreeClassifierTrait::predict
///
/// ```rust
/// fn predict(ref self: TreeClassifier<T>, features: Span<T>) -> T;
/// ```
///
/// Predicts the target value for a set of features using the provided decision tree.
///
/// ## Args
///
/// * `self`: A reference to the decision tree used for making the prediction.
/// * `features`: A span representing the features for which the prediction is to be made.
///
/// ## Returns
///
/// The predicted target value.
///
/// ## Type Constraints
///
/// Constrain input and output types to fixed point.
///
/// ## Examples
///
/// ```rust
/// use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier};
/// use orion::numbers::{FP16x16, FixedTrait};
///
/// fn tree_classifier_example(tree: TreeClassifier<FP16x16>) {
///
/// tree.predict(
/// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span()
/// );
///
/// }
/// ```
///
fn predict(ref self: TreeClassifier<T>, features: Span<T>) -> T;
/// # TreeClassifierTrait::predict_proba
///
/// ```rust
/// fn predict_proba(ref self: TreeClassifier<T>, features: Span<T>) -> Span<T>;
/// ```
///
/// Given a set of features, this method traverses the decision tree
/// represented by `self` and returns the class distribution (probabilities)
/// found in the leaf node that matches the provided features. The traversal
/// stops once a leaf node is reached in the decision tree.
///
/// ## Args
///
/// * `self`: A reference to the decision tree used for making the prediction.
/// * `features`: A span representing the features for which the prediction is to be made.
///
/// ## Returns
///
/// Returns a `Span<T>` representing the class distribution at the leaf node.
///
/// ## Type Constraints
///
/// Constrain input and output types to fixed points.
///
/// ## Examples
///
/// ```rust
/// use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier};
/// use orion::numbers::{FP16x16, FixedTrait};
///
/// fn tree_classifier_example(tree: TreeClassifier<FP16x16>) {
///
/// tree.predict_proba(
/// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span()
/// );
///
/// }
/// ```
///
fn predict_proba(ref self: TreeClassifier<T>, features: Span<T>) -> Span<T>;
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/ml/tree_regressor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ trait TreeRegressorTrait<T> {
///
/// ## Type Constraints
///
/// Constrain input and output types to fixed point tensors.
/// Constrain input and output types to fixed point.
///
/// ## Examples
///
Expand Down

0 comments on commit 3de6907

Please sign in to comment.