Skip to content

Commit

Permalink
fix(frontend-python): optimize extract bits
Browse files Browse the repository at this point in the history
lsb and tlu calls were not minimized
  • Loading branch information
rudy-6-4 committed Apr 3, 2024
1 parent 916c337 commit fb5bcd7
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 40 deletions.
4 changes: 2 additions & 2 deletions frontends/concrete-python/concrete/fhe/extensions/bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __getitem__(self, index: Union[int, np.integer, slice]) -> Tracer:

def evaluator(x, bits): # pylint: disable=redefined-outer-name
if isinstance(bits, (int, np.integer)):
return (x & (1 << bits)) >> bits
return (x >> bits) & 1

assert isinstance(bits, slice)

Expand All @@ -126,7 +126,7 @@ def evaluator(x, bits): # pylint: disable=redefined-outer-name

result = 0
for i, bit in enumerate(range(start, stop, step)):
value = (x & (1 << bit)) >> bit
value = (x >> bit) & 1
result += value << i

return result
Expand Down
156 changes: 119 additions & 37 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@
# pylint: enable=import-error,no-name-in-module


LUT_COSTS = {
1: 29,
2: 33,
3: 45,
4: 74,
5: 101,
6: 231,
7: 535,
8: 1721,
9: 3864,
10: 8697,
11: 19522,
}


