Skip to content

Commit

Permalink
chore: add check proper keyword for rounding_threshold_bits (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Apr 19, 2024
1 parent 19342e9 commit c6917a3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ def process_rounding_threshold_bits(rounding_threshold_bits):
Raises:
NotImplementedError: If 'auto' rounding is specified but not implemented.
ValueError: If an invalid type or value is provided for rounding_threshold_bits.
KeyError: If the dict contains keys other than 'n_bits' and 'method'.
"""
n_bits_rounding: Union[None, str, int] = None
method: Exactness = Exactness.EXACT
Expand All @@ -632,9 +633,16 @@ def process_rounding_threshold_bits(rounding_threshold_bits):
if isinstance(rounding_threshold_bits, int):
n_bits_rounding = rounding_threshold_bits
elif isinstance(rounding_threshold_bits, dict):
valid_keys = {"n_bits", "method"}
if not valid_keys.issuperset(rounding_threshold_bits.keys()):
raise KeyError(
f"Invalid keys in rounding_threshold_bits. Allowed keys are {valid_keys}."
)
n_bits_rounding = rounding_threshold_bits.get("n_bits")
if n_bits_rounding == "auto":
raise NotImplementedError("Automatic rounding is not implemented yet.")
if not isinstance(n_bits_rounding, int):
raise ValueError("n_bits must be an integer.")
method = rounding_threshold_bits.get("method", method)
if not isinstance(method, Exactness):
method_str = method.upper()
Expand Down
12 changes: 11 additions & 1 deletion tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def forward(self, x):
"Invalid type for rounding_threshold_bits. Must be int or dict.",
),
(
{"method": "INVALID_METHOD"},
{"n_bits": 4, "method": "INVALID_METHOD"},
ValueError,
"INVALID_METHOD is not a valid method. Must be one of EXACT, APPROXIMATE.",
),
Expand All @@ -1420,6 +1420,16 @@ def forward(self, x):
ValueError,
"n_bits_rounding must be between 2 and 8 inclusive",
),
(
{"invalid_key": 4},
KeyError,
"Invalid keys in rounding_threshold_bits. Allowed keys are {'n_bits', 'method'}.",
),
(
{"n_bits": "not_an_int"},
ValueError,
"n_bits must be an integer.",
),
],
)
def test_compile_torch_model_rounding_threshold_bits_errors(
Expand Down

0 comments on commit c6917a3

Please sign in to comment.