Skip to content

Commit

Permalink
chore: use cp vl when circuit is CRT based (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Apr 17, 2024
1 parent 678c866 commit cf6f224
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,13 @@ def _fhe_forward(

# If the inference should be executed using simulation
if simulate:
is_crt_encoding = self.fhe_circuit.statistics["packing_key_switch_count"] != 0

# If the old simulation method should be used
if USE_OLD_VL:
# If the virtual library method should be used
# For now, use the virtual library when simulating
# circuits that use CRT encoding because the official simulation is too slow
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4391
if USE_OLD_VL or is_crt_encoding:
predict_method = partial(
self.fhe_circuit.graph, p_error=self.fhe_circuit.p_error
) # pragma: no cover
Expand Down
8 changes: 6 additions & 2 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,13 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.

# If the inference should be executed using simulation
if fhe == "simulate":
is_crt_encoding = self.fhe_circuit.statistics["packing_key_switch_count"] != 0

# If the old simulation method should be used
if USE_OLD_VL:
# If the virtual library method should be used
# For now, use the virtual library when simulating
# circuits that use CRT encoding because the official simulation is too slow
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4391
if USE_OLD_VL or is_crt_encoding:
predict_method = partial(
self.fhe_circuit.graph, p_error=self.fhe_circuit.p_error
) # pragma: no cover
Expand Down
46 changes: 46 additions & 0 deletions tests/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,49 @@ def test_quantized_module_initialization_error():
ordered_module_output_names=None, # This makes the combination invalid
quant_layers_dict={"layer1": (["input1"], "QuantizedOp")},
)


@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))])
def test_crt_circuit_creation_with_rounding(model_class, input_shape, default_configuration):
"""Test the creation of CRT and non-CRT circuits based on rounding settings."""

torch_fc_model = model_class(activation_function=nn.ReLU)
torch_fc_model.eval()

# Create random input
numpy_input = numpy.random.uniform(size=input_shape)
torch_input = torch.from_numpy(numpy_input).float()

# Compile with rounding_threshold_bits = 6
quantized_model_with_rounding = compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=6,
p_error=0.01,
rounding_threshold_bits=4,
)

# Check that the packing key switch count is not present
# This should not be using CRT encoding
assert (
quantized_model_with_rounding.fhe_circuit.statistics.get("packing_key_switch_count", 0) == 0
), "Packing key switch count should be 0 when rounding_threshold_bits is set."

# Compile with rounding_threshold_bits = None
# Which should create a CRT bsaed encoding circuit
quantized_model_without_rounding = compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=4,
rounding_threshold_bits=None,
)

# Check that the packing key switch count is not zero
assert (
quantized_model_without_rounding.fhe_circuit.statistics.get("packing_key_switch_count", 0)
> 0
), "Packing key switch count should be > 0 when rounding_threshold_bits is not set."

0 comments on commit cf6f224

Please sign in to comment.