Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(frontend): support higher bitwidth computation when using TFHE-rs #1136

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,20 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) ->
] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1))
padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table)
# set padding bits (where necessary) in the final result
return ctx.add(result_type, sum_result, padding_bits_inc)

return sum_result
result = ctx.add(result_type, sum_result, padding_bits_inc)
else:
result = sum_result

# even if TFHE-rs value are using non-variable bit-width, we want the output
# to be pluggable into the rest of the computation. For example, two 8bits TFHE-rs integers
# could be used in a 9bits addition. If we don't cast, it won't pass the bitwidth
# compatibility check.
output_bit_width = ctx.typeof(node).bit_width
casted_result_type = ctx.tensor(
ctx.esint(output_bit_width) if dtype.is_signed else ctx.eint(output_bit_width),
result_shape,
)
return ctx.cast(casted_result_type, result)

def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
Expand Down
10 changes: 10 additions & 0 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def lut_add_lut(x, y):
TFHERS_UINT_8_3_2_4096,
id="x + y",
),
# make sure Concrete ciphertexts can use more than 8 bits
pytest.param(
lambda x, y: (x + y) % 213,
{
"x": {"range": [128, 255], "status": "encrypted"},
"y": {"range": [128, 255], "status": "encrypted"},
},
TFHERS_UINT_8_3_2_4096,
id="mod(x + y)",
),
pytest.param(
lambda x, y: x + y,
{
Expand Down
Loading