-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
237 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dtypes import TFHERSIntegerType, int8, int16, uint8, uint16 | ||
from .tracing import from_native, to_native | ||
from .values import int8_2_2_value, int16_2_2_value, uint8_2_2_value, uint16_2_2_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
""" | ||
Declaration of `TFHERSIntegerType` class. | ||
""" | ||
|
||
from functools import partial | ||
from typing import Any | ||
|
||
from ..dtypes import Integer | ||
|
||
|
||
class TFHERSIntegerType(Integer): | ||
pad_width: int | ||
msg_width: int | ||
|
||
def __init__(self, is_signed: bool, bit_width: int, pad_width: int, msg_width: int): | ||
super().__init__(is_signed, bit_width) | ||
self.pad_width = pad_width | ||
self.msg_width = msg_width | ||
|
||
def __eq__(self, other: Any) -> bool: | ||
return ( | ||
isinstance(other, self.__class__) | ||
and super().__eq__(other) | ||
and self.pad_width == other.pad_width | ||
and self.msg_width == other.msg_width | ||
) | ||
|
||
def __str__(self) -> str: | ||
return ( | ||
f"tfhers_{('int' if self.is_signed else 'uint')}" | ||
f"{self.bit_width}_{self.pad_width}_{self.msg_width}" | ||
) | ||
|
||
|
||
int8 = partial(TFHERSIntegerType, True, 8) | ||
uint8 = partial(TFHERSIntegerType, False, 8) | ||
int16 = partial(TFHERSIntegerType, True, 16) | ||
uint16 = partial(TFHERSIntegerType, False, 16) | ||
|
||
int8_2_2 = int8(2, 2) | ||
uint8_2_2 = uint8(2, 2) | ||
int16_2_2 = int16(2, 2) | ||
uint16_2_2 = uint16(2, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from copy import deepcopy | ||
from typing import Union | ||
|
||
from ..dtypes import Integer | ||
from ..representation import Node | ||
from ..tracing import Tracer | ||
from ..values import EncryptedScalar, EncryptedTensor | ||
from .dtypes import TFHERSIntegerType, uint16 | ||
from .values import TFHERSInteger | ||
|
||
|
||
def to_native(value: Union[Tracer, TFHERSInteger]): | ||
if isinstance(value, Tracer): | ||
dtype = value.output.dtype | ||
if not isinstance(dtype, TFHERSIntegerType): | ||
raise TypeError("tracer didn't contain an output of TFHEInteger type. Type is:", dtype) | ||
return _trace_to_native(value, dtype) | ||
assert isinstance(value, TFHERSInteger) | ||
dtype = value.dtype | ||
return _eval_to_native(value, dtype.pad_width, dtype.msg_width) | ||
|
||
|
||
def from_native(value, dtype_to: TFHERSIntegerType): | ||
if isinstance(value, Tracer): | ||
return _trace_from_native(value, dtype_to) | ||
assert isinstance(value, TFHERSInteger) | ||
dtype = value.dtype | ||
return _eval_from_native(value, dtype.pad_width, dtype.msg_width) | ||
|
||
|
||
def _trace_to_native(tfhers_int: Tracer, dtype: TFHERSIntegerType): | ||
# TODO: compute the right descriptor | ||
output = EncryptedScalar(Integer(dtype.is_signed, dtype.bit_width)) | ||
|
||
computation = Node.generic( | ||
"tfhers_to_native", | ||
deepcopy( | ||
[ | ||
tfhers_int.output, | ||
] | ||
), | ||
output, | ||
_eval_to_native, | ||
args=(), | ||
kwargs={"pad_width": dtype.pad_width, "msg_width": dtype.msg_width}, | ||
) | ||
return Tracer( | ||
computation, | ||
input_tracers=[ | ||
tfhers_int, | ||
], | ||
) | ||
|
||
|
||
def _trace_from_native(native_int: Tracer, dtype_to: TFHERSIntegerType): | ||
# TODO: compute the right descriptor | ||
output = EncryptedScalar(dtype_to) | ||
|
||
computation = Node.generic( | ||
"tfhers_from_native", | ||
deepcopy( | ||
[ | ||
native_int.output, | ||
] | ||
), | ||
output, | ||
_eval_from_native, | ||
args=(), | ||
kwargs={"pad_width": dtype_to.pad_width, "msg_width": dtype_to.msg_width}, | ||
) | ||
return Tracer( | ||
computation, | ||
input_tracers=[ | ||
native_int, | ||
], | ||
) | ||
|
||
|
||
def _eval_to_native(tfhers_int: TFHERSInteger, pad_width: int, msg_width: int): | ||
return tfhers_int.value | ||
|
||
|
||
def _eval_from_native(native_value, pad_width: int, msg_width: int): | ||
return native_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from functools import partial | ||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
from .dtypes import TFHERSIntegerType, int8_2_2, int16_2_2, uint8_2_2, uint16_2_2 | ||
|
||
|
||
# TODO: maybe merge the integer and the type with the type having an optional param value | ||
class TFHERSInteger: | ||
_value: Union[int, np.ndarray] | ||
_dtype: TFHERSIntegerType | ||
_shape: tuple | ||
|
||
# TODO: the value might need access to crypto parameters to compute the type and shape | ||
def __init__( | ||
self, | ||
dtype: TFHERSIntegerType, | ||
value: Union[int, np.ndarray], | ||
): | ||
if isinstance(value, int): | ||
self._shape = () | ||
elif isinstance(value, np.ndarray): | ||
if value.max() > dtype.max(): | ||
raise ValueError( | ||
"ndarray value has bigger elements than what the dtype can support" | ||
) | ||
if value.min() < dtype.min(): | ||
raise ValueError( | ||
"ndarray value has smaller elements than what the dtype can support" | ||
) | ||
self._shape = value.shape | ||
else: | ||
raise TypeError("value can either be an int or ndarray") | ||
|
||
self._value = value | ||
self._dtype = dtype | ||
|
||
@property | ||
def dtype(self) -> TFHERSIntegerType: | ||
# type has to return the type of a single ct after encoding, not the TFHERS type | ||
return self._dtype | ||
|
||
@property | ||
def shape(self) -> tuple: | ||
# shape has to return the shape considering encoding | ||
return self._shape | ||
|
||
@property | ||
def value(self) -> Union[int, np.ndarray]: | ||
return self._value | ||
|
||
def min(self): | ||
return self.dtype.min() | ||
|
||
def max(self): | ||
return self.dtype.max() | ||
|
||
|
||
int8_2_2_value = partial(TFHERSInteger, int8_2_2) | ||
int16_2_2_value = partial(TFHERSInteger, int16_2_2) | ||
uint8_2_2_value = partial(TFHERSInteger, uint8_2_2) | ||
uint16_2_2_value = partial(TFHERSInteger, uint16_2_2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters