From e6d3f353c1827a8a982a3a3128adc8733a9be819 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jordan=20Fr=C3=A9ry?= Date: Fri, 26 Jan 2024 13:18:16 +0100 Subject: [PATCH] chore: refacto quantized module init + handle no input onnx --- pyproject.toml | 2 + src/concrete/ml/onnx/convert.py | 10 +++ src/concrete/ml/onnx/ops_impl.py | 13 ++-- .../ml/quantization/quantized_module.py | 51 +++++++++------ tests/quantization/test_quantized_module.py | 15 +++++ tests/torch/test_compile_torch.py | 64 +++++++++++++++++++ 6 files changed, 133 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8044a21db..4e9f97793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,8 @@ filterwarnings = [ "ignore:non-integer arguments to randrange\\(\\) have been deprecated since Python 3\\.10 and will be removed in a subsequent version:DeprecationWarning", "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", "ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning", + "ignore:Converting a tensor to a NumPy array might cause the trace to be incorrect.", + "ignore:torch.from_numpy results are registered as constants in the trace.", ] [tool.semantic_release] diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 96a524d7d..1aacbcd65 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -12,6 +12,7 @@ import torch from onnx import checker, helper +from ..common.debugging import assert_true from .onnx_utils import ( IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, @@ -149,6 +150,15 @@ def get_equivalent_numpy_forward_from_torch( input_names=arguments, ) equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path)) + + # Check if the inputs are present in the model's graph + for input_name in arguments: + assert_true( + any(input_name == node.name for node in equivalent_onnx_model.graph.input), + f"Input '{input_name}' is missing in the ONNX graph after export. " + "Verify the forward pass for issues.", + ) + # Remove the tempfile if we used one if use_tempfile: output_onnx_file_path.unlink() diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index 1b13f8ef7..7cae4eb27 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -1479,9 +1479,9 @@ def numpy_pad( def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]: """Execute ONNX cast in Numpy. - For traced values during compilation, it supports only booleans, which are converted to float. - For raw values (used in constant folding or shape computations), any cast is allowed. - + This function supports casting to booleans, floats, and double for traced values, + converting them accordingly. For raw values (used in constant folding or shape computations), + any cast is allowed. See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast Args: @@ -1495,7 +1495,12 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]: if isinstance(data, RawOpOutput): return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),) - assert_true(to == onnx.TensorProto.BOOL) + allowed_types = (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT, onnx.TensorProto.DOUBLE) + assert to in allowed_types, ( + f"Invalid 'to' data type: {onnx.TensorProto.DataType.Name(to)}. " + f"Only {', '.join(onnx.TensorProto.DataType.Name(t) for t in allowed_types)}" + "are allowed for casting." + ) # Will be used for traced values return (data.astype(numpy.float64),) diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index 124bf151e..74bf8712e 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -96,31 +96,46 @@ def __init__( quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None, onnx_model: Optional[onnx.ModelProto] = None, ): + + all_or_none_params = [ + ordered_module_input_names, + ordered_module_output_names, + quant_layers_dict, + ] + if not ( + all(v is None or v == {} for v in all_or_none_params) + or not any(v is None or v == {} for v in all_or_none_params) + ): + raise ValueError( + "Please either set all three 'ordered_module_input_names', " + "'ordered_module_output_names' and 'quant_layers_dict' or none of them." + ) + + self.ordered_module_input_names = ( + tuple(ordered_module_input_names) if ordered_module_input_names else () + ) + self.ordered_module_output_names = ( + tuple(ordered_module_output_names) if ordered_module_output_names else () + ) + self.quant_layers_dict = ( + copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {} + ) + # Set base attributes for API consistency. This could be avoided if an abstract base class # is created for both Concrete ML models and QuantizedModule # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899 - self.fhe_circuit = None + self.input_quantizers: List[UniformQuantizer] = [] + self.output_quantizers: List[UniformQuantizer] = [] + self.fhe_circuit: Optional[Circuit] = None self._is_compiled = False - self.input_quantizers = [] - self.output_quantizers = [] self._onnx_model = onnx_model self._post_processing_params: Dict[str, Any] = {} - # If any of the arguments are not provided, skip the init - if not all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict]): - return - - # for mypy - assert isinstance(ordered_module_input_names, Iterable) - assert isinstance(ordered_module_output_names, Iterable) - assert all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict]) - self.ordered_module_input_names = tuple(ordered_module_input_names) - self.ordered_module_output_names = tuple(ordered_module_output_names) - - assert quant_layers_dict is not None - self.quant_layers_dict = copy.deepcopy(quant_layers_dict) - - self.output_quantizers = self._set_output_quantizers() + # Initialize output quantizers based on quant_layers_dict + if self.quant_layers_dict: + self.output_quantizers = self._set_output_quantizers() + else: + self.output_quantizers = [] # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127 def set_reduce_sum_copy(self): diff --git a/tests/quantization/test_quantized_module.py b/tests/quantization/test_quantized_module.py index 0ef20853b..75f65ca2a 100644 --- a/tests/quantization/test_quantized_module.py +++ b/tests/quantization/test_quantized_module.py @@ -496,3 +496,18 @@ def test_serialization(model_class, input_shape): QuantizedModule, equal_method=partial(quantized_module_predictions_are_equal, x=numpy_input), ) + + +def test_quantized_module_initialization_error(): + """Test initialization fails with mismatched parameters.""" + # Initialize with invalid parameters + with pytest.raises( + ValueError, + match=r"Please either set all three 'ordered_module_input_names', " + r"'ordered_module_output_names' and 'quant_layers_dict' or none of them.", + ): + QuantizedModule( + ordered_module_input_names=["input1", "input2"], + ordered_module_output_names=None, # This makes the combination invalid + quant_layers_dict={"layer1": (["input1"], "QuantizedOp")}, + ) diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index d5be6c0d1..5dd5d0a9e 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1373,3 +1373,67 @@ def test_mono_parameter_rounding_warning( verbose=False, get_and_compile=False, ) + + +@pytest.mark.parametrize( + "cast_type, should_fail, error_message", + [ + (torch.bool, False, None), + (torch.float32, False, None), + (torch.float64, False, None), + (torch.int64, True, r"Invalid 'to' data type: INT64"), + ], +) +def test_compile_torch_model_with_cast(cast_type, should_fail, error_message): + """Test compiling a Torch model with various casts, expecting failure for invalid types.""" + torch_input = torch.randn(100, 28) + + class CastNet(nn.Module): + """Network with cast.""" + + def __init__(self, cast_to): + super().__init__() + self.threshold = torch.tensor(0.5, dtype=torch.float32) + self.cast_to = cast_to + + def forward(self, x): + """Forward pass with dynamic cast.""" + zeros = torch.zeros_like(x) + x = x + zeros + x = (x > self.threshold).to(self.cast_to) + return x + + model = CastNet(cast_type) + + if should_fail: + with pytest.raises(AssertionError, match=error_message): + compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3) + else: + compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3) + + +def test_onnx_no_input(): + """Test a torch model that has no input when converted to onnx.""" + + torch_input = torch.randn(100, 28) + + class NoInputNet(nn.Module): + """Network with no input in the onnx graph.""" + + def __init__(self): + super().__init__() + self.threshold = torch.tensor(0.5, dtype=torch.float32) + + def forward(self, x): + """Forward pass.""" + zeros = numpy.zeros_like(x) + x = x + zeros + x = (x > self.threshold).to(torch.float32) + return x + + model = NoInputNet() + + with pytest.raises( + AssertionError, match="Input 'x' is missing in the ONNX graph after export." + ): + compile_torch_model(model, torch_input, rounding_threshold_bits=3)