Sparse Quantized Neural Network torch module.
- MAX_BITWIDTH_BACKWARD_COMPATIBLE
Sparse Quantized Neural Network.
This class implements an MLP that is compatible with FHE constraints. The weights and activations are quantized to low bit-width and pruning is used to ensure accumulators do not surpass an user-provided accumulator bit-width. The number of classes and number of layers are specified by the user, as well as the breadth of the network
__init__(
input_dim,
n_layers,
n_outputs,
n_hidden_neurons_multiplier=4,
n_w_bits=3,
n_a_bits=3,
n_accum_bits=8,
n_prune_neurons_percentage=0.0,
activation_function=<class 'torch.nn.modules.activation.ReLU'>,
quant_narrow=False,
quant_signed=True
)
Sparse Quantized Neural Network constructor.
Args:
input_dim
: Number of dimensions of the input datan_layers
: Number of linear layers for this networkn_outputs
: Number of output classes or regression targetsn_w_bits
: Number of weight bitsn_a_bits
: Number of activation and input bitsn_accum_bits
: Maximal allowed bit-width of intermediate accumulatorsn_hidden_neurons_multiplier
: The number of neurons on the hidden will be the number of dimensions of the input multiplied byn_hidden_neurons_multiplier
. Note that pruning is used to adjust the accumulator size to attempt to keep the maximum accumulator bit-width ton_accum_bits
, meaning that not all hidden layer neurons will be active. The default value forn_hidden_neurons_multiplier
is chosen for small dimensions of the input. Reducing this value decreases the FHE inference time considerably but also decreases the robustness and accuracy of model training.n_prune_neurons_percentage
: How many neurons to prune on the hidden layers. This should be used mostly through the dedicated.prune()
mechanism. This can be used in when settingn_hidden_neurons_multiplier
high (3-4), once good accuracy is obtained, to speed up the model in FHE.activation_function
: a torch class that is used to construct activation functions in the network (eg torch.ReLU, torch.SELU, torch.Sigmoid, etc)quant_narrow
: whether this network should use narrow range quantized integer valuesquant_signed
: whether to use signed quantized integer values
Raises:
ValueError
: if the parameters have invalid values or the computed accumulator bit-width is zero
enable_pruning() → None
Enable pruning in the network. Pruning must be made permanent to recover pruned weights.
Raises:
ValueError
: if the quantization parameters are invalid
forward(x: Tensor) → Tensor
Forward pass.
Args:
x
(torch.Tensor): network input
Returns:
x
(torch.Tensor): network prediction
make_pruning_permanent() → None
Make the learned pruning permanent in the network.
max_active_neurons() → int
Compute the maximum number of active (non-zero weight) neurons.
The computation is done using the quantization parameters passed to the constructor. Warning: With the current quantization algorithm (asymmetric) the value returned by this function is not guaranteed to ensure FHE compatibility. For some weight distributions, weights that are 0 (which are pruned weights) will not be quantized to 0. Therefore the total number of active quantized neurons will not be equal to max_active_neurons.
Returns:
n
(int): maximum number of active neurons