Skip to content

Commit

Permalink
feat(frontend-python): experimental synthesis extension
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Sep 10, 2024
1 parent edaa208 commit 705b5fe
Show file tree
Hide file tree
Showing 16 changed files with 1,808 additions and 12 deletions.
1 change: 1 addition & 0 deletions frontends/concrete-python/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ disable=raw-checker-failed,
wrong-import-order,
unsubscriptable-object,
no-else-continue,
no-else-return,
unnecessary-comprehension

# Enable the message, report, category or checker with the given id(s). You can
Expand Down
7 changes: 4 additions & 3 deletions frontends/concrete-python/.ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ select = [
"PLC", "PLE", "PLR", "PLW", "RUF"
]
ignore = [
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "SIM105",
"RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901",
"A", "D", "FBT", "T20", "ANN", "N805", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "SIM105",
"RET504", "RET505", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901",
"PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901",
"E731", "RET507", "SIM102", "N805",
"E731", "RET507", "SIM102", "SIM108",
"Q000",
]

[per-file-ignores]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MultiParameterStrategy,
MultivariateStrategy,
ParameterSelectionStrategy,
SynthesisConfig,
)
from .keys import Keys
from .module import FheFunction, FheModule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,31 @@ class ApproximateRoundingConfig:
"""


@dataclass
class SynthesisConfig:
"""
Controls the behavior of synthesis.
"""

start_tlu_at_precision: int = 7
"""
Starting synthesis at the given TLU input precision, but keep the original TLU when it's faster.
Used to make high precision TLU faster by rewritting them with several lower precisions TLU.
"""

force_tlu_at_precision: int = 17
"""
Force synthesis at the given TLU input precision, even if it's slower than the original TLU.
Used to replace any high precision TLU by several lower precisions TLU.
"""

maximal_tlu_input_bit_width: int = 8
"""
Maximal bit_width for TLU generated by synthesis.
Used if you want guarantees on the maximum input bit_width of TLU after synthesis.
"""


class ComparisonStrategy(str, Enum):
"""
ComparisonStrategy, to specify implementation preference for comparisons.
Expand Down Expand Up @@ -994,6 +1019,7 @@ class Configuration:
dynamic_assignment_check_out_of_bounds: bool
simulate_encrypt_run_decrypt: bool
composable: bool
synthesis_config: SynthesisConfig