class Context:
"""
Context class, to perform operations on conversions.
Expand Down Expand Up @@ -2173,6 +2188,22 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL})

def shift_left(self, x: Conversion, rank):
assert rank >= 0
assert rank < x.bit_width
shifter = 2**rank
shifter = self.constant(self.i(x.bit_width + 1), shifter)
return self.mul(x.type, x, shifter)

def reduce_precision(self, x: Conversion, bit_width):
assert bit_width > 0
assert bit_width <= x.type.bit_width
if bit_width == x.type.bit_width:
return x
scaled_x = self.shift_left(x, x.type.bit_width - bit_width)
x = self.reinterpret(scaled_x, bit_width=bit_width)
return x

def extract_bits(
self,
resulting_type: ConversionType,
Expand Down Expand Up @@ -2200,66 +2231,116 @@ def extract_bits(
start = bits.start or MIN_EXTRACTABLE_BIT
stop = bits.stop or (MAX_EXTRACTABLE_BIT if step > 0 else (MIN_EXTRACTABLE_BIT - 1))

bits_and_their_positions = []
for position, bit in enumerate(range(start, stop, step)):
bits_and_their_positions.append((bit, position))
bits = list(range(start, stop, step))

bits_and_their_positions = ((bit, position) for position, bit in enumerate(bits))

bits_and_their_positions = sorted(
bits_and_their_positions,
key=lambda bit_and_its_position: bit_and_its_position[0],
)

# we optimize bulk extract in low precision, used for identity
if (
len(bits) > 1
and LUT_COSTS.get(x.type.bit_width, float("inf")) <= (max(bits) + 1) * LUT_COSTS[1]
):

def is_positive(v):
return x.type.is_unsigned or v < 2 ** (x.type.bit_width - 1)

def to_signed(v):
if is_positive(v):
return v
return -(2 ** (x.type.bit_width) - v)

table = [
sum(
((to_signed(v) >> bit) & 1) << position
for bit, position in bits_and_their_positions
)
+ (0 if is_positive(v) else 2 ** (resulting_type.bit_width + 1))
for v in range(2**x.type.bit_width)
]
tlu_result = self.tlu(resulting_type, x, table)
return self.to_signedness(tlu_result, of=resulting_type)

def same_type_as(type_, bit_width):
return self.typeof(
ValueDescription(
dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width),
shape=type_.shape,
is_encrypted=type_.is_encrypted,
)
)

def reduce_precision(x, delta):
assert delta < x.type.bit_width
new_bit_witdh = x.type.bit_width - delta
return self.reduce_precision(x, new_bit_witdh)

current_bit = 0
max_bit = x.original_bit_width

lsb: Optional[Conversion] = None
result: Optional[Conversion] = None

for index, (bit, position) in enumerate(bits_and_their_positions):
for bit, position in bits_and_their_positions:
if bit >= max_bit and x.is_unsigned:
break

last = index == len(bits_and_their_positions) - 1
while bit != (current_bit - 1):
if bit == (max_bit - 1) and x.bit_width == 1 and x.is_unsigned:
if (
bit == (max_bit - 1)
and x.bit_width == resulting_type.bit_width == 1
and x.is_unsigned
):
lsb_bit_witdh = 1
lsb = x
elif last and bit == current_bit:
lsb = self.lsb(resulting_type, x)
else:
lsb = self.lsb(x.type, x)
lsb_bit_witdh = max(resulting_type.bit_width - position, x.type.bit_width)
lsb_type = same_type_as(x.type, lsb_bit_witdh)
lsb = self.lsb(lsb_type, x)

# check that we only need to shift to emulate the initial and final position
# position are expressed for the final bit_width
initial_position = resulting_type.bit_width - x.type.bit_width
actual_position = resulting_type.bit_width - lsb.type.bit_width
assert (
actual_position <= initial_position
), "extract_bits: Cannot get back to initial precision"
assert (
actual_position <= position
), "extract_bits: Cannot get back to final precision"

current_bit += 1

if current_bit >= max_bit:
break

if not last or bit != (current_bit - 1):
cleared = self.sub(x.type, x, lsb)
x = self.reinterpret(cleared, bit_width=(x.bit_width - 1))
delta_precision = initial_position - actual_position
assert 0 <= delta_precision < resulting_type.bit_width
clearing_bit = reduce_precision(lsb, delta_precision)
cleared = self.sub(x.type, x, clearing_bit)
x = self.reinterpret(cleared, bit_width=(x.bit_width - 1))

assert lsb is not None
lsb = self.to_signedness(lsb, of=resulting_type)

if lsb.bit_width > resulting_type.bit_width:
difference = (lsb.bit_width - resulting_type.bit_width) + position
shifter = self.constant(self.i(lsb.bit_width + 1), 2**difference)
shifted = self.mul(lsb.type, lsb, shifter)
lsb = self.reinterpret(shifted, bit_width=resulting_type.bit_width)

elif lsb.bit_width < resulting_type.bit_width:
shift = 2 ** (lsb.bit_width - 1)
if shift != 1:
shifter = self.constant(self.i(lsb.bit_width + 1), shift)
shifted = self.mul(lsb.type, lsb, shifter)
lsb = self.reinterpret(shifted, bit_width=1)
lsb = self.tlu(resulting_type, lsb, [0 << position, 1 << position])

elif position != 0:
shifter = self.constant(self.i(lsb.bit_width + 1), 2**position)
lsb = self.mul(lsb.type, lsb, shifter)
bit_value = self.to_signedness(lsb, of=resulting_type)
bit_value = self.reinterpret(
bit_value, bit_width=max(resulting_type.bit_width, max_bit)
)

assert lsb is not None
result = lsb if result is None else self.add(resulting_type, result, lsb)
delta_precision = position - actual_position
assert actual_position < 0 or 0 <= delta_precision < resulting_type.bit_width, (
position,
actual_position,
resulting_type.bit_width,
)
if delta_precision:
bit_value = self.shift_left(bit_value, delta_precision)

bit_value = self.reinterpret(bit_value, bit_width=resulting_type.bit_width)

result = bit_value if result is None else self.add(resulting_type, result, bit_value)

return result if result is not None else self.zeros(resulting_type)

Expand Down Expand Up @@ -3763,11 +3844,12 @@ def reinterpret(
) -> Conversion:
assert x.is_encrypted

if x.bit_width == bit_width:
result_unsigned = x.is_unsigned if signed is None else not signed

if x.bit_width == bit_width and x.is_unsigned == result_unsigned:
return x

result_signed = x.is_unsigned if signed is None else signed
resulting_element_type = (self.eint if result_signed else self.esint)(bit_width)
resulting_element_type = (self.eint if result_unsigned else self.esint)(bit_width)
resulting_type = self.tensor(resulting_element_type, shape=x.shape)

operation = (
Expand Down
139 changes: 138 additions & 1 deletion frontends/concrete-python/tests/execution/test_bit_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,144 @@ def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers):

compiler = fhe.Compiler(operation, {"x": "encrypted"})
circuit = compiler.compile(inputset, helpers.configuration())

print(circuit.mlir)
values = inputset if len(inputset) <= 8 else random.sample(inputset, 8)
for value in values:
helpers.check_execution(circuit, operation, value, retries=3)


def mlir_count_ops(mlir, operation):
"""
Count op in mlir.
"""
return sum(operation in line for line in mlir.splitlines())


def test_highest_bit_extraction_mlir(helpers):
"""
Test bit extraction of the highest bit. Saves one lsb.
"""

precision = 8
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return fhe.bits(x)[precision - 1]

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1
assert mlir_count_ops(circuit.mlir, "lookup") == 0


def test_bits_extraction_to_same_bitwidth_mlir(helpers):
"""
Test bit extraction to same.
"""

precision = 8
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] for i in range(precision))

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1
assert mlir_count_ops(circuit.mlir, "lookup") == 0


def test_bits_extraction_to_bigger_bitwidth_mlir(helpers):
"""
Test bit extraction to bigger bitwidth.
"""

precision = 8
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] + (2**precision + 1) for i in range(precision))

circuit = operation.compile(inputset, helpers.configuration())
print(circuit.mlir)
assert mlir_count_ops(circuit.mlir, "lsb") == precision
assert mlir_count_ops(circuit.mlir, "lookup") == 0


def test_seq_bits_extraction_to_same_bitwidth_mlir(helpers):
"""
Test sequential bit extraction to smaller bitwidth.
"""

precision = 8
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] + (2**precision - 2) for i in range(precision))

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision
assert mlir_count_ops(circuit.mlir, "lookup") == 0


def test_seq_bits_extraction_to_smaller_bitwidth_mlir(helpers):
"""
Test sequential bit extraction to smaller bitwidth.
"""

precision = 8
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] for i in range(precision))

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1
assert mlir_count_ops(circuit.mlir, "lookup") == 0


def test_seq_bits_extraction_to_bigger_bitwidth_mlir(helpers):
"""
Test sequential bit extraction to bigger bitwidth.
"""

precision = 8
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] + 2 ** (precision + 1) for i in range(precision))

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision
assert mlir_count_ops(circuit.mlir, "lookup") == 0


def test_bit_extract_to_1_tlu(helpers):
"""
Test bit extract as 1 tlu for small precision.
"""
precision = 3
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x):
return fhe.bits(x)[0:2]

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == 0
assert mlir_count_ops(circuit.mlir, "lookup") == 1

precision = 4
inputset = list(range(2**precision))

@fhe.compiler({"x": "encrypted"})
def operation(x): # pylint: disable=function-redefined
return fhe.bits(x)[0:2]

circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == 2
assert mlir_count_ops(circuit.mlir, "lookup") == 0

0 comments on commit fb5bcd7

Please sign in to comment.