Skip to content

Commit

Permalink
feat(frontend): extend bridge to work with Client
Browse files Browse the repository at this point in the history
this is needed in a deployment usecase where we don't have a circuit or
module
  • Loading branch information
youben11 committed Mar 7, 2025
1 parent 108918d commit 76ebd79
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 116 deletions.
5 changes: 5 additions & 0 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ..internal.utils import assert_that
from ..representation import Graph
from ..tfhers.specs import TFHERSClientSpecs
from .client import Client
from .composition import CompositionRule
from .configuration import Configuration
Expand Down Expand Up @@ -720,12 +721,15 @@ def __init__(
self.mlir_module = mlir
self.compilation_context = compilation_context

tfhers_specs = TFHERSClientSpecs.from_graphs(graphs)

def init_simulation():
simulation_server = Server.create(
self.mlir_module,
self.configuration.fork(fhe_simulation=True),
is_simulated=True,
compilation_context=self.compilation_context,
tfhers_specs=tfhers_specs,
)
simulation_client = Client(simulation_server.client_specs, is_simulated=True)
return SimulationRt(simulation_client, simulation_server)
Expand All @@ -741,6 +745,7 @@ def init_execution():
compilation_context=self.compilation_context,
composition_rules=composition_rules,
is_simulated=False,
tfhers_specs=tfhers_specs,
)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
Expand Down
27 changes: 24 additions & 3 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from concrete.compiler import lookup_runtime_lib, set_compiler_logging, set_llvm_debug_flag
from mlir.ir import Module as MlirModule

from ..tfhers.specs import TFHERSClientSpecs
from .composition import CompositionClause, CompositionRule
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
Expand All @@ -60,24 +61,27 @@ class Server:
_mlir: Optional[str]
_configuration: Optional[Configuration]
_composition_rules: Optional[List[CompositionRule]]
_tfhers_specs: Optional[TFHERSClientSpecs]

def __init__(
self,
library: Library,
is_simulated: bool,
composition_rules: Optional[List[CompositionRule]],
tfhers_specs: Optional[TFHERSClientSpecs] = None,
):
self.is_simulated = is_simulated
self._library = library
self._mlir = None
self._composition_rules = composition_rules
self._tfhers_specs = tfhers_specs

@property
def client_specs(self) -> ClientSpecs:
"""
Return the associated client specs.
"""
return ClientSpecs(self._library.get_program_info())
return ClientSpecs(self._library.get_program_info(), tfhers_specs=self._tfhers_specs)

@staticmethod
def create(
Expand All @@ -86,6 +90,7 @@ def create(
is_simulated: bool = False,
compilation_context: Optional[CompilationContext] = None,
composition_rules: Optional[Iterable[CompositionRule]] = None,
tfhers_specs: Optional[TFHERSClientSpecs] = None,
) -> "Server":
"""
Create a server using MLIR and output sign information.
Expand All @@ -105,6 +110,9 @@ def create(
composition_rules (Iterable[Tuple[str, int, str, int]]):
composition rules to be applied when compiling
tfhers_specs (Optional[TFHERSClientSpecs]):
TFHE-rs client specs
"""

backend = Backend.GPU if configuration.use_gpu else Backend.CPU
Expand Down Expand Up @@ -220,7 +228,10 @@ def create(
composition_rules = composition_rules if composition_rules else None

result = Server(
library=library, is_simulated=is_simulated, composition_rules=composition_rules
library=library,
is_simulated=is_simulated,
composition_rules=composition_rules,
tfhers_specs=tfhers_specs,
)

# pylint: disable=protected-access
Expand Down Expand Up @@ -332,6 +343,11 @@ def load(path: Union[str, Path], **kwargs) -> "Server":
else None
)

tfhers_specs = None
if (output_dir_path / "client.specs.json").exists():
with open(output_dir_path / "client.specs.json", "rb") as f:
tfhers_specs = ClientSpecs.deserialize(f.read()).tfhers_specs

if (output_dir_path / "circuit.mlir").exists():
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
mlir = f.read()
Expand All @@ -340,7 +356,11 @@ def load(path: Union[str, Path], **kwargs) -> "Server":
configuration = Configuration().fork(**jsonpickle.loads(f.read())).fork(**kwargs)

return Server.create(
mlir, configuration, is_simulated, composition_rules=composition_rules
mlir,
configuration,
is_simulated,
composition_rules=composition_rules,
tfhers_specs=tfhers_specs,
)

library = Library(str(output_dir_path))
Expand All @@ -349,6 +369,7 @@ def load(path: Union[str, Path], **kwargs) -> "Server":
library,
is_simulated,
composition_rules,
tfhers_specs=tfhers_specs,
)

