Skip to content

Commit

Permalink
fix(frontend-python): optimize extract bits
Browse files Browse the repository at this point in the history
lsb and tlu calls were not minimized
  • Loading branch information
rudy-6-4 committed Apr 10, 2024
1 parent f506f5f commit 863b485
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 43 deletions.
4 changes: 2 additions & 2 deletions frontends/concrete-python/concrete/fhe/extensions/bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __getitem__(self, index: Union[int, np.integer, slice]) -> Tracer:

def evaluator(x, bits): # pylint: disable=redefined-outer-name
if isinstance(bits, (int, np.integer)):
return (x & (1 << bits)) >> bits
return (x >> bits) & 1

assert isinstance(bits, slice)

Expand All @@ -126,7 +126,7 @@ def evaluator(x, bits): # pylint: disable=redefined-outer-name

result = 0
for i, bit in enumerate(range(start, stop, step)):
value = (x & (1 << bit)) >> bit
value = (x >> bit) & 1
result += value << i

return result
Expand Down
178 changes: 138 additions & 40 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@

# pylint: enable=import-error,no-name-in-module

# See https://raw.githubusercontent.com/zama-ai/concrete/main/compilers/concrete-optimizer/v0-parameters/ref/v0_last_128
# Provide a very coarse way to compare 2 alternative code generation
LUT_COSTS_V0_NORM2_0 = {
1: 29,
2: 33,
3: 45,
4: 74,
5: 101,
6: 231,
7: 535,
8: 1721,
9: 3864,
10: 8697,
11: 19522,
}


def default(param, value):
"""Handle optional parameter with default value."""
return value if param is None else param


