Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Apr 26, 2024
1 parent cc30c55 commit c8e6a7c
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 2 deletions.
2 changes: 1 addition & 1 deletion frontends/concrete-python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pytest: pytest-default
pytest-default:
export LD_PRELOAD=$(RUNTIME_LIBRARY)
export PYTHONPATH=$(BINDINGS_DIRECTORY)
pytest tests -svv -n auto \
pytest tests -svv -n 1 \
--cov=concrete.fhe \
--cov-fail-under=95 \
--cov-report=term-missing:skip-covered \
Expand Down
7 changes: 7 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..dtypes import Integer
from ..extensions.bits import MAX_EXTRACTABLE_BIT, MIN_EXTRACTABLE_BIT
from ..representation import Graph, GraphProcessor, Node
from ..tfhers.dtypes import TFHERSIntegerType
from ..values import ValueDescription
from .conversion import Conversion, ConversionType
from .utils import MAXIMUM_TLU_BIT_WIDTH, Comparison, _FromElementsOp
Expand Down Expand Up @@ -116,6 +117,12 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType:
assert isinstance(value.dtype, Integer)
bit_width = value.dtype.bit_width

# TODO: compute the real shape (and what about the element type? only unsigned? or not eint at all?)
if isinstance(value.dtype, TFHERSIntegerType):
pad_width = value.dtype.pad_width
msg_width = value.dtype.msg_width
return self.eint(value.dtype.bit_width)

if value.is_clear:
result = self.i(bit_width)
elif value.dtype.is_signed:
Expand Down
20 changes: 20 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,4 +814,24 @@ def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion
assert len(preds) == 0
return ctx.zeros(ctx.typeof(node))

def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
pad_width, msg_width = (
node.properties["kwargs"]["pad_width"],
node.properties["kwargs"]["msg_width"],
)
# PBS values to tfhers integer bitwidth
# width = tfhers_int.bit_width
# Dot with some powers of two to encode in a single value
return preds[0]

def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
pad_width, msg_width = (
node.properties["kwargs"]["pad_width"],
node.properties["kwargs"]["msg_width"],
)
# extract bits and put them in a tensor of ct based on crypto params
return preds[0]

# pylint: enable=missing-function-docstring,unused-argument
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ def min_max(self, node: Node, preds: List[Node]):
inputs_and_output_share_precision,
}

tfhers_to_native = {
inputs_and_output_share_precision,
}

