From 9717ec583ea8fc14f4598d39dbdf44755ab9878b Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 22 Mar 2024 15:03:15 +0100 Subject: [PATCH] fix(frontend-python): round_bit_pattern, prevent exactness argument misuse --- .../concrete/fhe/compilation/configuration.py | 8 ++++---- .../concrete/fhe/extensions/round_bit_pattern.py | 4 ++++ .../concrete/fhe/mlir/converter.py | 2 +- .../fhe/mlir/processors/process_rounding.py | 2 +- .../tests/extensions/test_round_bit_pattern.py | 15 +++++++++++++++ 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index ba0b643240..7a49954370 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -4,7 +4,7 @@ import platform from dataclasses import dataclass -from enum import Enum, IntEnum +from enum import Enum from pathlib import Path from typing import List, Optional, Tuple, Union, get_type_hints @@ -74,13 +74,13 @@ def parse(cls, string: str) -> "MultiParameterStrategy": raise ValueError(message) -class Exactness(IntEnum): +class Exactness(Enum): """ Exactness, to specify for specific operator the implementation preference (default and local). """ - EXACT = 0 - APPROXIMATE = 1 + APPROXIMATE = "approximate" + EXACT = "exact" @dataclass diff --git a/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py b/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py index c67380831d..932d947b99 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py +++ b/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py @@ -242,6 +242,10 @@ def round_bit_pattern( lsbs_to_remove = lsbs_to_remove.lsbs_to_remove + if not isinstance(exactness, (Exactness, None.__class__)): + msg = "exactness should be of type fhe.Exactness" + raise TypeError(msg) + assert isinstance(lsbs_to_remove, int) def evaluator( diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index bb8bda4dc1..dba0dee7ff 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -760,7 +760,7 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: variable_input.origin.properties["exactness"] or ctx.configuration.rounding_exactness ) - if exactness == Exactness.APPROXIMATE: + if exactness is Exactness.APPROXIMATE: # we clip values to enforce input precision exactly as queried original_bit_width = variable_input.origin.properties["original_bit_width"] lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"] diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py index 01abaaa8f5..01461e56ea 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py @@ -40,7 +40,7 @@ def apply(self, graph: Graph): exactness = self.rounding_exactness if original_lsbs_to_remove != 0 and final_lsbs_to_remove == 0: - if exactness != Exactness.APPROXIMATE: + if exactness is not Exactness.APPROXIMATE: self.replace_with_tlu(graph, node) continue diff --git a/frontends/concrete-python/tests/extensions/test_round_bit_pattern.py b/frontends/concrete-python/tests/extensions/test_round_bit_pattern.py index 4fd8789051..d067e7f35d 100644 --- a/frontends/concrete-python/tests/extensions/test_round_bit_pattern.py +++ b/frontends/concrete-python/tests/extensions/test_round_bit_pattern.py @@ -2,6 +2,8 @@ Tests of 'round_bit_pattern' extension. """ +import pytest + from concrete import fhe @@ -35,3 +37,16 @@ def test_dump_load_auto_rounder(): assert loaded.input_max == 20 assert loaded.input_bit_width == 5 assert loaded.lsbs_to_remove == 2 + + +def test_bad_exactness(): + """ + Test for incorrect 'exactness' argument. + """ + + @fhe.compiler({"a": "encrypted"}) + def f(a): + return fhe.round_bit_pattern(a, lsbs_to_remove=1, exactness=True) + + with pytest.raises(TypeError): + f.compile([0])