Skip to content

Commit

Permalink
chore: adding more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Feb 16, 2024
1 parent 9f4702f commit d926f43
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
11 changes: 7 additions & 4 deletions tests/quantization/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
@pytest.mark.parametrize("n_bits", [2])
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("verbose", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_quantized_module_compilation(
input_output_feature,
model,
Expand All @@ -55,6 +56,7 @@ def test_quantized_module_compilation(
simulate,
check_graph_input_has_no_tlu,
verbose,
compress_evaluation_keys,
check_is_good_execution_for_cml_vs_circuit,
):
"""Test a neural network compilation for FHE inference."""
Expand Down Expand Up @@ -82,6 +84,7 @@ def test_quantized_module_compilation(
numpy_input,
default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=simulate)
Expand All @@ -104,6 +107,7 @@ def test_quantized_module_compilation(
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("verbose", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_quantized_cnn_compilation(
input_output_feature,
model,
Expand All @@ -112,6 +116,7 @@ def test_quantized_cnn_compilation(
check_graph_input_has_no_tlu,
simulate,
verbose,
compress_evaluation_keys,
check_is_good_execution_for_cml_vs_circuit,
):
"""Test a convolutional neural network compilation for FHE inference."""
Expand Down Expand Up @@ -142,6 +147,7 @@ def test_quantized_cnn_compilation(
numpy_input,
default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)
check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=simulate)
check_graph_input_has_no_tlu(quantized_model.fhe_circuit.graph)
Expand Down Expand Up @@ -317,10 +323,7 @@ def test_compile_multi_input_nn_with_input_tlus(
quantized_model = post_training_quant.quantize_module(numpy_input)

# Compile
quantized_model.compile(
numpy_input,
default_configuration,
)
quantized_model.compile(numpy_input, default_configuration)
check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=True)

if can_remove_tlu:
Expand Down
44 changes: 43 additions & 1 deletion tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def create_test_inputset(inputset, n_percent_inputset_examples_test):
return x_test


def get_and_compile_quantized_module(model, inputset, import_qat, n_bits, configuration, verbose):
def get_and_compile_quantized_module(
model, inputset, import_qat, n_bits, configuration, verbose, compress_evaluation_keys
):
"""Get and compile the quantized module built from the given model."""
quantized_numpy_module = build_quantized_module(
model,
Expand All @@ -87,6 +89,7 @@ def get_and_compile_quantized_module(model, inputset, import_qat, n_bits, config
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

return quantized_numpy_module
Expand All @@ -108,6 +111,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
get_and_compile=False,
input_shape=None,
is_brevitas_qat=False,
compress_evaluation_keys=True,
) -> QuantizedModule:
"""Test the different model architecture from torch numpy."""

Expand Down Expand Up @@ -164,6 +168,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
n_bits=n_bits,
configuration=default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

else:
Expand All @@ -174,6 +179,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
configuration=default_configuration,
n_bits=n_bits,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)
else:
if is_brevitas_qat:
Expand All @@ -185,6 +191,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
n_bits=n_bits,
configuration=default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

elif get_and_compile:
Expand All @@ -195,6 +202,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
n_bits=n_bits,
configuration=default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

else:
Expand All @@ -205,6 +213,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
configuration=default_configuration,
n_bits=n_bits,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

n_examples_test = 1
Expand All @@ -228,6 +237,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
n_bits=n_bits,
configuration=default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

else:
Expand All @@ -253,6 +263,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
n_bits=n_bits,
configuration=default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

else:
Expand All @@ -263,6 +274,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
configuration=default_configuration,
n_bits=n_bits,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

accuracy_test_rounding(
Expand All @@ -276,6 +288,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many
verbose=verbose,
check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit,
is_brevitas_qat=is_brevitas_qat,
compress_evaluation_keys=compress_evaluation_keys,
)

if dump_onnx:
Expand All @@ -298,6 +311,7 @@ def accuracy_test_rounding(
simulate,
verbose,
check_is_good_execution_for_cml_vs_circuit,
compress_evaluation_keys,
is_brevitas_qat=False,
):
"""Check rounding behavior.
Expand Down Expand Up @@ -329,6 +343,7 @@ def accuracy_test_rounding(
configuration=configuration,
rounding_threshold_bits=8,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

# and another quantized module with a rounding threshold equal to 2 bits
Expand All @@ -339,6 +354,7 @@ def accuracy_test_rounding(
configuration=configuration,
rounding_threshold_bits=2,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

else:
Expand All @@ -351,6 +367,7 @@ def accuracy_test_rounding(
n_bits=n_bits,
rounding_threshold_bits=8,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

# and another quantized module with a rounding threshold equal to 2 bits
Expand All @@ -362,6 +379,7 @@ def accuracy_test_rounding(
n_bits=n_bits,
verbose=verbose,
rounding_threshold_bits=2,
compress_evaluation_keys=compress_evaluation_keys,
)

n_percent_inputset_examples_test = 0.1
Expand Down Expand Up @@ -463,6 +481,9 @@ def accuracy_test_rounding(
@pytest.mark.parametrize("simulate", [True, False], ids=["FHE_simulation", "FHE"])
@pytest.mark.parametrize("is_onnx", [True, False], ids=["is_onnx", ""])
@pytest.mark.parametrize("get_and_compile", [True, False], ids=["get_and_compile", "compile"])
@pytest.mark.parametrize(
"compress_evaluation_keys", [True, False], ids=["compress_evaluation_keys", ""]
)
def test_compile_torch_or_onnx_networks(
input_output_feature,
model,
Expand All @@ -473,6 +494,7 @@ def test_compile_torch_or_onnx_networks(
get_and_compile,
check_is_good_execution_for_cml_vs_circuit,
is_weekly_option,
compress_evaluation_keys,
):
"""Test the different model architecture from torch numpy."""

Expand All @@ -495,6 +517,7 @@ def test_compile_torch_or_onnx_networks(
check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit,
verbose=False,
get_and_compile=get_and_compile,
compress_evaluation_keys=compress_evaluation_keys,
)


Expand All @@ -517,6 +540,7 @@ def test_compile_torch_or_onnx_networks(
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("is_onnx", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
model,
is_1d,
Expand All @@ -527,6 +551,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
check_graph_input_has_no_tlu,
check_graph_output_has_no_tlu,
check_is_good_execution_for_cml_vs_circuit,
compress_evaluation_keys,
):
"""Test the different model architecture from torch numpy."""

Expand All @@ -547,6 +572,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit,
verbose=False,
input_shape=input_shape,
compress_evaluation_keys=compress_evaluation_keys,
)

check_graph_input_has_no_tlu(q_module.fhe_circuit.graph)
Expand Down Expand Up @@ -598,6 +624,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("is_onnx", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_compile_torch_or_onnx_activations(
input_output_feature,
model,
Expand All @@ -606,6 +633,7 @@ def test_compile_torch_or_onnx_activations(
simulate,
is_onnx,
check_is_good_execution_for_cml_vs_circuit,
compress_evaluation_keys,
):
"""Test the different model architecture from torch numpy."""

Expand All @@ -622,6 +650,7 @@ def test_compile_torch_or_onnx_activations(
is_onnx,
check_is_good_execution_for_cml_vs_circuit,
verbose=False,
compress_evaluation_keys=compress_evaluation_keys,
)


Expand All @@ -640,13 +669,15 @@ def test_compile_torch_or_onnx_activations(
[pytest.param(n_bits) for n_bits in [1, 2]],
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_compile_torch_qat(
input_output_feature,
model,
n_bits,
default_configuration,
simulate,
check_is_good_execution_for_cml_vs_circuit,
compress_evaluation_keys,
):
"""Test the different model architecture from torch numpy."""

Expand All @@ -666,6 +697,7 @@ def test_compile_torch_qat(
is_onnx,
check_is_good_execution_for_cml_vs_circuit,
verbose=False,
compress_evaluation_keys=compress_evaluation_keys,
)


Expand All @@ -678,6 +710,7 @@ def test_compile_torch_qat(
[pytest.param(n_bits) for n_bits in [2]],
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_compile_brevitas_qat(
model_class,
input_output_feature,
Expand All @@ -686,6 +719,7 @@ def test_compile_brevitas_qat(
simulate,
default_configuration,
check_is_good_execution_for_cml_vs_circuit,
compress_evaluation_keys,
):
"""Test compile_brevitas_qat_model."""

Expand All @@ -710,6 +744,7 @@ def test_compile_brevitas_qat(
check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit,
verbose=False,
is_brevitas_qat=is_brevitas_qat,
compress_evaluation_keys=compress_evaluation_keys,
)


Expand Down Expand Up @@ -758,12 +793,14 @@ def test_compile_brevitas_qat(
pytest.param(nn.ReLU, id="relu"),
],
)
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_dump_torch_network(
model_class,
expected_onnx_str,
activation_function,
default_configuration,
check_is_good_execution_for_cml_vs_circuit,
compress_evaluation_keys,
):
"""This is a test which is equivalent to tests in test_dump_onnx.py, but for torch modules."""
input_output_feature = 7
Expand All @@ -783,10 +820,12 @@ def test_dump_torch_network(
dump_onnx=True,
expected_onnx_str=expected_onnx_str,
verbose=False,
compress_evaluation_keys=compress_evaluation_keys,
)


@pytest.mark.parametrize("verbose", [True, False], ids=["with_verbose", "without_verbose"])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
# pylint: disable-next=too-many-locals
def test_pretrained_mnist_qat(
default_configuration,
Expand All @@ -795,6 +834,7 @@ def test_pretrained_mnist_qat(
check_graph_output_has_no_tlu,
check_is_good_execution_for_cml_vs_circuit,
is_weekly_option,
compress_evaluation_keys,
):
"""Load a QAT MNIST model and confirm we get the same results in FHE simulation as with ONNX."""
if not is_weekly_option:
Expand Down Expand Up @@ -845,6 +885,7 @@ def test_pretrained_mnist_qat(
configuration=default_configuration,
n_bits=n_bits,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

quantized_numpy_module.check_model_is_compiled()
Expand Down Expand Up @@ -884,6 +925,7 @@ def test_pretrained_mnist_qat(
configuration=default_configuration,
n_bits=n_bits,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

# As this is a custom QAT network, the input goes through multiple univariate
Expand Down

0 comments on commit d926f43

Please sign in to comment.