Skip to content

Commit

Permalink
Implement tests, todo: check lsbs to remove with Andrei
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Aug 24, 2023
1 parent 373e81b commit 379a5f8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
14 changes: 8 additions & 6 deletions src/concrete/ml/quantization/quantized_module_passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Optimization passes for QuantizedModules."""
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, DefaultDict
from collections import defaultdict

import numpy
Expand All @@ -17,7 +17,7 @@

# A dictionary that contains for a quantized op a list of predecessor ops
# Each predecessor op is stored along with its output tensor name
PredecessorsType = Dict[QuantizedOp, List[Tuple[Optional[QuantizedOp], str]]]
PredecessorsType = DefaultDict[Optional[QuantizedOp], List[Tuple[Optional[QuantizedOp], str]]]

# A list of optimizable patterns. For a "Mixing" op that supports rounding accumulators
# we store a list of ops which contain information that allows us to
Expand Down Expand Up @@ -194,9 +194,8 @@ def detect_patterns(self, predecessors: PredecessorsType) -> PatternDict:
# The Gemm/Conv op that produces this integer node is the one
# onto which we apply the roundPBS optimization
nodes_in_path.append(back_node)
assert back_node is not None
list_pred_of_path = predecessors.get(back_node, None)
if list_pred_of_path is not None and len(list_pred_of_path) == 1:
list_pred_of_path = predecessors[back_node]
if len(list_pred_of_path) == 1:
integer_node_input_quant = list_pred_of_path[0][0]

assert isinstance(node_op, QuantizedMixingOp)
Expand Down Expand Up @@ -273,7 +272,10 @@ def integer_log2(value: float) -> Tuple[int, bool]:
# operation to perform the scaling. Thus the
# number of lsbs to round is the negative of the sum of log2
# of the scale factors
lsbs_to_round = -(log2_input + log2_weights - log2_output)
lsbs_to_round = - (log2_input + log2_weights - log2_output)
# log2_output - log2_input - log2_weights
# TODO: check this part with Andrei
# How is it possible to have like that?
path_start_node.rounding_threshold_bits = lsbs_to_round
path_start_node.lsbs_to_remove = lsbs_to_round
else:
Expand Down
52 changes: 32 additions & 20 deletions tests/torch/test_brevitas_qat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests with brevitas quantization aware training."""

from typing import Optional
import brevitas.nn as qnn
import numpy
import pytest
Expand All @@ -21,7 +22,9 @@
from concrete.ml.sklearn.qnn_module import SparseQuantNeuralNetwork
from concrete.ml.torch.compile import compile_brevitas_qat_model
from concrete.ml.quantization.quantizers import Int8ActPerTensorPoT, Int8WeightPerTensorPoT
from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import IntBias
from concrete.ml.quantization.quantized_module import QuantizedModule
from concrete.ml.quantization.base_quantized_op import QuantizedMixingOp
from concrete.ml.quantization.post_training import PowerOfTwoScalingRoundPBSAdapter

Expand Down Expand Up @@ -409,57 +412,66 @@ def test_brevitas_constant_folding(default_configuration):
)


def test_brevitas_power_of_two(default_configuration):
@pytest.mark.parametrize("rounding", [None, 4])
@pytest.mark.parametrize("power_of_two", [True, False])
@pytest.mark.parametrize("n_bits", [5])
def test_brevitas_power_of_two(default_configuration, rounding: Optional[int], power_of_two: bool, n_bits: int):
"""Test that a network that does not quantize its inputs raises the right exception.
The network tested is not a valid QAT network for Concrete ML as it does not
quantize its inputs. However, in previous versions of Concrete ML a bug
in constant folding prevented the correct error being raised.
"""

input_shape = 32
output_shape = 2
hidden_shape = 64
batch_size = 128
data = torch.randn((batch_size, input_shape))
model_n_bits = 6
rounding_n_bits = 5

model = QuantCustomModel(
input_shape, output_shape, hidden_shape, n_bits=4,
act_quant=Int8ActPerTensorPoT,
weight_quant=Int8WeightPerTensorPoT,
bias_quant=IntBias,
input_shape, output_shape, hidden_shape, n_bits=model_n_bits,
act_quant=Int8ActPerTensorPoT if power_of_two else Int8ActPerTensorFloat,
weight_quant=Int8WeightPerTensorPoT if power_of_two else Int8WeightPerTensorFloat,
bias_quant=IntBias if power_of_two else None,
)

# If rounding threshold is set -> nothing happens
# If Quantizer is not setup -> nothing happens
quantized_module = compile_brevitas_qat_model(
model.to("cpu"),
torch_inputset=data,
configuration=default_configuration,
rounding_threshold_bits=6,
rounding_threshold_bits=rounding_n_bits,
)
# TODO: add checks

pot_should_be_applied = not rounding and power_of_two
# Count the number of patterns that were optimized with roundPBS
num_round_pbs_layers = 0
for (_, node_op) in quantized_module.quant_layers_dict.values():
if isinstance(node_op, QuantizedMixingOp):
num_round_pbs_layers += 1 if node_op.rounding_threshold_bits is not None else 0
assert node_op.rounding_threshold_bits == node_op.lsbs_to_remove

if pot_should_be_applied:
assert node_op.rounding_threshold_bits == node_op.lsbs_to_remove
else:
assert node_op.rounding_threshold_bits != node_op.lsbs_to_remove
# Apply the PowerOfTwoScalingRoundPBSAdapter again. The second time
# the adapter will ignore already optimized patterns but report them
# as ignored.
adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module)
round_pbs_patterns = adapter.process()

# The power-of-two optimization will only work
# when Relu activations are used and scaling factors are forced to be 2**s
assert (
len(round_pbs_patterns) == 0
), "Expected number of round PBS optimized patterns was not matched"
# 3 layers
assert (
adapter.num_ignored_valid_patterns == 3 - 1
), "Expected number of ignored round PBS optimizable patterns was not matched"
if pot_should_be_applied:
assert (
len(round_pbs_patterns) == 0
), "Expected number of round PBS optimized patterns was not matched"
# 3 layers
assert (
adapter.num_ignored_valid_patterns == 3 - 1
), "Expected number of ignored round PBS optimizable patterns was not matched"
else:
pass

# y_pred_clear_round = model.predict(x_test, fhe="disable")

Expand Down

0 comments on commit 379a5f8

Please sign in to comment.