def __init__(
self,
Expand Down Expand Up @@ -1063,6 +1089,7 @@ def __init__(
dynamic_indexing_check_out_of_bounds: bool = True,
dynamic_assignment_check_out_of_bounds: bool = True,
simulate_encrypt_run_decrypt: bool = False,
synthesis_config: Optional[SynthesisConfig] = None,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1170,6 +1197,8 @@ def __init__(

self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt

self.synthesis_config = synthesis_config or SynthesisConfig()

self._validate()

class Keep:
Expand Down Expand Up @@ -1245,6 +1274,7 @@ def fork(
dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP,
dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP,
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
synthesis_config: Union[Keep, Optional[SynthesisConfig]] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Provide synthesis main entry points."""

from .fhe_function import lut, verilog_expression
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# pylint: disable=missing-module-docstring,missing-function-docstring

from dataclasses import dataclass

import numpy as np

from concrete.fhe.extensions.synthesis.verilog_source import Ty


class EvalContext:
"""
This is a reduced context with similar method as `concrete.fhe.mlir.Context`.
It provides a clear evaluation backend for tlu_circuit_to_mlir.
Until the all internal_api are used directly by concrete-python,
this helps to keep all tests for previous backend and api.
For now only synthesis of TLU is supported by concrete-python.
"""

# For all `EvalContext` method look at `Context` documentation.

@dataclass
class Ty:
"""Equivalent for `ConversionType`."""

bit_width: int
is_tensor: bool = False
shape: tuple = () # not used

@dataclass
class Val:
"""Equivalent for `Conversion`. Contains the evaluation result."""

value: 'int | np.ndarray'
type: Ty

def __init__(self, value, type_):
try:
value = int(value)
except TypeError:
pass
self.value = value
self.type = type_

def fork_type(self, type_, bit_width=None, shape=None):
return self.Ty(bit_width=bit_width or type_.bit_width, shape=shape or type_.shape)

def i(self, size):
return self.Ty(size)

def constant(self, type_: Ty, value: int):
return self.Val(value, type_)

def mul(self, type_: Ty, a: Val, b: Val):
assert isinstance(b.value, int)
assert a.type == type_
return self.Val(a.value * b.value, type_)

def add(self, type_: Ty, a: Val, b: Val):
assert a.type == b.type == type_
return self.Val(a.value + b.value, type_)

def sub(self, type_: Ty, a: Val, b: Val):
assert isinstance(b.value, int)
assert a.type == type_
return self.Val(a.value - b.value, type_)

def tlu(self, type_: Ty, arg: Val, tlu_content, **_kwargs):
if isinstance(arg, int):
v = self.Val(tlu_content[arg.value], type_)
else:
v = np.vectorize(lambda v: int(tlu_content[v]))(arg.value)
return self.Val(v, type_)

def extract_bits(self, type_: Ty, arg: Val, bit_index, **_kwargs):
return self.Val((arg.value >> bit_index) & 1, type_)

def to_unsigned(self, arg: Val):
def aux(value):
if value < 0:
return 2**arg.type.bit_width + value
return value

if isinstance(arg.value, int):
v = aux(arg.value)
else:
v = np.vectorize(aux)(arg.value)
return self.Val(v, arg.type)

def to_signed(self, arg: Val):
def aux(value):
assert value >= 0
negative = value >= 2 ** (arg.type.bit_width - 1)
if negative:
return -(2**arg.type.bit_width - arg.value)
return value

if isinstance(arg.value, int):
v = aux(arg.value)
else:
v = np.vectorize(aux)(arg.value)
return self.Val(v, arg.type)

def index(self, type_: Ty, tensor: Val, index):
assert isinstance(tensor.value, list), type(tensor.value)
assert len(index) == 1
(index,) = index
return self.Val(tensor.value[index], self.Ty(type_.bit_width, is_tensor=False))

def reinterpret(self, arg, bit_width=None):
arg_bit_width = arg.type.bit_width
if bit_width is None:
bit_width = arg_bit_width
if bit_width == arg_bit_width:
return arg
shift = 2 ** (bit_width - arg_bit_width)
if isinstance(arg, int):
v = arg.value * shift
else:
v = np.vectorize(lambda v: v * shift)(arg.value)
return self.Val(v, self.Ty(bit_width=bit_width))

def safe_reduce_precision(self, arg, bit_width):
if arg.type.bit_width == bit_width:
return arg
assert arg.type.bit_width > bit_width
shift = arg.type.bit_width - bit_width
shifted = self.mul(arg.type, arg, self.constant(self.i(bit_width + 1), 2**shift))
return self.reinterpret(shifted, bit_width)
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
INTERNAL extension to synthesize a fhe compatible function from verilog code.
"""

from collections import Counter
from typing import Optional

import concrete.fhe.dtypes as fhe_dtypes
import concrete.fhe.tracing.typing as fhe_typing
from concrete.fhe.dtypes.integer import Integer
from concrete.fhe.extensions.synthesis.eval_context import EvalContext
from concrete.fhe.extensions.synthesis.luts_to_fhe import tlu_circuit_to_mlir
from concrete.fhe.extensions.synthesis.luts_to_graph import to_graph
from concrete.fhe.extensions.synthesis.verilog_source import (
Ty,
verilog_from_expression,
verilog_from_tlu,
)
from concrete.fhe.extensions.synthesis.verilog_to_luts import yosys_lut_synthesis
from concrete.fhe.values.value_description import ValueDescription


class FheFunction:
"""Main class to synthesize verilog."""

def __init__(
self,
*,
verilog,
name,
params=None,
result_name="result",
yosys_dot_file=False,
verbose=False,
):
assert params
self.name = name
self.verilog = verilog
if verbose:
print()
print(f"Verilog, {name}:")
print(verilog)
print()
if verbose:
print("Synthesis")
self.circuit = yosys_lut_synthesis(
verilog, yosys_dot_file=yosys_dot_file, circuit_name=name
)
if verbose:
print()
print(f"TLUs counts, {self.tlu_counts()}:")
print()
self.params = params
self.result_name = result_name

self.mlir = tlu_circuit_to_mlir(self.circuit, self.params, result_name, verbose)

def __call__(self, **kwargs):
"""
Evaluate using mlir generation with a direct evaluation context.
This is useful for testing purpose.
"""
args = []
for name, type_ in self.params.items():
if name == "result":
continue
if isinstance(type_, list):
val = EvalContext.Val(
kwargs[name], EvalContext.Ty(type_[0].dtype.bit_width, is_tensor=True)
)
else:
val = EvalContext.Val(kwargs[name], EvalContext.Ty(type_.dtype.bit_width))
args.append(val)
result_ty = self.params["result"]
if isinstance(result_ty, list):
eval_ty = EvalContext.Ty(result_ty, is_tensor=True)
else:
eval_ty = EvalContext.Ty(result_ty, is_tensor=False)
result = self.mlir(EvalContext(), eval_ty, args)
if isinstance(result_ty, list):
return [r.value for r in result]
else:
return result.value

def tlu_counts(self):
"""Count the number of tlus in the synthesized tracer keyed by input precision."""
counter = Counter()
for node in self.circuit.nodes:
if len(node.arguments) == 1:
print(node)
counter.update({len(node.arguments): 1})
return dict(sorted(counter.items()))

def is_faster_than_1_tlu(self, reference_costs):
"""Verify that synthesis is faster than the original tlu."""
costs = 0
for node in self.circuit.nodes:
zero_cost = len(node.arguments) <= 1
if zero_cost:
# constant or inversion (converted to substraction)
continue
else:
costs += reference_costs[len(node.arguments)]
try:
return costs <= reference_costs[self.params["a"].dtype.bit_width]
except KeyError:
return True

def graph(self, *, filename=None, view=True, **kwargs):
"""Render the synthesized tracer as a graph."""
graph = to_graph(self.name, self.circuit.nodes)
graph.render(filename=filename, view=view, cleanup=filename is None, **kwargs)


def lut(table: 'list[int]', out_type: Optional[ValueDescription] = None, **kwargs):
"""Synthesize a lookup function from a table."""
# assert not signed # TODO signed case
if isinstance(out_type, list):
msg = "Multi-message output is not supported"
raise TypeError(msg)
if out_type:
assert isinstance(out_type.dtype, Integer)
v_out_type = Ty(
bit_width=out_type.dtype.bit_width,
is_signed=out_type.dtype.is_signed,
)
verilog, v_out_type = verilog_from_tlu(table, signed_input=False, out_type=v_out_type)
if "name" not in kwargs:
kwargs.setdefault("name", "lut")
if "params" not in kwargs:
dtype = fhe_dtypes.Integer.that_can_represent(len(table) - 1)
a_ty = getattr(fhe_typing, f"uint{dtype.bit_width}")
assert a_ty
kwargs["params"] = {"a": a_ty, "result": out_type}
return FheFunction(verilog=verilog, **kwargs)


def _uniformize_as_list(v):
return v if isinstance(v, (list, tuple)) else [v]


def verilog_expression(
in_params: 'dict[str, ValueDescription]', expression: str, out_type: ValueDescription, **kwargs
):
"""Synthesize a lookup function from a verilog function."""
result_name = "result"
if result_name in in_params:
result_name = f"{result_name}_{hash(expression)}"
in_params = dict(in_params)
in_params[result_name] = out_type
verilog_params = {
name: Ty(
bit_width=sum(ty.dtype.bit_width for ty in _uniformize_as_list(type_list)),
is_signed=any(ty.dtype.is_signed for ty in _uniformize_as_list(type_list)),
)
for name, type_list in in_params.items()
}
verilog = verilog_from_expression(verilog_params, expression, result_name)
if "name" not in kwargs:
kwargs.setdefault("name", expression)
return FheFunction(verilog=verilog, params=in_params, result_name=result_name, **kwargs)
Loading

0 comments on commit 705b5fe

Please sign in to comment.