diff --git a/docs/guides/configure.md b/docs/guides/configure.md index e84d8170dc..247cbcbeb9 100644 --- a/docs/guides/configure.md +++ b/docs/guides/configure.md @@ -151,3 +151,11 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * When this option is set to `True`, encrypt and decrypt are identity functions, and run is a wrapper around simulation. In other words, this option allows to switch off the encryption to quickly test if a function has expected semantic (without paying the price of FHE execution). * This is extremely unsafe and should only be used during development. * For this reason, it requires **enable\_unsafe\_features** to be set to `True`. +* **auto\_schedule\_run**: bool = False + * Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized. + * For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished. + * E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available. + * If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically. + * Automatic scheduling behavior can be override locally by calling directly a variant of `run`: + * `run_sync`: forces the fhe function to occur in the current thread, not in the background, + * `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]` diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 90c856c9a9..c1eaf55115 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -994,6 +994,7 @@ class Configuration: dynamic_assignment_check_out_of_bounds: bool simulate_encrypt_run_decrypt: bool composable: bool + auto_schedule_run: bool def __init__( self, @@ -1063,6 +1064,7 @@ def __init__( dynamic_indexing_check_out_of_bounds: bool = True, dynamic_assignment_check_out_of_bounds: bool = True, simulate_encrypt_run_decrypt: bool = False, + auto_schedule_run: bool = False, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1170,6 +1172,8 @@ def __init__( self.simulate_encrypt_run_decrypt = simulate_encrypt_run_decrypt + self.auto_schedule_run = auto_schedule_run + self._validate() class Keep: @@ -1245,6 +1249,7 @@ def fork( dynamic_indexing_check_out_of_bounds: Union[Keep, bool] = KEEP, dynamic_assignment_check_out_of_bounds: Union[Keep, bool] = KEEP, simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP, + auto_schedule_run: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index 3f15cdb100..72e4797f4d 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -4,8 +4,12 @@ # pylint: disable=import-error,no-member,no-name-in-module +import asyncio from pathlib import Path +from threading import Thread from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future import numpy as np from concrete.compiler import ( @@ -29,14 +33,38 @@ # pylint: enable=import-error,no-member,no-name-in-module -class ExecutionRt(NamedTuple): +class ExecutionRt: """ Runtime object class for execution. """ client: Client server: Server + auto_schedule_run: bool + fhe_executor_pool: ThreadPoolExecutor + fhe_waiter_loop: asyncio.BaseEventLoop + fhe_waiter_thread: Thread # daemon thread + + def __init__(self, client, server, auto_schedule_run): + self.client = client + self.server = server + self.auto_schedule_run = auto_schedule_run + if auto_schedule_run: + self.fhe_executor_pool = ThreadPoolExecutor() + self.fhe_waiter_loop = asyncio.new_event_loop() + def loop_thread(): + asyncio.set_event_loop(self.fhe_waiter_loop) + self.fhe_waiter_loop.run_forever() + self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True) + self.fhe_waiter_thread.start() + else: + self.fhe_executor_pool = None + self.fhe_waiter_loop = None + self.fhe_waiter_thread = None + def __del__(self): + if self.fhe_waiter_loop: + self.fhe_waiter_loop.stop() # daemon cleanup class SimulationRt(NamedTuple): """ @@ -186,10 +214,47 @@ def encrypt( assert isinstance(self.runtime, ExecutionRt) return self.runtime.client.encrypt(*args, function_name=self.name) - def run( + def run_sync( self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], ) -> Union[Value, Tuple[Value, ...]]: + """ + Evaluate the function synchronuously. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ + + return self._run(True, *args) + + def run_async(self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]] + ) -> 'Union[Future[Value], Future[Tuple[Value, ...]]]': + """ + Evaluate the function asynchronuously. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ + if isinstance(self.runtime, ExecutionRt) and not self.runtime.fhe_executor_pool: + self.runtime = ExecutionRt(self.runtime.client, self.runtime.server, True) + self.runtime.auto_schedule_run = False + + return self._run(False, *args) + + def run( + self, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...], Future]: """ Evaluate the function. @@ -201,15 +266,63 @@ def run( Union[Value, Tuple[Value, ...]]: result(s) of evaluation """ + if isinstance(self.runtime, ExecutionRt): + auto_schedule_run = self.runtime.auto_schedule_run + else: + auto_schedule_run = False + return self._run(not auto_schedule_run, *args) + def _run( + self, + sync, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...], Future]: + """ + Evaluate the function. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ if self.configuration.simulate_encrypt_run_decrypt: return self.simulate(*args) assert isinstance(self.runtime, ExecutionRt) - return self.runtime.server.run( + + fhe_work = lambda *args:self.runtime.server.run( *args, evaluation_keys=self.runtime.client.evaluation_keys, function_name=self.name ) + def args_ready(args): + return [arg.result() if isinstance(arg, Future) else arg for arg in args] + + if sync: + return fhe_work(*args_ready(args)) + + all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args) + + fhe_work_future = lambda *args:self.runtime.fhe_executor_pool.submit(fhe_work, *args) + if all_args_done: + return fhe_work_future(*args_ready(args)) + + # waiting args to be ready with async coroutines + # it only required one thread to run unlimited waits vs unlimited sync threads + async def wait_async(arg): + if not isinstance(arg, Future): + return arg + if arg.done(): + return arg.result() + return await asyncio.wrap_future(arg, loop=self.runtime.fhe_waiter_loop) + async def args_ready_and_submit(*args): + args = [await wait_async(arg) for arg in args] + return await wait_async(fhe_work_future(*args)) + run_async = args_ready_and_submit(*args) + return asyncio.run_coroutine_threadsafe(run_async, self.runtime.fhe_waiter_loop) + def decrypt( self, *results: Union[Value, Tuple[Value, ...]], @@ -230,6 +343,7 @@ def decrypt( return results if len(results) != 1 else results[0] # type: ignore assert isinstance(self.runtime, ExecutionRt) + results = [res.result() if isinstance(res, Future) else res for res in results] return self.runtime.client.decrypt(*results, function_name=self.name) def encrypt_run_decrypt(self, *args: Any) -> Any: @@ -585,7 +699,7 @@ def __init__( keyset_cache_directory = self.configuration.insecure_key_cache_location client = Client(server.client_specs, keyset_cache_directory) - self.runtime = ExecutionRt(client, server) + self.runtime = ExecutionRt(client, server, self.configuration.auto_schedule_run) @property def mlir(self) -> str: diff --git a/frontends/concrete-python/tests/compilation/test_modules.py b/frontends/concrete-python/tests/compilation/test_modules.py index df3f6ec679..a4d2e63676 100644 --- a/frontends/concrete-python/tests/compilation/test_modules.py +++ b/frontends/concrete-python/tests/compilation/test_modules.py @@ -2,6 +2,7 @@ Tests of everything related to modules. """ +from concurrent.futures import Future import inspect import re import tempfile @@ -325,7 +326,6 @@ def dec(x): fhe_simulation=True, ) - assert module.client is None assert module.keys is None assert module.inc.simulate(5) == 6 assert module.dec.simulate(5) == 4 @@ -718,3 +718,78 @@ def function(x): output = client.decrypt(deserialized_result, function_name="inc") assert output == 11 + +class IncDec: + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return fhe.refresh(x + 1) + + @fhe.function({"x": "encrypted"}) + def dec(x): + return fhe.refresh(x - 1) + + precision = 4 + + inputset = list(range(1, 2**precision-1)) + to_compile = {"inc": inputset, "dec": inputset} + +def test_run_async(): + """ + Test `run_async` with `auto_schedule_run=False` configuration option. + """ + + module = IncDec.Module.compile(IncDec.to_compile) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run_async(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run(a) + assert isinstance(b, type(encrypted_x)) + + result = module.inc.decrypt(b) + assert result == sample_x + +def test_run_sync(): + """ + Test `run_sync` with `auto_schedule_run=True` configuration option. + """ + + conf = fhe.Configuration(auto_schedule_run=True) + module = IncDec.Module.compile(IncDec.to_compile, conf) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run_sync(a) + assert isinstance(b, type(encrypted_x)) + + result = module.inc.decrypt(b) + assert result == sample_x + +def test_run_auto_schedule(): + """ + Test `run` with `auto_schedule_run=True` configuration option. + """ + + conf = fhe.Configuration(auto_schedule_run=True) + module = IncDec.Module.compile(IncDec.to_compile, conf) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run(a) + assert isinstance(b, Future) + + result = module.inc.decrypt(b) + assert result == sample_x