Skip to content

Commit

Permalink
feat(frontend-python): relax bit-width assignment of clear values
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Mar 21, 2024
1 parent 2471b37 commit 79b72db
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
12 changes: 3 additions & 9 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,9 @@ def is_bit_width_compatible(self, *args: Optional[Union[ConversionType, Conversi
args = tuple(arg.type if isinstance(arg, Conversion) else arg for arg in args)

def check(type1, type2):
return (
(type1.bit_width + 1) == type2.bit_width
if type1.is_encrypted and type2.is_clear
else (
type1.bit_width == (type2.bit_width + 1)
if type1.is_clear and type2.is_encrypted
else type1.bit_width == type2.bit_width
)
)
if type1.is_encrypted and type2.is_encrypted:
return type1.bit_width == type2.bit_width
return True

reference = args[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,17 @@ def apply_many(self, graphs: Dict[str, Graph]):
for node, bit_width in bit_widths.items():
assert isinstance(node.output.dtype, Integer)
new_bit_width = model[bit_width].as_long()

if node.output.is_clear:
new_bit_width += 1

node.properties["original_bit_width"] = node.properties.get(
original_bit_width = node.properties.get(
"bit_width_hint",
node.output.dtype.bit_width,
)

if node.output.is_clear:
new_bit_width = original_bit_width
if not node.output.dtype.is_signed:
new_bit_width += 1

node.properties["original_bit_width"] = original_bit_width
node.output.dtype.bit_width = new_bit_width
for graph in graphs.values():
graph.bit_width_constraints = optimizer
Expand Down
16 changes: 8 additions & 8 deletions frontends/concrete-python/tests/mlir/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,8 +1421,8 @@ def test_converter_bad_convert(
%1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<8>) -> !FHE.eint<6>
%cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31]> : tensor<64xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<5>
%c100_i9 = arith.constant 100 : i9
%3 = "FHE.add_eint_int"(%arg0, %c100_i9) : (!FHE.eint<8>, i9) -> !FHE.eint<8>
%c100_i8 = arith.constant 100 : i8
%3 = "FHE.add_eint_int"(%arg0, %c100_i8) : (!FHE.eint<8>, i8) -> !FHE.eint<8>
return %2, %3 : !FHE.eint<5>, !FHE.eint<8>
}
}
Expand Down Expand Up @@ -1501,9 +1501,9 @@ def test_converter_bad_convert(
%1 = "FHE.reinterpret_precision"(%0) : (!FHE.esint<8>) -> !FHE.esint<6>
%cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, -16, -16, -15, -15, -14, -14, -13, -13, -12, -12, -11, -11, -10, -10, -9, -9, -8, -8, -7, -7, -6, -6, -5, -5, -4, -4, -3, -3, -2, -2, -1, -1]> : tensor<64xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.esint<6>, tensor<64xi64>) -> !FHE.esint<5>
%c100_i9 = arith.constant 100 : i9
%c100_i8 = arith.constant 100 : i8
%3 = "FHE.to_unsigned"(%arg0) : (!FHE.esint<8>) -> !FHE.eint<8>
%4 = "FHE.add_eint_int"(%3, %c100_i9) : (!FHE.eint<8>, i9) -> !FHE.eint<8>
%4 = "FHE.add_eint_int"(%3, %c100_i8) : (!FHE.eint<8>, i8) -> !FHE.eint<8>
return %2, %4 : !FHE.esint<5>, !FHE.eint<8>
}
}
Expand Down Expand Up @@ -1615,7 +1615,7 @@ def test_converter_convert_multi_precision(
module {
func.func @main(%arg0: tensor<2x!FHE.eint<6>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
%c2_i7 = arith.constant 2 : i7
%c2_i3 = arith.constant 2 : i3
%c16_i7 = arith.constant 16 : i7
%from_elements = tensor.from_elements %c16_i7 : tensor<1xi7>
%0 = "FHELinalg.mul_eint_int"(%arg0, %from_elements) : (tensor<2x!FHE.eint<6>>, tensor<1xi7>) -> tensor<2x!FHE.eint<6>>
Expand Down Expand Up @@ -1702,7 +1702,7 @@ def test_converter_convert_composition(function, parameters, expected_mlir, help
%0 = x # EncryptedScalar<uint4> ∈ [0, 10]
%1 = 2 # ClearScalar<uint3> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
%3 = 100 # ClearScalar<uint8> ∈ [100, 100]
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
return %4
Expand Down Expand Up @@ -1742,9 +1742,9 @@ def test_converter_process_multi_precision(function, parameters, expected_graph,
"""
%0 = x # EncryptedScalar<uint8> ∈ [0, 10]
%1 = 2 # ClearScalar<uint9> ∈ [2, 2]
%1 = 2 # ClearScalar<uint3> ∈ [2, 2]
%2 = power(%0, %1) # EncryptedScalar<uint8> ∈ [0, 100]
%3 = 100 # ClearScalar<uint9> ∈ [100, 100]
%3 = 100 # ClearScalar<uint8> ∈ [100, 100]
%4 = add(%2, %3) # EncryptedScalar<uint8> ∈ [100, 200]
return %4
Expand Down

0 comments on commit 79b72db

Please sign in to comment.