From b617740dbc4a045e24f94edca13a07a0cb8738c1 Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Tue, 6 Feb 2024 11:49:02 +0100 Subject: [PATCH] docs: update operator list in torch support's documentation section --- docs/built-in-models/neural-networks.md | 2 +- docs/deep-learning/torch_support.md | 63 ++++++++++++++++++------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/docs/built-in-models/neural-networks.md b/docs/built-in-models/neural-networks.md index d400d94a9..f88d3cc7e 100644 --- a/docs/built-in-models/neural-networks.md +++ b/docs/built-in-models/neural-networks.md @@ -48,7 +48,7 @@ The figure above right shows the Concrete ML neural network, trained with Quanti ### Architecture parameters - `module__n_layers`: number of layers in the FCNN, must be at least 1. Note that this is the total number of layers. For a single, hidden layer NN model, set `module__n_layers=2` -- `module__activation_function`: can be one of the Torch activations (e.g., nn.ReLU, see the full list [here](../deep-learning/torch_support.md#activations)). Neural networks with `nn.ReLU` activation benefit from specific optimizations that make them around 10x faster than networks with other activation functions. +- `module__activation_function`: can be one of the Torch activations (e.g., nn.ReLU, see the full list [here](../deep-learning/torch_support.md#activation-functions)). Neural networks with `nn.ReLU` activation benefit from specific optimizations that make them around 10x faster than networks with other activation functions. ### Quantization parameters diff --git a/docs/deep-learning/torch_support.md b/docs/deep-learning/torch_support.md index 578a8e9fc..ebcae62dd 100644 --- a/docs/deep-learning/torch_support.md +++ b/docs/deep-learning/torch_support.md @@ -155,62 +155,93 @@ Concrete ML supports a variety of PyTorch operators that can be used to build fu ### Operators -#### univariate operators +#### Univariate operators -- [`torch.abs`](https://pytorch.org/docs/stable/generated/torch.abs.html) +- [`torch.nn.identity`](https://pytorch.org/docs/stable/generated/torch.nn.Identity.html) - [`torch.clip`](https://pytorch.org/docs/stable/generated/torch.clip.html) +- [`torch.clamp`](https://pytorch.org/docs/stable/generated/torch.clamp.html) +- [`torch.round`](https://pytorch.org/docs/stable/generated/torch.round.html) +- [`torch.floor`](https://pytorch.org/docs/stable/generated/torch.floor.html) +- [`torch.min`](https://pytorch.org/docs/stable/generated/torch.min.html) +- [`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html) +- [`torch.abs`](https://pytorch.org/docs/stable/generated/torch.abs.html) +- [`torch.neg`](https://pytorch.org/docs/stable/generated/torch.neg.html) +- [`torch.sign`](https://pytorch.org/docs/stable/generated/torch.sign.html) +- [`torch.logical_or, torch.Tensor operator ||`](https://pytorch.org/docs/stable/generated/torch.logical_or.html) +- [`torch.logical_not`](https://pytorch.org/docs/stable/generated/torch.logical_not.html) +- [`torch.gt, torch.greater`](https://pytorch.org/docs/stable/generated/torch.gt.html) +- [`torch.ge, torch.greater_equal`](https://pytorch.org/docs/stable/generated/torch.ge.html) +- [`torch.lt, torch.less`](https://pytorch.org/docs/stable/generated/torch.lt.html) +- [`torch.le, torch.less_equal`](https://pytorch.org/docs/stable/generated/torch.le.html) +- [`torch.eq`](https://pytorch.org/docs/stable/generated/torch.eq.html) +- [`torch.where`](https://pytorch.org/docs/stable/generated/torch.where.html) - [`torch.exp`](https://pytorch.org/docs/stable/generated/torch.exp.html) - [`torch.log`](https://pytorch.org/docs/stable/generated/torch.log.html) -- [`torch.gt`](https://pytorch.org/docs/stable/generated/torch.gt.html) -- [`torch.clamp`](https://pytorch.org/docs/stable/generated/torch.clamp.html) +- [`torch.pow`](https://pytorch.org/docs/stable/generated/torch.pow.html) +- [`torch.sum`](https://pytorch.org/docs/stable/generated/torch.sum.html) - [`torch.mul, torch.Tensor operator *`](https://pytorch.org/docs/stable/generated/torch.mul.html) - [`torch.div, torch.Tensor operator /`](https://pytorch.org/docs/stable/generated/torch.div.html) -- [`torch.nn.identity`](https://pytorch.org/docs/stable/generated/torch.nn.Identity.html) - [`torch.nn.BatchNorm2d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html) +- [`torch.nn.BatchNorm3d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html) +- [`torch.erf, torch.special.erf`](https://pytorch.org/docs/stable/special.html#torch.special.erf) +- [`torch.nn.functional.pad`](https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html) -#### shape modifying operators +#### Shape modifying operators - [`torch.reshape`](https://pytorch.org/docs/stable/generated/torch.reshape.html) - [`torch.Tensor.view`](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view) - [`torch.flatten`](https://pytorch.org/docs/stable/generated/torch.flatten.html) +- [`torch.unsqueeze`](https://pytorch.org/docs/stable/generated/torch.unsqueeze.html) +- [`torch.squeeze`](https://pytorch.org/docs/stable/generated/torch.squeeze.html) - [`torch.transpose`](https://pytorch.org/docs/stable/generated/torch.transpose.html) +- [`torch.concat, torch.cat`](https://pytorch.org/docs/stable/generated/torch.cat.html) +- [`torch.nn.Unfold`](https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html) -#### operators that take an encrypted input and unencrypted constants +#### Tensor operators + +- [`torch.Tensor.expand`](https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html) +- [`torch.Tensor.to`](https://pytorch.org/docs/stable/generated/torch.Tensor.to.html) -- for casting to dtype + +#### Multi-variate operators: encrypted input and unencrypted constants -- [`torch.conv2d`, `torch.nn.Conv2D`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) -- [`torch.matmul`](https://pytorch.org/docs/stable/generated/torch.matmul.html) - [`torch.nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) +- [`torch.conv1d`, `torch.nn.Conv1D`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) +- [`torch.conv2d`, `torch.nn.Conv2D`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) +- [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) +- [`torch.nn.MaxPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) -Concrete ML supports these operators but also the QAT equivalents from Brevitas. +Concrete ML also supports some of their QAT equivalents from Brevitas. - `brevitas.nn.QuantLinear` +- `brevitas.nn.QuantConv1d` - `brevitas.nn.QuantConv2d` -#### operators that can take both encrypted+unencrypted and encrypted+encrypted inputs +#### Multi-variate operators: encrypted+unencrypted or encrypted+encrypted inputs - [`torch.add, torch.Tensor operator +`](https://pytorch.org/docs/stable/generated/torch.Tensor.add.html) - [`torch.sub, torch.Tensor operator -`](https://pytorch.org/docs/stable/generated/torch.Tensor.sub.html) +- [`torch.matmul`](https://pytorch.org/docs/stable/generated/torch.matmul.html) ### Quantizers - `brevitas.nn.QuantIdentity` -### Activations +### Activation functions -- [`torch.nn.Celu`](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html) -- [`torch.nn.Elu`](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html) +- [`torch.nn.CELU`](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html) +- [`torch.nn.ELU`](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html) - [`torch.nn.GELU`](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) - [`torch.nn.Hardshrink`](https://pytorch.org/docs/stable/generated/torch.nn.Hardshrink.html) - [`torch.nn.HardSigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html) - [`torch.nn.Hardswish`](https://pytorch.org/docs/stable/generated/torch.nn.Hardswish) - [`torch.nn.HardTanh`](https://pytorch.org/docs/stable/generated/torch.nn.Hardtanh.html) -- [`torch.nn.LeakyRelu`](https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html) +- [`torch.nn.LeakyReLU`](https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html) - [`torch.nn.LogSigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.LogSigmoid.html) - [`torch.nn.Mish`](https://pytorch.org/docs/stable/generated/torch.nn.Mish.html) - [`torch.nn.PReLU`](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html) - [`torch.nn.ReLU6`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU6.html) - [`torch.nn.ReLU`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) -- [`torch.nn.Selu`](https://pytorch.org/docs/stable/generated/torch.nn.SELU.html) +- [`torch.nn.SELU`](https://pytorch.org/docs/stable/generated/torch.nn.SELU.html) - [`torch.nn.Sigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html) - [`torch.nn.SiLU`](https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html) - [`torch.nn.Softplus`](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)