diff --git a/mosaic/comms/comms.py b/mosaic/comms/comms.py index 6b369b64..a12002e7 100644 --- a/mosaic/comms/comms.py +++ b/mosaic/comms/comms.py @@ -488,7 +488,7 @@ def __init__(self, uid, address, port, self._heartbeat_timeout = None self._heartbeat_attempts = 0 self._heartbeat_max_attempts = 5 - self._heartbeat_interval = 30 + self._heartbeat_interval = 1 self._shaken = False @@ -1587,10 +1587,18 @@ async def disconnect(self, sender_id, uid, notify=False): if notify is True: for connected_id, connection in self._send_socket.items(): + if connection.state == 'disconnected': + continue await self.send_async(connected_id, method='disconnect', uid=uid) + if 'node' in uid and uid in self._runtime._nodes: + node_index = self._runtime._nodes[uid].indices[0] + for worker in self._runtime.workers: + if worker.indices[0] == node_index: + await self.disconnect(sender_id, worker.uid, notify=notify) + async def handshake(self, uid, address, port): """ Start handshake with remote ``uid``, located at a certain ``address`` and ``port``. diff --git a/mosaic/core/__init__.py b/mosaic/core/__init__.py index 747e3b8c..427fa4df 100644 --- a/mosaic/core/__init__.py +++ b/mosaic/core/__init__.py @@ -1,3 +1,4 @@ +from .base import RuntimeDisconnectedError from .task import * from .tessera import * diff --git a/mosaic/core/base.py b/mosaic/core/base.py index 77954e11..9a17a496 100644 --- a/mosaic/core/base.py +++ b/mosaic/core/base.py @@ -6,7 +6,11 @@ from ..profile import profiler -__all__ = ['RemoteBase', 'ProxyBase'] +__all__ = ['RemoteBase', 'ProxyBase', 'RuntimeDisconnectedError'] + + +class RuntimeDisconnectedError(Exception): + pass class Base: @@ -84,6 +88,17 @@ async def __init_async__(self, *args, **kwargs): async def init(self, *args, **kwargs): pass + def deregister_runtime(self, uid): + if uid != self.runtime_id: + return + + if self._init_future.done(): + self._init_future = Future() + + self.init_future.set_exception( + RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid) + ) + def __repr__(self): NotImplementedError('Unimplemented Base method __repr__') diff --git a/mosaic/core/task.py b/mosaic/core/task.py index 5e5c5faa..7536c76d 100644 --- a/mosaic/core/task.py +++ b/mosaic/core/task.py @@ -8,7 +8,7 @@ from cached_property import cached_property from .. import types -from .base import Base, RemoteBase, ProxyBase +from .base import Base, RemoteBase, ProxyBase, RuntimeDisconnectedError from ..utils import Future, MultiError @@ -482,6 +482,24 @@ async def init(self): if self._state == 'init': self.state_changed('queued') + def deregister_runtime(self, uid): + if uid != self.runtime_id: + return + + super().deregister_runtime(uid) + + self.state_changed('failed') + + try: + self._done_future.set_exception( + RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid) + ) + except asyncio.InvalidStateError: + pass + else: + # Once done release local copy of the arguments + self._cleanup() + @cached_property def runtime_id(self): """ diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 52eca70a..b8b3f7c1 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -207,10 +207,10 @@ async def init_cluster(self, **kwargs): cpu_mask = _cpu_mask(1, 1, num_cpus) cmd = (f'srun {ssh_flags} --nodes=1 --ntasks=1 --tasks-per-node={num_cpus} ' - f'--cpu-bind=mask_cpu:{cpu_mask} ' + f'--cpu-bind=mask_cpu:{cpu_mask} --mem-bind=local ' f'--oversubscribe ' f'--distribution=block:block ' - f'--hint=nomultithread ' + f'--hint=nomultithread --no-kill ' f'--nodelist={node_address} ' f'{remote_cmd}') diff --git a/mosaic/runtime/node.py b/mosaic/runtime/node.py index f90a731a..90d265d9 100644 --- a/mosaic/runtime/node.py +++ b/mosaic/runtime/node.py @@ -143,8 +143,7 @@ def start_worker(*args, **extra_kwargs): worker_proxy = RuntimeProxy(name='worker', indices=indices) worker_subprocess = subprocess(start_worker)(name=worker_proxy.uid, daemon=False, - cpu_affinity=worker_cpus.get(worker_index, None), - mem_affinity=worker_nodes.get(worker_index, None)) + cpu_affinity=worker_cpus.get(worker_index, None)) worker_subprocess.start_process() worker_proxy.subprocess = worker_subprocess diff --git a/mosaic/runtime/runtime.py b/mosaic/runtime/runtime.py index 52d982e0..4cae8d3e 100644 --- a/mosaic/runtime/runtime.py +++ b/mosaic/runtime/runtime.py @@ -13,7 +13,7 @@ from ..utils import SpillBuffer from ..utils.event_loop import EventLoop from ..comms import CommsManager -from ..core import Task +from ..core import Task, RuntimeDisconnectedError from ..profile import profiler, global_profiler from ..utils.utils import memory_limit, cpu_count @@ -262,11 +262,17 @@ async def barrier(self, timeout=None): def async_for(self, *iterables, **kwargs): assert not self._inside_async_for - safe = kwargs.pop('safe', False) + safe = kwargs.pop('safe', True) + timeout = kwargs.pop('timeout', None) + max_await = kwargs.pop('max_await', None) async def _async_for(func): self._inside_async_for = True + available_workers = self.num_workers + if available_workers <= 0: + raise RuntimeError('No workers available to complete async workload') + worker_queue = asyncio.Queue() for worker in self._workers.values(): await worker_queue.put(worker) @@ -279,8 +285,31 @@ async def call(*iters): return res - tasks = [call(*each) for each in zip(*iterables)] - gather = await asyncio.gather(*tasks) + tasks = [asyncio.create_task(call(*each)) for each in zip(*iterables)] + + gather = [] + for task in asyncio.as_completed(tasks, timeout=timeout): + try: + res = await task + gather.append(res) + except Exception as exc: + if safe: + self.logger.info('Runtime failed, retiring worker: %s' % exc) + available_workers -= 1 + if available_workers <= 0: + for other_task in tasks: + other_task.cancel() + with contextlib.suppress(RuntimeDisconnectedError, asyncio.CancelledError): + await other_task + raise RuntimeError('No workers available to complete async workload') + else: + raise + + if max_await is not None and len(gather) > max_await: + for other_task in tasks: + other_task.close() + break + await self.barrier() self._inside_async_for = False @@ -292,17 +321,8 @@ async def call(*iters): @contextlib.asynccontextmanager async def _exclusive_proxy(self, queue, safe=False): proxy = await queue.get() - - try: - yield proxy - except Exception as exc: - if safe: - self.logger.info('Runtime %s failed, retiring runtime:\n\n%s' - % (proxy.uid, exc)) - else: - raise - else: - await queue.put(proxy) + yield proxy + await queue.put(proxy) @property def address(self): @@ -584,6 +604,27 @@ def proxy_from_uid(self, uid, proxy=None): else: return found_proxy + def remove_proxy_from_uid(self, uid, proxy=None): + """ + Remove a proxy from a UID. + + Parameters + ---------- + uid : str + proxy : BaseProxy + + Returns + ------- + + """ + proxy = proxy or self.proxy(uid=uid) + + if hasattr(self, '_' + proxy.name + 's'): + del getattr(self, '_' + proxy.name + 's')[uid] + + elif hasattr(self, '_' + proxy.name): + setattr(self, '_' + proxy.name, None) + @staticmethod def proxy(name=None, indices=(), uid=None): """ @@ -771,7 +812,28 @@ def disconnect(self, sender_id, uid): ------- """ - pass + # deregister if remote uid held a proxy to a local tessera + for obj in self._tessera.values(): + obj.deregister_proxy(uid) + + # deregister if remote uid held a proxy to a local task + for obj in self._task.values(): + obj.deregister_proxy(uid) + + # deregister if local tessera proxy points to remote uid + for obj in self._tessera_proxy.values(): + obj.deregister_runtime(uid) + + # deregister if local tessera proxy array points to remote uid + for obj in self._tessera_proxy_array.values(): + obj.deregister_runtime(uid) + + # deregister if local task proxy points to remote uid + for obj in self._task_proxy.values(): + obj.deregister_runtime(uid) + + # remove remote runtime from local runtime + self.remove_proxy_from_uid(uid) async def stop(self, sender_id=None): """ diff --git a/mosaic/utils/subprocess.py b/mosaic/utils/subprocess.py index a85dbd0d..9c7a2c2a 100644 --- a/mosaic/utils/subprocess.py +++ b/mosaic/utils/subprocess.py @@ -370,10 +370,14 @@ def mem_affinity(self, nodes): """ try: import numa + from numa import LIBNUMA except Exception: return numa.memory.set_membind_nodes(*nodes) + op_res = LIBNUMA.numa_set_bind_policy(0) + if op_res == -1: + raise Exception('set_bind_policy failed') def subprocess(target): diff --git a/stride/__init__.py b/stride/__init__.py index 7afd0a01..62c9664a 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -84,7 +84,7 @@ async def forward(problem, pde, *args, **kwargs): dump = kwargs.pop('dump', True) shot_ids = kwargs.pop('shot_ids', None) deallocate = kwargs.pop('deallocate', False) - safe = kwargs.pop('safe', False) + safe = kwargs.pop('safe', True) if dump is True: try: @@ -206,7 +206,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa restart_id = kwargs.pop('restart_id', -1) dump = kwargs.pop('dump', True) - safe = kwargs.pop('safe', False) + safe = kwargs.pop('safe', True) f_min = kwargs.pop('f_min', None) f_max = kwargs.pop('f_max', None)