From 0740717770315b65e9416306a3b7054e1eb6d134 Mon Sep 17 00:00:00 2001 From: jfrery Date: Wed, 24 Jan 2024 12:02:04 +0100 Subject: [PATCH] chore: review have a ValueError instead of assert_true in qmodule init --- src/concrete/ml/onnx/convert.py | 3 +- src/concrete/ml/onnx/ops_impl.py | 13 +++-- .../ml/quantization/quantized_module.py | 22 +++++---- tests/quantization/test_quantized_module.py | 15 ++++++ tests/torch/test_compile_torch.py | 49 ++++++++++++++++--- 5 files changed, 80 insertions(+), 22 deletions(-) diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 2b076f26f..1aacbcd65 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -155,7 +155,8 @@ def get_equivalent_numpy_forward_from_torch( for input_name in arguments: assert_true( any(input_name == node.name for node in equivalent_onnx_model.graph.input), - f"Input '{input_name}' is not present in the ONNX model. Please check the onnx graph.", + f"Input '{input_name}' is missing in the ONNX graph after export. " + "Verify the forward pass for issues.", ) # Remove the tempfile if we used one diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index fa4191b75..7cae4eb27 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -1479,10 +1479,9 @@ def numpy_pad( def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]: """Execute ONNX cast in Numpy. - For traced values during compilation, it supports booleans and floats, - which are converted to float. For raw values (used in constant folding or shape computations), + This function supports casting to booleans, floats, and double for traced values, + converting them accordingly. For raw values (used in constant folding or shape computations), any cast is allowed. - See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast Args: @@ -1496,8 +1495,12 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]: if isinstance(data, RawOpOutput): return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),) - # Allow both bool and float types - assert_true(to in (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT)) + allowed_types = (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT, onnx.TensorProto.DOUBLE) + assert to in allowed_types, ( + f"Invalid 'to' data type: {onnx.TensorProto.DataType.Name(to)}. " + f"Only {', '.join(onnx.TensorProto.DataType.Name(t) for t in allowed_types)}" + "are allowed for casting." + ) # Will be used for traced values return (data.astype(numpy.float64),) diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index f3bcef1b1..74bf8712e 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -96,21 +96,21 @@ def __init__( quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None, onnx_model: Optional[onnx.ModelProto] = None, ): - # Set base attributes for API consistency. This could be avoided if an abstract base class - # is created for both Concrete ML models and QuantizedModule - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899 all_or_none_params = [ ordered_module_input_names, ordered_module_output_names, quant_layers_dict, ] - assert_true( + if not ( all(v is None or v == {} for v in all_or_none_params) - or not any(v is None or v == {} for v in all_or_none_params), - "All of ordered_module_input_names, ordered_module_output_names, " - "and quant_layers_dict must be provided if any one of them is provided.", - ) + or not any(v is None or v == {} for v in all_or_none_params) + ): + raise ValueError( + "Please either set all three 'ordered_module_input_names', " + "'ordered_module_output_names' and 'quant_layers_dict' or none of them." + ) + self.ordered_module_input_names = ( tuple(ordered_module_input_names) if ordered_module_input_names else () ) @@ -120,9 +120,13 @@ def __init__( self.quant_layers_dict = ( copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {} ) + + # Set base attributes for API consistency. This could be avoided if an abstract base class + # is created for both Concrete ML models and QuantizedModule + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899 self.input_quantizers: List[UniformQuantizer] = [] self.output_quantizers: List[UniformQuantizer] = [] - self.fhe_circuit: Union[None, Circuit] = None + self.fhe_circuit: Optional[Circuit] = None self._is_compiled = False self._onnx_model = onnx_model self._post_processing_params: Dict[str, Any] = {} diff --git a/tests/quantization/test_quantized_module.py b/tests/quantization/test_quantized_module.py index 0ef20853b..75f65ca2a 100644 --- a/tests/quantization/test_quantized_module.py +++ b/tests/quantization/test_quantized_module.py @@ -496,3 +496,18 @@ def test_serialization(model_class, input_shape): QuantizedModule, equal_method=partial(quantized_module_predictions_are_equal, x=numpy_input), ) + + +def test_quantized_module_initialization_error(): + """Test initialization fails with mismatched parameters.""" + # Initialize with invalid parameters + with pytest.raises( + ValueError, + match=r"Please either set all three 'ordered_module_input_names', " + r"'ordered_module_output_names' and 'quant_layers_dict' or none of them.", + ): + QuantizedModule( + ordered_module_input_names=["input1", "input2"], + ordered_module_output_names=None, # This makes the combination invalid + quant_layers_dict={"layer1": (["input1"], "QuantizedOp")}, + ) diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 57f3d66c8..5dd5d0a9e 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1375,17 +1375,53 @@ def test_mono_parameter_rounding_warning( ) +@pytest.mark.parametrize( + "cast_type, should_fail, error_message", + [ + (torch.bool, False, None), + (torch.float32, False, None), + (torch.float64, False, None), + (torch.int64, True, r"Invalid 'to' data type: INT64"), + ], +) +def test_compile_torch_model_with_cast(cast_type, should_fail, error_message): + """Test compiling a Torch model with various casts, expecting failure for invalid types.""" + torch_input = torch.randn(100, 28) + + class CastNet(nn.Module): + """Network with cast.""" + + def __init__(self, cast_to): + super().__init__() + self.threshold = torch.tensor(0.5, dtype=torch.float32) + self.cast_to = cast_to + + def forward(self, x): + """Forward pass with dynamic cast.""" + zeros = torch.zeros_like(x) + x = x + zeros + x = (x > self.threshold).to(self.cast_to) + return x + + model = CastNet(cast_type) + + if should_fail: + with pytest.raises(AssertionError, match=error_message): + compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3) + else: + compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3) + + def test_onnx_no_input(): """Test a torch model that has no input when converted to onnx.""" torch_input = torch.randn(100, 28) - class SimplifiedNet(nn.Module): + class NoInputNet(nn.Module): """Network with no input in the onnx graph.""" def __init__(self): super().__init__() - self.fc1 = nn.Linear(28, 10) self.threshold = torch.tensor(0.5, dtype=torch.float32) def forward(self, x): @@ -1393,12 +1429,11 @@ def forward(self, x): zeros = numpy.zeros_like(x) x = x + zeros x = (x > self.threshold).to(torch.float32) - x = self.fc1(x) return x - model = SimplifiedNet() + model = NoInputNet() - with pytest.raises(AssertionError) as excinfo: + with pytest.raises( + AssertionError, match="Input 'x' is missing in the ONNX graph after export." + ): compile_torch_model(model, torch_input, rounding_threshold_bits=3) - - assert "Input 'x' is not present in the ONNX model" in str(excinfo.value)