Skip to content

Commit

Permalink
chore: adding more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Feb 16, 2024
1 parent 9f4702f commit d3eb96e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
11 changes: 7 additions & 4 deletions tests/quantization/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
@pytest.mark.parametrize("n_bits", [2])
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("verbose", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_quantized_module_compilation(
input_output_feature,
model,
Expand All @@ -55,6 +56,7 @@ def test_quantized_module_compilation(
simulate,
check_graph_input_has_no_tlu,
verbose,
compress_evaluation_keys,
check_is_good_execution_for_cml_vs_circuit,
):
"""Test a neural network compilation for FHE inference."""
Expand Down Expand Up @@ -82,6 +84,7 @@ def test_quantized_module_compilation(
numpy_input,
default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=simulate)
Expand All @@ -104,6 +107,7 @@ def test_quantized_module_compilation(
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("verbose", [True, False])
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_quantized_cnn_compilation(
input_output_feature,
model,
Expand All @@ -112,6 +116,7 @@ def test_quantized_cnn_compilation(
check_graph_input_has_no_tlu,
simulate,
verbose,
compress_evaluation_keys,
check_is_good_execution_for_cml_vs_circuit,
):
"""Test a convolutional neural network compilation for FHE inference."""
Expand Down Expand Up @@ -142,6 +147,7 @@ def test_quantized_cnn_compilation(
numpy_input,
default_configuration,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)
check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=simulate)
check_graph_input_has_no_tlu(quantized_model.fhe_circuit.graph)
Expand Down Expand Up @@ -317,10 +323,7 @@ def test_compile_multi_input_nn_with_input_tlus(
quantized_model = post_training_quant.quantize_module(numpy_input)

# Compile
quantized_model.compile(
numpy_input,
default_configuration,
)
quantized_model.compile(numpy_input, default_configuration)
check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=True)

if can_remove_tlu:
Expand Down
14 changes: 11 additions & 3 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,21 +1492,23 @@ def test_input_support(


@pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS)
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
def test_inference_methods(
model_class,
parameters,
load_data,
is_weekly_option,
check_float_array_equal,
default_configuration,
compress_evaluation_keys,
verbose=True,
):
"""Test inference methods."""
n_bits = get_n_bits_non_correctness(model_class)

model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option)

model.compile(x, default_configuration)
model.compile(x, default_configuration, compress_evaluation_keys=compress_evaluation_keys)

if verbose:
print("Run check_inference_methods")
Expand Down Expand Up @@ -1562,6 +1564,7 @@ def test_pipeline(
< min(N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS, N_BITS_THRESHOLD_TO_FORCE_EXECUTION_NOT_IN_FHE)
],
)
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
# pylint: disable=too-many-branches
def test_predict_correctness(
model_class,
Expand All @@ -1572,6 +1575,7 @@ def test_predict_correctness(
default_configuration,
check_is_good_execution_for_cml_vs_circuit,
is_weekly_option,
compress_evaluation_keys,
verbose=True,
):
"""Test prediction correctness between clear quantized and FHE simulation or execution."""
Expand All @@ -1591,7 +1595,7 @@ def test_predict_correctness(
if verbose:
print("Compile the model")

model.compile(x, default_configuration)
model.compile(x, default_configuration, compress_evaluation_keys=compress_evaluation_keys)

if verbose:
print(f"Check prediction correctness for {fhe_samples} samples.")
Expand Down Expand Up @@ -1623,6 +1627,7 @@ def test_predict_correctness(
< min(N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS, N_BITS_THRESHOLD_TO_FORCE_EXECUTION_NOT_IN_FHE)
],
)
@pytest.mark.parametrize("compress_evaluation_keys", [True, False])
# pylint: disable=too-many-branches
def test_separated_inference(
model_class,
Expand All @@ -1633,6 +1638,7 @@ def test_separated_inference(
default_configuration,
is_weekly_option,
check_float_array_equal,
compress_evaluation_keys,
verbose=True,
):
"""Test prediction correctness between clear quantized and FHE simulation or execution."""
Expand All @@ -1653,7 +1659,9 @@ def test_separated_inference(
if verbose:
print("Compile the model")

fhe_circuit = model.compile(x, default_configuration)
fhe_circuit = model.compile(
x, default_configuration, compress_evaluation_keys=compress_evaluation_keys
)

if verbose:
print("Run check_separated_inference")
Expand Down
Loading

0 comments on commit d3eb96e

Please sign in to comment.