Skip to content

Commit

Permalink
feat(frontend-python): module run are scheduled and parallelized in a…
Browse files Browse the repository at this point in the history
… worker pool
  • Loading branch information
rudy-6-4 authored and BourgerieQuentin committed Dec 6, 2024
1 parent 570e05c commit c292a8b
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 10 deletions.
8 changes: 8 additions & 0 deletions docs/guides/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,11 @@ When options are specified both in the `configuration` and as kwargs in the `com
#### verbose: bool = False
- Print details related to compilation.

#### auto_schedule_run: bool = False
- Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized.
- For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished.
- E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available.
- If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically.
- Automatic scheduling behavior can be override locally by calling directly a variant of `run`:
- `run_sync`: forces the fhe function to occur in the current thread, not in the background,
- `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]`
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def run(
result(s) of evaluation
"""

return self._function.run(*args)
return self._function.run_sync(*args)

def decrypt(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ class Configuration:
composable: bool
range_restriction: Optional[RangeRestriction]
keyset_restriction: Optional[KeysetRestriction]
auto_schedule_run: bool

def __init__(
self,
Expand Down Expand Up @@ -1068,6 +1069,7 @@ def __init__(
simulate_encrypt_run_decrypt: bool = False,
range_restriction: Optional[RangeRestriction] = None,
keyset_restriction: Optional[KeysetRestriction] = None,
auto_schedule_run: bool = False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1177,6 +1179,8 @@ def __init__(
self.range_restriction = range_restriction
self.keyset_restriction = keyset_restriction

self.auto_schedule_run = auto_schedule_run

self._validate()

class Keep:
Expand Down Expand Up @@ -1254,6 +1258,7 @@ def fork(
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
range_restriction: Union[Keep, Optional[RangeRestriction]] = KEEP,
keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP,
auto_schedule_run: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
150 changes: 141 additions & 9 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

# pylint: disable=import-error,no-member,no-name-in-module

import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from threading import Thread
from typing import Any, Awaitable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union

import numpy as np
from concrete.compiler import CompilationContext, LweSecretKey, Parameter
Expand All @@ -24,13 +27,40 @@
# pylint: enable=import-error,no-member,no-name-in-module


class ExecutionRt(NamedTuple):
class ExecutionRt:
"""
Runtime object class for execution.
"""

client: Client
server: Server
auto_schedule_run: bool
fhe_executor_pool: ThreadPoolExecutor
fhe_waiter_loop: asyncio.BaseEventLoop
fhe_waiter_thread: Thread # daemon thread

def __init__(self, client, server, auto_schedule_run):
self.client = client
self.server = server
self.auto_schedule_run = auto_schedule_run
if auto_schedule_run:
self.fhe_executor_pool = ThreadPoolExecutor()
self.fhe_waiter_loop = asyncio.new_event_loop()

def loop_thread():
asyncio.set_event_loop(self.fhe_waiter_loop)
self.fhe_waiter_loop.run_forever()

self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True)
self.fhe_waiter_thread.start()
else:
self.fhe_executor_pool = None
self.fhe_waiter_loop = None
self.fhe_waiter_thread = None

def __del__(self):
if self.fhe_waiter_loop:
self.fhe_waiter_loop.stop() # pragma: no cover


class SimulationRt(NamedTuple):
Expand Down Expand Up @@ -177,12 +207,12 @@ def encrypt(
return tuple(args) if len(args) > 1 else args[0] # type: ignore
return self.execution_runtime.val.client.encrypt(*args, function_name=self.name)

def run(
def run_sync(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...]]:
) -> Any:
"""
Evaluate the function.
Evaluate the function synchronuously.
Args:
*args (Value):
Expand All @@ -193,17 +223,115 @@ def run(
result(s) of evaluation
"""

return self._run(True, *args)

def run_async(
self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]]
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
"""
Evaluate the function asynchronuously.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Awaitable[Value], Awaitable[Tuple[Value, ...]]]:
result(s) a future of the evaluation
"""
if (
isinstance(self.execution_runtime.val, ExecutionRt)
and not self.execution_runtime.val.fhe_executor_pool
):
client = self.execution_runtime.val.client
server = self.execution_runtime.val.server
self.execution_runtime = Lazy(lambda: ExecutionRt(client, server, True))
self.execution_runtime.val.auto_schedule_run = False

return self._run(False, *args)

def run(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
"""
Evaluate the function.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
result(s) of evaluation or future of result(s) of evaluation if configured with async_run=True
"""
if isinstance(self.execution_runtime.val, ExecutionRt):
auto_schedule_run = self.execution_runtime.val.auto_schedule_run
else:
auto_schedule_run = False # pragma: no cover
return self._run(not auto_schedule_run, *args)

def _run(
self,
sync: bool,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
"""
Evaluate the function.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
result(s) of evaluation if sync=True else future of result(s) of evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return self._simulate_decrypt(self._simulate_run(*args)) # type: ignore
return self.execution_runtime.val.server.run(

assert isinstance(self.execution_runtime.val, ExecutionRt)

fhe_work = lambda *args: self.execution_runtime.val.server.run(
*args,
evaluation_keys=self.execution_runtime.val.client.evaluation_keys,
function_name=self.name,
)

def args_ready(args):
return [arg.result() if isinstance(arg, Future) else arg for arg in args]

if sync:
return fhe_work(*args_ready(args))

all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args)

fhe_work_future = lambda *args: self.execution_runtime.val.fhe_executor_pool.submit(
fhe_work, *args
)
if all_args_done:
return fhe_work_future(*args_ready(args)) # type: ignore

# waiting args to be ready with async coroutines
# it only required one thread to run unlimited waits vs unlimited sync threads
async def wait_async(arg):
if not isinstance(arg, Future):
return arg # pragma: no cover
if arg.done():
return arg.result() # pragma: no cover
return await asyncio.wrap_future(arg, loop=self.execution_runtime.val.fhe_waiter_loop)

async def args_ready_and_submit(*args):
args = [await wait_async(arg) for arg in args]
return await wait_async(fhe_work_future(*args))

run_async = args_ready_and_submit(*args)
return asyncio.run_coroutine_threadsafe(
run_async, self.execution_runtime.val.fhe_waiter_loop
) # type: ignore

def decrypt(
self,
*results: Union[Value, Tuple[Value, ...]],
self, *results: Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
"""
Decrypt result(s) of evaluation.
Expand All @@ -220,6 +348,8 @@ def decrypt(
if self.configuration.simulate_encrypt_run_decrypt:
return tuple(results) if len(results) > 1 else results[0] # type: ignore

assert isinstance(self.execution_runtime.val, ExecutionRt)
results = [res.result() if isinstance(res, Future) else res for res in results]
return self.execution_runtime.val.client.decrypt(*results, function_name=self.name)

def encrypt_run_decrypt(self, *args: Any) -> Any:
Expand Down Expand Up @@ -620,7 +750,9 @@ def init_execution():
execution_client = Client(
execution_server.client_specs, keyset_cache_directory, is_simulated=False
)
return ExecutionRt(execution_client, execution_server)
return ExecutionRt(
execution_client, execution_server, self.configuration.auto_schedule_run
)

self.execution_runtime = Lazy(init_execution)
if configuration.fhe_execution:
Expand Down
82 changes: 82 additions & 0 deletions frontends/concrete-python/tests/compilation/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import inspect
import tempfile
from concurrent.futures import Future
from pathlib import Path
from typing import Awaitable

import numpy as np
import pytest
Expand Down Expand Up @@ -955,3 +957,83 @@ def inc(x, y):
},
helpers.configuration().fork(),
)


class IncDec:
@fhe.module()
class Module:
@fhe.function({"x": "encrypted"})
def inc(x):
return fhe.refresh(x + 1)

@fhe.function({"x": "encrypted"})
def dec(x):
return fhe.refresh(x - 1)

precision = 4

inputset = list(range(1, 2**precision - 1))
to_compile = {"inc": inputset, "dec": inputset}


def test_run_async():
"""
Test `run_async` with `auto_schedule_run=False` configuration option.
"""

module = IncDec.Module.compile(IncDec.to_compile)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run_async(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run(a)
assert isinstance(b, type(encrypted_x))

result = module.inc.decrypt(b)
assert result == sample_x
del module


def test_run_sync():
"""
Test `run_sync` with `auto_schedule_run=True` configuration option.
"""

conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run_sync(a)
assert isinstance(b, type(encrypted_x))

result = module.inc.decrypt(b)
assert result == sample_x


def test_run_auto_schedule():
"""
Test `run` with `auto_schedule_run=True` configuration option.
"""

conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)

sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)

a = module.inc.run(encrypted_x)
assert isinstance(a, Future)

b = module.dec.run(a)
assert isinstance(b, Future)

result = module.inc.decrypt(b)
assert result == sample_x

0 comments on commit c292a8b

Please sign in to comment.