Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement Gemm #403

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
doc
  • Loading branch information
raphaelDkhn committed Oct 25, 2023
commit 755556d7738faac7b79db57ef899831a6cba5391
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
* [nn.linear](framework/operators/neural-network/nn.linear.md)
* [nn.hard\_sigmoid](framework/operators/neural-network/nn.hard\_sigmoid.md)
* [nn.thresholded\_relu](framework/operators/neural-network/nn.thresholded_relu.md)
* [nn.gemm](framework/operators/neural-network/nn.gemm.md)
* [Machine Learning](framework/operators/machine-learning/README.md)
* [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md)
* [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md)
Expand Down
1 change: 1 addition & 0 deletions docs/framework/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ You can see below the list of current supported ONNX Operators:
| [Clip](operators/tensor/tensor.clip.md) | :white\_check\_mark: |
| [Identity](operators/tensor/tensor.identity.md) | :white\_check\_mark: |
| [And](operators/tensor/tensor.and.md) | :white\_check\_mark: |
| [Gemm](operators/neural-network/nn.gemm.md) | :white\_check\_mark: |

Current Operators support: **51/156 (33%)**
3 changes: 2 additions & 1 deletion docs/framework/operators/neural-network/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ Orion supports currently these `NN` types.
| [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. |
| [`nn.linear`](nn.linear.md) | Performs a linear transformation of the input tensor using the provided weights and bias. |
| [`nn.hard_sigmoid`](nn.hard\_sigmoid.md) | Applies the Hard Sigmoid function to an n-dimensional input tensor. |
| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | performs the thresholded relu activation function element-wise. |
| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | Performs the thresholded relu activation function element-wise. |
| [`nn.gemm`](nn.gemm.md) | Performs General Matrix multiplication. |

67 changes: 67 additions & 0 deletions docs/framework/operators/neural-network/nn.gemm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# NNTrait::gemm

```rust
fn gemm(
A: Tensor<T>,
B: Tensor<T>,
C: Option<Tensor<T>>,
alpha: Option<T>,
beta: Option<T>,
transA: bool,
transB: bool
) -> Tensor<T>;
```

Performs General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3

* A' = transpose(A) if transA else A
* B' = transpose(B) if transB else B

Compute `Y = alpha * A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N).
`A` will be transposed before doing the computation if attribute `transA` is `true`, same for `B` and `transB`.

## Args

* `A`(`Tensor<T>`) - Input tensor A. The shape of `A` should be (M, K) if `transA` is `false`, or (K, M) if `transA` is `true`.
* `B`(`Tensor<T>`) - Input tensor B. The shape of `B` should be (K, N) if `transB` is `false`, or (N, K) if `transB` is `true`.
* `C`(`Option<Tensor<T>>`) - Optional input tensor C. The shape of C should be unidirectional broadcastable to (M, N).
* `alpha`(`Option<T>`) - Optional scalar multiplier for the product of input tensors `A * B`.
* `beta`(`Option<T>`) - Optional scalar multiplier for input tensor `C`.
* `transA`(`bool`) - Whether `A` should be transposed.
* `transB`(`bool`) - Whether `B` should be transposed.

## Returns

A `Tensor<T>` of shape (M, N).

## Examples

```rust
mod input_0;
mod input_1;
mod input_2;

use orion::operators::nn::NNTrait;
use orion::numbers::FixedTrait;
use orion::operators::nn::FP16x16NN;
use orion::operators::tensor::FP16x16TensorPartialEq;

fn gemm_all_attributes_example() -> Tensor<FP16x16> {
let input_0 = input_0::input_0(); // shape [4;3]
let input_1 = input_1::input_1(); // shape [5;4]
let input_2 = input_2::input_2(); // shape [1;5]

let y = NNTrait::gemm(
input_0,
input_1,
Option::Some(input_2),
Option::Some(FixedTrait::new(16384, false)), // 0.25
Option::Some(FixedTrait::new(22938, false)), // 0.35
true,
true
);

return y;
}
>>> tensor of shape [3;5]
````
71 changes: 70 additions & 1 deletion src/operators/nn/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use orion::operators::tensor::core::Tensor;
/// softplus - Applies the Softplus function element-wise.
/// linear - Performs a linear transformation of the input tensor using the provided weights and bias.
/// hard_sigmoid - Applies the Hard Sigmoid function to an n-dimensional input tensor.
/// thresholded_relu - performs the thresholded relu activation function element-wise.
/// thresholded_relu - Performs the thresholded relu activation function element-wise.
/// gemm - Performs General Matrix multiplication.
trait NNTrait<T> {
/// # NNTrait::relu
///
Expand Down Expand Up @@ -557,6 +558,74 @@ trait NNTrait<T> {
/// ```
///
fn thresholded_relu(tensor: @Tensor<T>, alpha: @T) -> Tensor<T>;
/// # NNTrait::gemm
///
/// ```rust
/// fn gemm(
/// A: Tensor<T>,
/// B: Tensor<T>,
/// C: Option<Tensor<T>>,
/// alpha: Option<T>,
/// beta: Option<T>,
/// transA: bool,
/// transB: bool
/// ) -> Tensor<T>;
/// ```
///
/// Performs General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
///
/// * A' = transpose(A) if transA else A
/// * B' = transpose(B) if transB else B
///
/// Compute `Y = alpha * A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N).
/// `A` will be transposed before doing the computation if attribute `transA` is `true`, same for `B` and `transB`.
///
/// ## Args
///
/// * `A`(`Tensor<T>`) - Input tensor A. The shape of `A` should be (M, K) if `transA` is `false`, or (K, M) if `transA` is `true`.
/// * `B`(`Tensor<T>`) - Input tensor B. The shape of `B` should be (K, N) if `transB` is `false`, or (N, K) if `transB` is `true`.
/// * `C`(`Option<Tensor<T>>`) - Optional input tensor C. The shape of C should be unidirectional broadcastable to (M, N).
/// * `alpha`(`Option<T>`) - Optional scalar multiplier for the product of input tensors `A * B`.
/// * `beta`(`Option<T>`) - Optional scalar multiplier for input tensor `C`.
/// * `transA`(`bool`) - Whether `A` should be transposed.
/// * `transB`(`bool`) - Whether `B` should be transposed.
///
/// ## Returns
///
/// A `Tensor<T>` of shape (M, N).
///
/// ## Examples
///
/// ```rust
/// mod input_0;
/// mod input_1;
/// mod input_2;
///
/// use orion::operators::nn::NNTrait;
/// use orion::numbers::FixedTrait;
/// use orion::operators::nn::FP16x16NN;
/// use orion::operators::tensor::FP16x16TensorPartialEq;
///
/// fn gemm_all_attributes_example() -> Tensor<FP16x16> {
/// let input_0 = input_0::input_0(); // shape [4;3]
/// let input_1 = input_1::input_1(); // shape [5;4]
/// let input_2 = input_2::input_2(); // shape [1;5]
///
/// let y = NNTrait::gemm(
/// input_0,
/// input_1,
/// Option::Some(input_2),
/// Option::Some(FixedTrait::new(16384, false)), // 0.25
/// Option::Some(FixedTrait::new(22938, false)), // 0.35
/// true,
/// true
/// );
///
/// return y;
/// }
/// >>> tensor of shape [3;5]
/// ````
///
fn gemm(
A: Tensor<T>,
B: Tensor<T>,
Expand Down
21 changes: 14 additions & 7 deletions tests/src/nodes/gemm_all_attributes.cairo
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
mod input_0;
mod input_1;
mod input_2;
mod output_0;

mod input_0;
mod input_1;
mod input_2;
mod output_0;

use orion::operators::nn::NNTrait;
use orion::numbers::FixedTrait;
Expand All @@ -18,7 +17,15 @@ fn test_gemm_all_attributes() {
let input_2 = input_2::input_2();
let z = output_0::output_0();

let y = NNTrait::gemm(input_0, input_1, Option::Some(input_2), Option::Some(FixedTrait::new(16384, false)), Option::Some(FixedTrait::new(22938, false)), true, true);
let y = NNTrait::gemm(
input_0,
input_1,
Option::Some(input_2),
Option::Some(FixedTrait::new(16384, false)),
Option::Some(FixedTrait::new(22938, false)),
true,
true
);

assert_eq(y, z);
}
}