Skip to content

Commit

Permalink
chore: fix correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Sep 7, 2023
1 parent 29caabb commit 5891aed
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 20 deletions.
8 changes: 8 additions & 0 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
# Indicate if the old simulation method should be used when simulating FHE executions
USE_OLD_VL = True

# Debug option for testing round PBS optimization
# Setting this option to true will make quantizers "round half up"
# For example: 0.5 -> 1, 1.5 -> 2 instead of "round half to even"
# When the option is set to false, Concrete ML uses numpy.rint
# which has the same behavior as torch.round -> Brevitas nets
# should be exact compared to their Concrete ML QuantizedModule
QUANT_ROUND_LIKE_ROUND_PBS = False


class FheMode(str, enum.Enum):
"""Enum representing the execution mode.
Expand Down
6 changes: 5 additions & 1 deletion src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy import special
from typing_extensions import SupportsIndex

from ..common import utils
from ..common.debugging import assert_false, assert_true
from .onnx_impl_utils import (
compute_onnx_pool_padding,
Expand Down Expand Up @@ -1653,7 +1654,10 @@ def numpy_brevitas_quant(
y = numpy.clip(y, min_int_val, max_int_val)

# Quantize to produce integers representing the float quantized values
y = numpy.rint(y)
if utils.QUANT_ROUND_LIKE_ROUND_PBS:
y = numpy.floor(y + 0.5)
else:
y = numpy.rint(y)

# Compute quantized floating point values
y = (y - zero_point) * scale
Expand Down
8 changes: 3 additions & 5 deletions src/concrete/ml/quantization/quantized_module_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,9 @@ def integer_log2(value: float) -> Tuple[int, bool]:
# number of lsbs to round is the negative of the sum of log2
# of the scale factors
lsbs_to_round = -(log2_input + log2_weights - log2_output)
# log2_output - log2_input - log2_weights
# TODO: check this part with Andrei
# How is it possible to have like that?
path_start_node.rounding_threshold_bits = lsbs_to_round
path_start_node.lsbs_to_remove = lsbs_to_round
if lsbs_to_round > 0:
path_start_node.rounding_threshold_bits = lsbs_to_round
path_start_node.lsbs_to_remove = lsbs_to_round
else:
invalid_paths.append(path_start_node)

Expand Down
6 changes: 5 additions & 1 deletion src/concrete/ml/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy

from ..common import utils
from ..common.debugging import assert_true
from ..common.serialization.dumpers import dump, dumps

Expand Down Expand Up @@ -745,7 +746,10 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray:
assert self.offset is not None
assert self.scale is not None

qvalues = numpy.rint(values / self.scale + self.zero_point)
if utils.QUANT_ROUND_LIKE_ROUND_PBS:
qvalues = numpy.floor(values / self.scale + self.zero_point + 0.5)
else:
qvalues = numpy.rint(values / self.scale + self.zero_point)

# Clipping can be performed for PTQ and for precomputed (for now only Brevitas) QAT
# (where quantizer parameters are available in ONNX layers).
Expand Down
10 changes: 5 additions & 5 deletions src/concrete/ml/sklearn/qnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch import nn

from ..common.debugging import assert_true
from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE
from ..quantization.qat_quantizers import Int8ActPerTensorPoT, Int8WeightPerTensorPoT


Expand All @@ -29,14 +28,15 @@ def __init__(
n_layers: int,
n_outputs: int,
n_hidden_neurons_multiplier: int = 4,
n_w_bits: int = 3,
n_a_bits: int = 3,
n_accum_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_w_bits: int = 4,
n_a_bits: int = 4,
# No pruning by default as roundPBS keeps the PBS precision low
n_accum_bits: int = 32,
n_prune_neurons_percentage: float = 0.0,
activation_function: Type = nn.ReLU,
quant_narrow: bool = False,
quant_signed: bool = True,
power_of_two_scaling: bool = False,
power_of_two_scaling: bool = True, # Default to true: use roundPBS to speed up the NNs
):
"""Sparse Quantized Neural Network constructor.
Expand Down
15 changes: 7 additions & 8 deletions tests/torch/test_brevitas_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from concrete.ml.common import utils
from concrete.ml.common.utils import (
is_classifier_or_partial_classifier,
is_regressor_or_partial_regressor,
Expand Down Expand Up @@ -514,6 +515,8 @@ def test_brevitas_power_of_two(

net, x_all, _ = train_brevitas_network_tinymnist(is_cnn, n_bits, True, False, power_of_two)

utils.QUANT_ROUND_LIKE_ROUND_PBS = True

# If rounding threshold is set -> nothing happens
# If Quantizer is not setup -> nothing happens
quantized_module = compile_brevitas_qat_model(
Expand Down Expand Up @@ -590,11 +593,7 @@ def test_brevitas_power_of_two(
)

# # Compare the result with the optimized network and without
# # they should be equal (allow 3 non-matching value out of 100)
# TODO: actually verify correctness here, this is just a placeholder
# https://github.com/zama-ai/concrete-ml-internal/issues/3946
assert y_pred_sim_round.shape == y_pred_clear_round.shape
assert y_pred_clear_round.shape == y_pred_clear_no_round.shape

# assert numpy.sum(y_pred_sim_round != y_pred_clear_round) <= 3
# assert numpy.sum(y_pred_clear_round != y_pred_clear_no_round) <= 3
# # they should be equal

assert numpy.sum(y_pred_sim_round != y_pred_clear_round) == 0
assert numpy.sum(y_pred_clear_round != y_pred_clear_no_round) == 0

0 comments on commit 5891aed

Please sign in to comment.