def run(
Expand Down
40 changes: 30 additions & 10 deletions frontends/concrete-python/concrete/fhe/compilation/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
"""

# pylint: disable=import-error,no-member,no-name-in-module

from typing import Any
import json
from typing import Any, Optional

# mypy: disable-error-code=attr-defined
from concrete.compiler import ProgramInfo

# pylint: enable=import-error,no-member,no-name-in-module
from concrete import fhe


class ClientSpecs:
Expand All @@ -18,28 +19,39 @@ class ClientSpecs:
"""

program_info: ProgramInfo
tfhers_specs: Optional["fhe.tfhers.TFHERSClientSpecs"]

def __init__(self, program_info: ProgramInfo):
def __init__(
self,
program_info: ProgramInfo,
tfhers_specs: Optional["fhe.tfhers.TFHERSClientSpecs"] = None,
):
self.program_info = program_info
self.tfhers_specs = tfhers_specs

def __eq__(self, other: Any): # pragma: no cover
return self.program_info.serialize() == other.program_info.serialize()
return (
self.program_info.serialize() == other.program_info.serialize()
and self.tfhers_specs == other.tfhers_specs
)

def serialize(self) -> bytes:
"""
Serialize client specs into a string representation.
Serialize client specs into bytes.
Returns:
bytes:
serialized client specs
"""

return self.program_info.serialize()
program_info = json.loads(self.program_info.serialize())
if self.tfhers_specs is not None:
program_info["tfhers_specs"] = self.tfhers_specs.to_dict()
return json.dumps(program_info).encode("utf-8")

@staticmethod
def deserialize(serialized_client_specs: bytes) -> "ClientSpecs":
"""
Create client specs from its string representation.
Create client specs from bytes.
Args:
serialized_client_specs (bytes):
Expand All @@ -49,6 +61,14 @@ def deserialize(serialized_client_specs: bytes) -> "ClientSpecs":
ClientSpecs:
deserialized client specs
"""
program_info_dict = json.loads(serialized_client_specs)
tfhers_specs_dict = program_info_dict.get("tfhers_specs", None)

if tfhers_specs_dict is not None:
del program_info_dict["tfhers_specs"]
tfhers_specs = fhe.tfhers.TFHERSClientSpecs.from_dict(tfhers_specs_dict)
else:
tfhers_specs = None

program_info = ProgramInfo.deserialize(serialized_client_specs)
return ClientSpecs(program_info)
program_info = ProgramInfo.deserialize(json.dumps(program_info_dict).encode("utf-8"))
return ClientSpecs(program_info, tfhers_specs)
3 changes: 2 additions & 1 deletion frontends/concrete-python/concrete/fhe/tfhers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
uint16,
uint16_2_2,
)
from .specs import TFHERSClientSpecs
from .tracing import from_native, to_native
from .values import TFHERSInteger

Expand All @@ -39,7 +40,7 @@ def get_type_from_params(
"""

# Read crypto parameters from TFHE-rs in the json file
with open(path_to_params_json) as f:
with open(path_to_params_json, "r", encoding="utf-8") as f:
crypto_param_dict = json.load(f)

return get_type_from_params_dict(crypto_param_dict, is_signed, precision)
Expand Down
Loading

0 comments on commit 76ebd79

Please sign in to comment.