Skip to content

Commit

Permalink
fix(frontend-python): REVIEW TO SQUASH
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Apr 3, 2024
1 parent ad0e6b5 commit b135255
Showing 1 changed file with 36 additions and 41 deletions.
77 changes: 36 additions & 41 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@

# pylint: enable=import-error,no-name-in-module


LUT_COSTS = {
# See https://raw.githubusercontent.com/zama-ai/concrete/main/compilers/concrete-optimizer/v0-parameters/ref/v0_last_128
# Provide a very coarse way to compare 2 alternative code generation
LUT_COSTS_V0_NORM2_0 = {
1: 29,
2: 33,
3: 45,
Expand Down Expand Up @@ -140,6 +141,15 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType:

return result if value.is_scalar else self.tensor(result, value.shape)

def fork_type(self, type_, bit_width):
return self.typeof(
ValueDescription(
dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width),
shape=type_.shape,
is_encrypted=type_.is_encrypted,
)
)

# utilities

def location(self) -> MlirLocation:
Expand Down Expand Up @@ -2188,17 +2198,17 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL})

def shift_left(self, x: Conversion, rank):
def shift_left(self, x: Conversion, rank: int) -> Conversion:
assert rank >= 0
assert rank < x.bit_width
shifter = 2**rank
shifter = self.constant(self.i(x.bit_width + 1), shifter)
return self.mul(x.type, x, shifter)

def reduce_precision(self, x: Conversion, bit_width):
def reduce_precision(self, x: Conversion, bit_width: int) -> Conversion:
assert bit_width > 0
assert bit_width <= x.type.bit_width
if bit_width == x.type.bit_width:
assert bit_width <= x.bit_width
if bit_width == x.bit_width:
return x
scaled_x = self.shift_left(x, x.type.bit_width - bit_width)
x = self.reinterpret(scaled_x, bit_width=bit_width)
Expand Down Expand Up @@ -2241,44 +2251,29 @@ def extract_bits(
)

# we optimize bulk extract in low precision, used for identity
if (
len(bits) > 1
and LUT_COSTS.get(x.type.bit_width, float("inf")) <= (max(bits) + 1) * LUT_COSTS[1]
):
cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf"))
cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1)
if cost_one_tlu < cost_many_lsbs:

def is_positive(v):
return x.type.is_unsigned or v < 2 ** (x.type.bit_width - 1)
def tlu_cell_with_positive_value(i):
return x.type.is_unsigned or i < 2 ** (x.bit_width - 1)

def to_signed(v):
if is_positive(v):
return v
return -(2 ** (x.type.bit_width) - v)
def tlu_cell_input_value(i):
if tlu_cell_with_positive_value(i):
return i
return -(2 ** (x.bit_width) - i)

table = [
sum(
((to_signed(v) >> bit) & 1) << position
((tlu_cell_input_value(i) >> bit) & 1) << position
for bit, position in bits_and_their_positions
)
+ (0 if is_positive(v) else 2 ** (resulting_type.bit_width + 1))
for v in range(2**x.type.bit_width)
+ (0 if tlu_cell_with_positive_value(i) else 2 ** (resulting_type.bit_width + 1))
for i in range(2**x.bit_width)
]
tlu_result = self.tlu(resulting_type, x, table)
return self.to_signedness(tlu_result, of=resulting_type)

def same_type_as(type_, bit_width):
return self.typeof(
ValueDescription(
dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width),
shape=type_.shape,
is_encrypted=type_.is_encrypted,
)
)

def reduce_precision(x, delta):
assert delta < x.type.bit_width
new_bit_witdh = x.type.bit_width - delta
return self.reduce_precision(x, new_bit_witdh)

current_bit = 0
max_bit = x.original_bit_width

Expand All @@ -2297,14 +2292,16 @@ def reduce_precision(x, delta):
lsb_bit_witdh = 1
lsb = x
else:
lsb_bit_witdh = max(resulting_type.bit_width - position, x.type.bit_width)
lsb_type = same_type_as(x.type, lsb_bit_witdh)
lsb_bit_witdh = max(resulting_type.bit_width - position, x.bit_width)
lsb_type = self.fork_type(x.type, lsb_bit_witdh)
lsb = self.lsb(lsb_type, x)

# check that we only need to shift to emulate the initial and final position
# position are expressed for the final bit_width
initial_position = resulting_type.bit_width - x.type.bit_width
initial_position = resulting_type.bit_width - x.bit_width
actual_position = resulting_type.bit_width - lsb.type.bit_width
delta_precision = initial_position - actual_position
assert 0 <= delta_precision < resulting_type.bit_width
assert (
actual_position <= initial_position
), "extract_bits: Cannot get back to initial precision"
Expand All @@ -2317,9 +2314,7 @@ def reduce_precision(x, delta):
if current_bit >= max_bit:
break

delta_precision = initial_position - actual_position
assert 0 <= delta_precision < resulting_type.bit_width
clearing_bit = reduce_precision(lsb, delta_precision)
clearing_bit = self.reduce_precision(lsb, x.bit_width)
cleared = self.sub(x.type, x, clearing_bit)
x = self.reinterpret(cleared, bit_width=(x.bit_width - 1))

Expand Down Expand Up @@ -3290,8 +3285,8 @@ def round_bit_pattern(
unskewed = x
if approx_conf.symetrize_deltas:
highest_supported_precision = 62
delta_precision = highest_supported_precision - x.type.bit_width
full_precision = x.type.bit_width + delta_precision
delta_precision = highest_supported_precision - x.bit_width
full_precision = x.bit_width + delta_precision
half_in_extra_precision = (
1 << (delta_precision - 1)
) - 1 # slightly smaller then half
Expand Down

0 comments on commit b135255

Please sign in to comment.