diff --git a/docs/guides/configure.md b/docs/guides/configure.md index 88e0b5b3f0..49e7470a12 100644 --- a/docs/guides/configure.md +++ b/docs/guides/configure.md @@ -210,3 +210,11 @@ When options are specified both in the `configuration` and as kwargs in the `com #### verbose: bool = False - Print details related to compilation. +#### auto_schedule_run: bool = False + - Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized. + - For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished. + - E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available. + - If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically. + - Automatic scheduling behavior can be override locally by calling directly a variant of `run`: + - `run_sync`: forces the fhe function to occur in the current thread, not in the background, + - `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]` diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index a1711be835..873e2a4cf8 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -195,7 +195,7 @@ def run( result(s) of evaluation """ - return self._function.run(*args) + return self._function.run_sync(*args) def decrypt( self, diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 421dda91d6..5cfe97b721 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -997,6 +997,7 @@ class Configuration: composable: bool range_restriction: Optional[RangeRestriction] keyset_restriction: Optional[KeysetRestriction] + auto_schedule_run: bool def __init__( self, @@ -1068,6 +1069,7 @@ def __init__( simulate_encrypt_run_decrypt: bool = False, range_restriction: Optional[RangeRestriction] = None, keyset_restriction: Optional[KeysetRestriction] = None, + auto_schedule_run: bool = False, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1177,6 +1179,8 @@ def __init__( self.range_restriction = range_restriction self.keyset_restriction = keyset_restriction + self.auto_schedule_run = auto_schedule_run + self._validate() class Keep: @@ -1254,6 +1258,7 @@ def fork( simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP, range_restriction: Union[Keep, Optional[RangeRestriction]] = KEEP, keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP, + auto_schedule_run: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index 279c26dff6..7b5e4d4748 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -4,8 +4,11 @@ # pylint: disable=import-error,no-member,no-name-in-module +import asyncio +from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path -from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from threading import Thread +from typing import Any, Awaitable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union import numpy as np from concrete.compiler import CompilationContext, LweSecretKey, Parameter @@ -24,13 +27,40 @@ # pylint: enable=import-error,no-member,no-name-in-module -class ExecutionRt(NamedTuple): +class ExecutionRt: """ Runtime object class for execution. """ client: Client server: Server + auto_schedule_run: bool + fhe_executor_pool: ThreadPoolExecutor + fhe_waiter_loop: asyncio.BaseEventLoop + fhe_waiter_thread: Thread # daemon thread + + def __init__(self, client, server, auto_schedule_run): + self.client = client + self.server = server + self.auto_schedule_run = auto_schedule_run + if auto_schedule_run: + self.fhe_executor_pool = ThreadPoolExecutor() + self.fhe_waiter_loop = asyncio.new_event_loop() + + def loop_thread(): + asyncio.set_event_loop(self.fhe_waiter_loop) + self.fhe_waiter_loop.run_forever() + + self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True) + self.fhe_waiter_thread.start() + else: + self.fhe_executor_pool = None + self.fhe_waiter_loop = None + self.fhe_waiter_thread = None + + def __del__(self): + if self.fhe_waiter_loop: + self.fhe_waiter_loop.stop() # daemon cleanup class SimulationRt(NamedTuple): @@ -177,12 +207,12 @@ def encrypt( return tuple(args) if len(args) > 1 else args[0] # type: ignore return self.execution_runtime.val.client.encrypt(*args, function_name=self.name) - def run( + def run_sync( self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], - ) -> Union[Value, Tuple[Value, ...]]: + ) -> Any: """ - Evaluate the function. + Evaluate the function synchronuously. Args: *args (Value): @@ -193,17 +223,115 @@ def run( result(s) of evaluation """ + return self._run(True, *args) + + def run_async( + self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]] + ) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]: + """ + Evaluate the function asynchronuously. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Awaitable[Value], Awaitable[Tuple[Value, ...]]]: + result(s) a future of the evaluation + """ + if ( + isinstance(self.execution_runtime.val, ExecutionRt) + and not self.execution_runtime.val.fhe_executor_pool + ): + client = self.execution_runtime.val.client + server = self.execution_runtime.val.server + self.execution_runtime = Lazy(lambda: ExecutionRt(client, server, True)) + self.execution_runtime.val.auto_schedule_run = False + + return self._run(False, *args) + + def run( + self, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]: + """ + Evaluate the function. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]: + result(s) of evaluation or future of result(s) of evaluation if configured with async_run=True + """ + if isinstance(self.execution_runtime.val, ExecutionRt): + auto_schedule_run = self.execution_runtime.val.auto_schedule_run + else: + auto_schedule_run = False # pragma: no cover + return self._run(not auto_schedule_run, *args) + + def _run( + self, + sync: bool, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]: + """ + Evaluate the function. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]: + result(s) of evaluation if sync=True else future of result(s) of evaluation + """ if self.configuration.simulate_encrypt_run_decrypt: return self._simulate_decrypt(self._simulate_run(*args)) # type: ignore - return self.execution_runtime.val.server.run( + + assert isinstance(self.execution_runtime.val, ExecutionRt) + + fhe_work = lambda *args: self.execution_runtime.val.server.run( *args, evaluation_keys=self.execution_runtime.val.client.evaluation_keys, function_name=self.name, ) + def args_ready(args): + return [arg.result() if isinstance(arg, Future) else arg for arg in args] + + if sync: + return fhe_work(*args_ready(args)) + + all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args) + + fhe_work_future = lambda *args: self.execution_runtime.val.fhe_executor_pool.submit( + fhe_work, *args + ) + if all_args_done: + return fhe_work_future(*args_ready(args)) # type: ignore + + # waiting args to be ready with async coroutines + # it only required one thread to run unlimited waits vs unlimited sync threads + async def wait_async(arg): + if not isinstance(arg, Future): + return arg # pragma: no cover + if arg.done(): + return arg.result() # pragma: no cover + return await asyncio.wrap_future(arg, loop=self.execution_runtime.val.fhe_waiter_loop) + + async def args_ready_and_submit(*args): + args = [await wait_async(arg) for arg in args] + return await wait_async(fhe_work_future(*args)) + + run_async = args_ready_and_submit(*args) + return asyncio.run_coroutine_threadsafe( + run_async, self.execution_runtime.val.fhe_waiter_loop + ) # type: ignore + def decrypt( - self, - *results: Union[Value, Tuple[Value, ...]], + self, *results: Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]] ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: """ Decrypt result(s) of evaluation. @@ -220,6 +348,8 @@ def decrypt( if self.configuration.simulate_encrypt_run_decrypt: return tuple(results) if len(results) > 1 else results[0] # type: ignore + assert isinstance(self.execution_runtime.val, ExecutionRt) + results = [res.result() if isinstance(res, Future) else res for res in results] return self.execution_runtime.val.client.decrypt(*results, function_name=self.name) def encrypt_run_decrypt(self, *args: Any) -> Any: @@ -620,7 +750,9 @@ def init_execution(): execution_client = Client( execution_server.client_specs, keyset_cache_directory, is_simulated=False ) - return ExecutionRt(execution_client, execution_server) + return ExecutionRt( + execution_client, execution_server, self.configuration.auto_schedule_run + ) self.execution_runtime = Lazy(init_execution) if configuration.fhe_execution: diff --git a/frontends/concrete-python/tests/compilation/test_modules.py b/frontends/concrete-python/tests/compilation/test_modules.py index 07dc33bc12..d593b5376c 100644 --- a/frontends/concrete-python/tests/compilation/test_modules.py +++ b/frontends/concrete-python/tests/compilation/test_modules.py @@ -4,7 +4,9 @@ import inspect import tempfile +from concurrent.futures import Future from pathlib import Path +from typing import Awaitable import numpy as np import pytest @@ -955,3 +957,83 @@ def inc(x, y): }, helpers.configuration().fork(), ) + + +class IncDec: + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return fhe.refresh(x + 1) + + @fhe.function({"x": "encrypted"}) + def dec(x): + return fhe.refresh(x - 1) + + precision = 4 + + inputset = list(range(1, 2**precision - 1)) + to_compile = {"inc": inputset, "dec": inputset} + + +def test_run_async(): + """ + Test `run_async` with `auto_schedule_run=False` configuration option. + """ + + module = IncDec.Module.compile(IncDec.to_compile) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run_async(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run(a) + assert isinstance(b, type(encrypted_x)) + + result = module.inc.decrypt(b) + assert result == sample_x + del module + + +def test_run_sync(): + """ + Test `run_sync` with `auto_schedule_run=True` configuration option. + """ + + conf = fhe.Configuration(auto_schedule_run=True) + module = IncDec.Module.compile(IncDec.to_compile, conf) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run_sync(a) + assert isinstance(b, type(encrypted_x)) + + result = module.inc.decrypt(b) + assert result == sample_x + + +def test_run_auto_schedule(): + """ + Test `run` with `auto_schedule_run=True` configuration option. + """ + + conf = fhe.Configuration(auto_schedule_run=True) + module = IncDec.Module.compile(IncDec.to_compile, conf) + + sample_x = 2 + encrypted_x = module.inc.encrypt(sample_x) + + a = module.inc.run(encrypted_x) + assert isinstance(a, Future) + + b = module.dec.run(a) + assert isinstance(b, Future) + + result = module.inc.decrypt(b) + assert result == sample_x