Skip to content

Commit

Permalink
docs: update operator list in torch support's documentation section
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Feb 7, 2024
1 parent 15a8340 commit b617740
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/built-in-models/neural-networks.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The figure above right shows the Concrete ML neural network, trained with Quanti
### Architecture parameters

- `module__n_layers`: number of layers in the FCNN, must be at least 1. Note that this is the total number of layers. For a single, hidden layer NN model, set `module__n_layers=2`
- `module__activation_function`: can be one of the Torch activations (e.g., nn.ReLU, see the full list [here](../deep-learning/torch_support.md#activations)). Neural networks with `nn.ReLU` activation benefit from specific optimizations that make them around 10x faster than networks with other activation functions.
- `module__activation_function`: can be one of the Torch activations (e.g., nn.ReLU, see the full list [here](../deep-learning/torch_support.md#activation-functions)). Neural networks with `nn.ReLU` activation benefit from specific optimizations that make them around 10x faster than networks with other activation functions.

### Quantization parameters

Expand Down
63 changes: 47 additions & 16 deletions docs/deep-learning/torch_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,62 +155,93 @@ Concrete ML supports a variety of PyTorch operators that can be used to build fu

### Operators

#### univariate operators
#### Univariate operators

- [`torch.abs`](https://pytorch.org/docs/stable/generated/torch.abs.html)
- [`torch.nn.identity`](https://pytorch.org/docs/stable/generated/torch.nn.Identity.html)
- [`torch.clip`](https://pytorch.org/docs/stable/generated/torch.clip.html)
- [`torch.clamp`](https://pytorch.org/docs/stable/generated/torch.clamp.html)
- [`torch.round`](https://pytorch.org/docs/stable/generated/torch.round.html)
- [`torch.floor`](https://pytorch.org/docs/stable/generated/torch.floor.html)
- [`torch.min`](https://pytorch.org/docs/stable/generated/torch.min.html)
- [`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html)
- [`torch.abs`](https://pytorch.org/docs/stable/generated/torch.abs.html)
- [`torch.neg`](https://pytorch.org/docs/stable/generated/torch.neg.html)
- [`torch.sign`](https://pytorch.org/docs/stable/generated/torch.sign.html)
- [`torch.logical_or, torch.Tensor operator ||`](https://pytorch.org/docs/stable/generated/torch.logical_or.html)
- [`torch.logical_not`](https://pytorch.org/docs/stable/generated/torch.logical_not.html)
- [`torch.gt, torch.greater`](https://pytorch.org/docs/stable/generated/torch.gt.html)
- [`torch.ge, torch.greater_equal`](https://pytorch.org/docs/stable/generated/torch.ge.html)
- [`torch.lt, torch.less`](https://pytorch.org/docs/stable/generated/torch.lt.html)
- [`torch.le, torch.less_equal`](https://pytorch.org/docs/stable/generated/torch.le.html)
- [`torch.eq`](https://pytorch.org/docs/stable/generated/torch.eq.html)
- [`torch.where`](https://pytorch.org/docs/stable/generated/torch.where.html)
- [`torch.exp`](https://pytorch.org/docs/stable/generated/torch.exp.html)
- [`torch.log`](https://pytorch.org/docs/stable/generated/torch.log.html)
- [`torch.gt`](https://pytorch.org/docs/stable/generated/torch.gt.html)
- [`torch.clamp`](https://pytorch.org/docs/stable/generated/torch.clamp.html)
- [`torch.pow`](https://pytorch.org/docs/stable/generated/torch.pow.html)
- [`torch.sum`](https://pytorch.org/docs/stable/generated/torch.sum.html)
- [`torch.mul, torch.Tensor operator *`](https://pytorch.org/docs/stable/generated/torch.mul.html)
- [`torch.div, torch.Tensor operator /`](https://pytorch.org/docs/stable/generated/torch.div.html)
- [`torch.nn.identity`](https://pytorch.org/docs/stable/generated/torch.nn.Identity.html)
- [`torch.nn.BatchNorm2d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)
- [`torch.nn.BatchNorm3d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html)
- [`torch.erf, torch.special.erf`](https://pytorch.org/docs/stable/special.html#torch.special.erf)
- [`torch.nn.functional.pad`](https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html)

#### shape modifying operators
#### Shape modifying operators

- [`torch.reshape`](https://pytorch.org/docs/stable/generated/torch.reshape.html)
- [`torch.Tensor.view`](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view)
- [`torch.flatten`](https://pytorch.org/docs/stable/generated/torch.flatten.html)
- [`torch.unsqueeze`](https://pytorch.org/docs/stable/generated/torch.unsqueeze.html)
- [`torch.squeeze`](https://pytorch.org/docs/stable/generated/torch.squeeze.html)
- [`torch.transpose`](https://pytorch.org/docs/stable/generated/torch.transpose.html)
- [`torch.concat, torch.cat`](https://pytorch.org/docs/stable/generated/torch.cat.html)
- [`torch.nn.Unfold`](https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html)

#### operators that take an encrypted input and unencrypted constants
#### Tensor operators

- [`torch.Tensor.expand`](https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html)
- [`torch.Tensor.to`](https://pytorch.org/docs/stable/generated/torch.Tensor.to.html) -- for casting to dtype

#### Multi-variate operators: encrypted input and unencrypted constants

- [`torch.conv2d`, `torch.nn.Conv2D`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
- [`torch.matmul`](https://pytorch.org/docs/stable/generated/torch.matmul.html)
- [`torch.nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
- [`torch.conv1d`, `torch.nn.Conv1D`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html)
- [`torch.conv2d`, `torch.nn.Conv2D`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
- [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html)
- [`torch.nn.MaxPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html)

Concrete ML supports these operators but also the QAT equivalents from Brevitas.
Concrete ML also supports some of their QAT equivalents from Brevitas.

- `brevitas.nn.QuantLinear`
- `brevitas.nn.QuantConv1d`
- `brevitas.nn.QuantConv2d`

#### operators that can take both encrypted+unencrypted and encrypted+encrypted inputs
#### Multi-variate operators: encrypted+unencrypted or encrypted+encrypted inputs

- [`torch.add, torch.Tensor operator +`](https://pytorch.org/docs/stable/generated/torch.Tensor.add.html)
- [`torch.sub, torch.Tensor operator -`](https://pytorch.org/docs/stable/generated/torch.Tensor.sub.html)
- [`torch.matmul`](https://pytorch.org/docs/stable/generated/torch.matmul.html)

### Quantizers

- `brevitas.nn.QuantIdentity`

### Activations
### Activation functions

- [`torch.nn.Celu`](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html)
- [`torch.nn.Elu`](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html)
- [`torch.nn.CELU`](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html)
- [`torch.nn.ELU`](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html)
- [`torch.nn.GELU`](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html)
- [`torch.nn.Hardshrink`](https://pytorch.org/docs/stable/generated/torch.nn.Hardshrink.html)
- [`torch.nn.HardSigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html)
- [`torch.nn.Hardswish`](https://pytorch.org/docs/stable/generated/torch.nn.Hardswish)
- [`torch.nn.HardTanh`](https://pytorch.org/docs/stable/generated/torch.nn.Hardtanh.html)
- [`torch.nn.LeakyRelu`](https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html)
- [`torch.nn.LeakyReLU`](https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html)
- [`torch.nn.LogSigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.LogSigmoid.html)
- [`torch.nn.Mish`](https://pytorch.org/docs/stable/generated/torch.nn.Mish.html)
- [`torch.nn.PReLU`](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html)
- [`torch.nn.ReLU6`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU6.html)
- [`torch.nn.ReLU`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html)
- [`torch.nn.Selu`](https://pytorch.org/docs/stable/generated/torch.nn.SELU.html)
- [`torch.nn.SELU`](https://pytorch.org/docs/stable/generated/torch.nn.SELU.html)
- [`torch.nn.Sigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html)
- [`torch.nn.SiLU`](https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html)
- [`torch.nn.Softplus`](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)
Expand Down

0 comments on commit b617740

Please sign in to comment.