Skip to content

Commit

Permalink
chore: review
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jan 24, 2024
1 parent ec2b6f5 commit 007afef
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
13 changes: 8 additions & 5 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,10 +1479,9 @@ def numpy_pad(
def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
"""Execute ONNX cast in Numpy.
For traced values during compilation, it supports booleans and floats,
which are converted to float. For raw values (used in constant folding or shape computations),
This function supports casting to booleans, floats, and double for traced values,
converting them accordingly. For raw values (used in constant folding or shape computations),
any cast is allowed.
See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
Args:
Expand All @@ -1496,8 +1495,12 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
if isinstance(data, RawOpOutput):
return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),)

# Allow both bool and float types
assert_true(to in (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT))
allowed_types = (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT, onnx.TensorProto.DOUBLE)
assert to in allowed_types, (
f"Invalid 'to' data type: {onnx.TensorProto.DataType.Name(to)}. "
f"Only {', '.join(onnx.TensorProto.DataType.Name(t) for t in allowed_types)}"
"are allowed for casting."
)

# Will be used for traced values
return (data.astype(numpy.float64),)
Expand Down
13 changes: 7 additions & 6 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def __init__(
quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None,
onnx_model: Optional[onnx.ModelProto] = None,
):
# Set base attributes for API consistency. This could be avoided if an abstract base class
# is created for both Concrete ML models and QuantizedModule
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899

all_or_none_params = [
ordered_module_input_names,
Expand All @@ -108,8 +105,8 @@ def __init__(
assert_true(
all(v is None or v == {} for v in all_or_none_params)
or not any(v is None or v == {} for v in all_or_none_params),
"All of ordered_module_input_names, ordered_module_output_names, "
"and quant_layers_dict must be provided if any one of them is provided.",
"Please either set all three 'ordered_module_input_names', "
"'ordered_module_output_names' and 'quant_layers_dict' or none of them.",
)
self.ordered_module_input_names = (
tuple(ordered_module_input_names) if ordered_module_input_names else ()
Expand All @@ -120,9 +117,13 @@ def __init__(
self.quant_layers_dict = (
copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {}
)

# Set base attributes for API consistency. This could be avoided if an abstract base class
# is created for both Concrete ML models and QuantizedModule
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899
self.input_quantizers: List[UniformQuantizer] = []
self.output_quantizers: List[UniformQuantizer] = []
self.fhe_circuit: Union[None, Circuit] = None
self.fhe_circuit: Optional[Circuit] = None
self._is_compiled = False
self._onnx_model = onnx_model
self._post_processing_params: Dict[str, Any] = {}
Expand Down
47 changes: 40 additions & 7 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,30 +1375,63 @@ def test_mono_parameter_rounding_warning(
)


@pytest.mark.parametrize(
"cast_type, should_fail, error_message",
[
(torch.bool, False, None),
(torch.float32, False, None),
(torch.float64, False, None),
(torch.int64, True, r"Invalid 'to' data type: INT64"),
],
)
def test_compile_torch_model_with_cast(cast_type, should_fail, error_message):
"""Test compiling a Torch model with various casts, expecting failure for invalid types."""
torch_input = torch.randn(100, 28)

class CastNet(nn.Module):
"""Network with cast."""

def __init__(self, cast_to):
super().__init__()
self.threshold = torch.tensor(0.5, dtype=torch.float32)
self.cast_to = cast_to

def forward(self, x):
"""Forward pass with dynamic cast."""
zeros = torch.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(self.cast_to)
return x

model = CastNet(cast_type)

if should_fail:
with pytest.raises(AssertionError, match=error_message):
compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3)
else:
compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3)


def test_onnx_no_input():
"""Test a torch model that has no input when converted to onnx."""

torch_input = torch.randn(100, 28)

class SimplifiedNet(nn.Module):
class NoInputNet(nn.Module):
"""Network with no input in the onnx graph."""

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28, 10)
self.threshold = torch.tensor(0.5, dtype=torch.float32)

def forward(self, x):
"""Forward pass."""
zeros = numpy.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(torch.float32)
x = self.fc1(x)
return x

model = SimplifiedNet()
model = NoInputNet()

with pytest.raises(AssertionError) as excinfo:
with pytest.raises(AssertionError, match="Input 'x' is not present in the ONNX model"):
compile_torch_model(model, torch_input, rounding_threshold_bits=3)

assert "Input 'x' is not present in the ONNX model" in str(excinfo.value)

0 comments on commit 007afef

Please sign in to comment.