class Context:
"""
Expand Down Expand Up @@ -125,6 +146,27 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType:

return result if value.is_scalar else self.tensor(result, value.shape)

def fork_type(
self,
type_: ConversionType,
bit_width: Optional[int] = None,
is_signed: Optional[int] = None,
shape: Optional[Tuple[int, ...]] = None,
) -> ConversionType:
"""
Fork a type with some properties update.
"""
bit_width = default(param=bit_width, value=type_.bit_width)
is_signed = default(param=is_signed, value=type_.is_signed)
shape = default(param=shape, value=type_.shape)
return self.typeof(
ValueDescription(
dtype=Integer(is_signed=is_signed, bit_width=bit_width),
shape=shape,
is_encrypted=type_.is_encrypted,
)
)

# utilities

def location(self) -> MlirLocation:
Expand Down Expand Up @@ -2173,6 +2215,22 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL})

def shift_left(self, x: Conversion, rank: int) -> Conversion:
assert rank >= 0
assert rank < x.bit_width
shifter = 2**rank
shifter = self.constant(self.i(x.bit_width + 1), shifter)
return self.mul(x.type, x, shifter)

def reduce_precision(self, x: Conversion, bit_width: int) -> Conversion:
assert bit_width > 0
assert bit_width <= x.bit_width
if bit_width == x.bit_width:
return x
scaled_x = self.shift_left(x, x.type.bit_width - bit_width)
x = self.reinterpret(scaled_x, bit_width=bit_width)
return x

def extract_bits(
self,
resulting_type: ConversionType,
Expand Down Expand Up @@ -2200,66 +2258,105 @@ def extract_bits(
start = bits.start or MIN_EXTRACTABLE_BIT
stop = bits.stop or (MAX_EXTRACTABLE_BIT if step > 0 else (MIN_EXTRACTABLE_BIT - 1))

bits_and_their_positions = []
for position, bit in enumerate(range(start, stop, step)):
bits_and_their_positions.append((bit, position))
bits = list(range(start, stop, step))

bits_and_their_positions = ((bit, position) for position, bit in enumerate(bits))

bits_and_their_positions = sorted(
bits_and_their_positions,
key=lambda bit_and_its_position: bit_and_its_position[0],
)

current_bit = 0
# we optimize bulk extract in low precision, used for identity
max_bit = x.original_bit_width
max_real_bit = max(
(bit for bit in bits if bit <= max_bit and x.is_unsigned),
default=0
)
cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf"))
cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max_real_bit + 1)
if cost_one_tlu < cost_many_lsbs:

def tlu_cell_with_positive_value(i):
return x.type.is_unsigned or i < 2 ** (x.bit_width - 1)

def tlu_cell_input_value(i):
if tlu_cell_with_positive_value(i):
return i
return -(2 ** (x.bit_width) - i)

table = [
sum(
((tlu_cell_input_value(i) >> bit) & 1) << position
for bit, position in bits_and_their_positions
)
+ (0 if tlu_cell_with_positive_value(i) else 2 ** (resulting_type.bit_width + 1))
for i in range(2**x.bit_width)
]
tlu_result = self.tlu(resulting_type, x, table)
return self.to_signedness(tlu_result, of=resulting_type)

current_bit = 0

lsb: Optional[Conversion] = None
result: Optional[Conversion] = None

for index, (bit, position) in enumerate(bits_and_their_positions):
for bit, position in bits_and_their_positions:
if bit >= max_bit and x.is_unsigned:
break

last = index == len(bits_and_their_positions) - 1
while bit != (current_bit - 1):
if bit == (max_bit - 1) and x.bit_width == 1 and x.is_unsigned:
if (
bit == (max_bit - 1)
and x.bit_width == resulting_type.bit_width == 1
and x.is_unsigned
):
lsb_bit_witdh = 1
lsb = x
elif last and bit == current_bit:
lsb = self.lsb(resulting_type, x)
else:
lsb = self.lsb(x.type, x)
lsb_bit_witdh = max(resulting_type.bit_width - position, x.bit_width)
lsb_type = self.fork_type(x.type, lsb_bit_witdh)
lsb = self.lsb(lsb_type, x)

# check that we only need to shift to emulate the initial and final position
# position are expressed for the final bit_width
initial_position = resulting_type.bit_width - x.bit_width
actual_position = resulting_type.bit_width - lsb.type.bit_width
delta_precision = initial_position - actual_position
assert 0 <= delta_precision < resulting_type.bit_width
assert (
actual_position <= initial_position
), "extract_bits: Cannot get back to initial precision"
assert (
actual_position <= position
), "extract_bits: Cannot get back to final precision"

current_bit += 1

if current_bit >= max_bit:
break

if not last or bit != (current_bit - 1):
cleared = self.sub(x.type, x, lsb)
x = self.reinterpret(cleared, bit_width=(x.bit_width - 1))
clearing_bit = self.reduce_precision(lsb, x.bit_width)
cleared = self.sub(x.type, x, clearing_bit)
x = self.reinterpret(cleared, bit_width=(x.bit_width - 1))

assert lsb is not None
lsb = self.to_signedness(lsb, of=resulting_type)

if lsb.bit_width > resulting_type.bit_width:
difference = (lsb.bit_width - resulting_type.bit_width) + position
shifter = self.constant(self.i(lsb.bit_width + 1), 2**difference)
shifted = self.mul(lsb.type, lsb, shifter)
lsb = self.reinterpret(shifted, bit_width=resulting_type.bit_width)

elif lsb.bit_width < resulting_type.bit_width:
shift = 2 ** (lsb.bit_width - 1)
if shift != 1:
shifter = self.constant(self.i(lsb.bit_width + 1), shift)
shifted = self.mul(lsb.type, lsb, shifter)
lsb = self.reinterpret(shifted, bit_width=1)
lsb = self.tlu(resulting_type, lsb, [0 << position, 1 << position])

elif position != 0:
shifter = self.constant(self.i(lsb.bit_width + 1), 2**position)
lsb = self.mul(lsb.type, lsb, shifter)
bit_value = self.to_signedness(lsb, of=resulting_type)
bit_value = self.reinterpret(
bit_value, bit_width=max(resulting_type.bit_width, max_bit)
)

assert lsb is not None
result = lsb if result is None else self.add(resulting_type, result, lsb)
delta_precision = position - actual_position
assert actual_position < 0 or 0 <= delta_precision < resulting_type.bit_width, (
position,
actual_position,
resulting_type.bit_width,
)
if delta_precision:
bit_value = self.shift_left(bit_value, delta_precision)

bit_value = self.reinterpret(bit_value, bit_width=resulting_type.bit_width)

result = bit_value if result is None else self.add(resulting_type, result, bit_value)

return result if result is not None else self.zeros(resulting_type)

Expand Down Expand Up @@ -3209,8 +3306,8 @@ def round_bit_pattern(
unskewed = x
if approx_conf.symetrize_deltas:
highest_supported_precision = 62
delta_precision = highest_supported_precision - x.type.bit_width
full_precision = x.type.bit_width + delta_precision
delta_precision = highest_supported_precision - x.bit_width
full_precision = x.bit_width + delta_precision
half_in_extra_precision = (
1 << (delta_precision - 1)
) - 1 # slightly smaller then half
Expand Down Expand Up @@ -3763,11 +3860,12 @@ def reinterpret(
) -> Conversion:
assert x.is_encrypted

if x.bit_width == bit_width:
result_unsigned = x.is_unsigned if signed is None else not signed

if x.bit_width == bit_width and x.is_unsigned == result_unsigned:
return x

result_signed = x.is_unsigned if signed is None else signed
resulting_element_type = (self.eint if result_signed else self.esint)(bit_width)
resulting_element_type = (self.eint if result_unsigned else self.esint)(bit_width)
resulting_type = self.tensor(resulting_element_type, shape=x.shape)

operation = (
Expand Down
Loading

0 comments on commit 863b485

Please sign in to comment.