array = {
inputs_and_output_share_precision,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,12 @@ def generic_error_message() -> str:
# if it fails we raise the exception below
pass

if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
# it's only normal to have them in input nodes
from ..tfhers.values import TFHERSInteger

if self.operation == Operation.Input and isinstance(result, TFHERSInteger):
pass
elif not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"of type {result.__class__.__name__} "
Expand Down
3 changes: 3 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dtypes import TFHERSIntegerType, int8, int16, uint8, uint16
from .tracing import from_native, to_native
from .values import int8_2_2_value, int16_2_2_value, uint8_2_2_value, uint16_2_2_value
43 changes: 43 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Declaration of `TFHERSIntegerType` class.
"""

from functools import partial
from typing import Any

from ..dtypes import Integer


class TFHERSIntegerType(Integer):
pad_width: int
msg_width: int

def __init__(self, is_signed: bool, bit_width: int, pad_width: int, msg_width: int):
super().__init__(is_signed, bit_width)
self.pad_width = pad_width
self.msg_width = msg_width

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and super().__eq__(other)
and self.pad_width == other.pad_width
and self.msg_width == other.msg_width
)

def __str__(self) -> str:
return (
f"tfhers_{('int' if self.is_signed else 'uint')}"
f"{self.bit_width}_{self.pad_width}_{self.msg_width}"
)


int8 = partial(TFHERSIntegerType, True, 8)
uint8 = partial(TFHERSIntegerType, False, 8)
int16 = partial(TFHERSIntegerType, True, 16)
uint16 = partial(TFHERSIntegerType, False, 16)

int8_2_2 = int8(2, 2)
uint8_2_2 = uint8(2, 2)
int16_2_2 = int16(2, 2)
uint16_2_2 = uint16(2, 2)
84 changes: 84 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from copy import deepcopy
from typing import Union

from ..dtypes import Integer
from ..representation import Node
from ..tracing import Tracer
from ..values import EncryptedScalar, EncryptedTensor
from .dtypes import TFHERSIntegerType, uint16
from .values import TFHERSInteger


def to_native(value: Union[Tracer, TFHERSInteger]):
if isinstance(value, Tracer):
dtype = value.output.dtype
if not isinstance(dtype, TFHERSIntegerType):
raise TypeError("tracer didn't contain an output of TFHEInteger type. Type is:", dtype)
return _trace_to_native(value, dtype)
assert isinstance(value, TFHERSInteger)
dtype = value.dtype
return _eval_to_native(value, dtype.pad_width, dtype.msg_width)


def from_native(value, dtype_to: TFHERSIntegerType):
if isinstance(value, Tracer):
return _trace_from_native(value, dtype_to)
assert isinstance(value, TFHERSInteger)
dtype = value.dtype
return _eval_from_native(value, dtype.pad_width, dtype.msg_width)


def _trace_to_native(tfhers_int: Tracer, dtype: TFHERSIntegerType):
# TODO: compute the right descriptor
output = EncryptedScalar(Integer(dtype.is_signed, dtype.bit_width))

computation = Node.generic(
"tfhers_to_native",
deepcopy(
[
tfhers_int.output,
]
),
output,
_eval_to_native,
args=(),
kwargs={"pad_width": dtype.pad_width, "msg_width": dtype.msg_width},
)
return Tracer(
computation,
input_tracers=[
tfhers_int,
],
)


def _trace_from_native(native_int: Tracer, dtype_to: TFHERSIntegerType):
# TODO: compute the right descriptor
output = EncryptedScalar(dtype_to)

computation = Node.generic(
"tfhers_from_native",
deepcopy(
[
native_int.output,
]
),
output,
_eval_from_native,
args=(),
kwargs={"pad_width": dtype_to.pad_width, "msg_width": dtype_to.msg_width},
)
return Tracer(
computation,
input_tracers=[
native_int,
],
)


def _eval_to_native(tfhers_int: TFHERSInteger, pad_width: int, msg_width: int):
return tfhers_int.value


def _eval_from_native(native_value, pad_width: int, msg_width: int):
return native_value
63 changes: 63 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from functools import partial
from typing import Union

import numpy as np

from .dtypes import TFHERSIntegerType, int8_2_2, int16_2_2, uint8_2_2, uint16_2_2


# TODO: maybe merge the integer and the type with the type having an optional param value
class TFHERSInteger:
_value: Union[int, np.ndarray]
_dtype: TFHERSIntegerType
_shape: tuple

# TODO: the value might need access to crypto parameters to compute the type and shape
def __init__(
self,
dtype: TFHERSIntegerType,
value: Union[int, np.ndarray],
):
if isinstance(value, int):
self._shape = ()
elif isinstance(value, np.ndarray):
if value.max() > dtype.max():
raise ValueError(
"ndarray value has bigger elements than what the dtype can support"
)
if value.min() < dtype.min():
raise ValueError(
"ndarray value has smaller elements than what the dtype can support"
)
self._shape = value.shape
else:
raise TypeError("value can either be an int or ndarray")

self._value = value
self._dtype = dtype

@property
def dtype(self) -> TFHERSIntegerType:
# type has to return the type of a single ct after encoding, not the TFHERS type
return self._dtype

@property
def shape(self) -> tuple:
# shape has to return the shape considering encoding
return self._shape

@property
def value(self) -> Union[int, np.ndarray]:
return self._value

def min(self):
return self.dtype.min()

def max(self):
return self.dtype.max()


int8_2_2_value = partial(TFHERSInteger, int8_2_2)
int16_2_2_value = partial(TFHERSInteger, int16_2_2)
uint8_2_2_value = partial(TFHERSInteger, uint8_2_2)
uint16_2_2_value = partial(TFHERSInteger, uint16_2_2)
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def of(value: Any, is_encrypted: bool = False) -> "ValueDescription":
dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted
)

# to avoid cyclic import issue
from ..tfhers.values import TFHERSInteger

if isinstance(value, TFHERSInteger):
return ValueDescription(dtype=value.dtype, shape=value.shape, is_encrypted=True)

message = f"Concrete cannot represent {repr(value)}"
raise ValueError(message)

Expand Down

0 comments on commit c8e6a7c

Please sign in to comment.