Skip to content

Commit

Permalink
Add more robust failure tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed Oct 27, 2023
1 parent 45d9d7c commit e270b55
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 23 deletions.
10 changes: 9 additions & 1 deletion mosaic/comms/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def __init__(self, uid, address, port,
self._heartbeat_timeout = None
self._heartbeat_attempts = 0
self._heartbeat_max_attempts = 5
self._heartbeat_interval = 30
self._heartbeat_interval = 1

self._shaken = False

Expand Down Expand Up @@ -1587,10 +1587,18 @@ async def disconnect(self, sender_id, uid, notify=False):

if notify is True:
for connected_id, connection in self._send_socket.items():
if connection.state == 'disconnected':
continue
await self.send_async(connected_id,
method='disconnect',
uid=uid)

if 'node' in uid and uid in self._runtime._nodes:
node_index = self._runtime._nodes[uid].indices[0]
for worker in self._runtime.workers:
if worker.indices[0] == node_index:
await self.disconnect(sender_id, worker.uid, notify=notify)

async def handshake(self, uid, address, port):
"""
Start handshake with remote ``uid``, located at a certain ``address`` and ``port``.
Expand Down
1 change: 1 addition & 0 deletions mosaic/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from .base import RuntimeDisconnectedError
from .task import *
from .tessera import *
17 changes: 16 additions & 1 deletion mosaic/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from ..profile import profiler


__all__ = ['RemoteBase', 'ProxyBase']
__all__ = ['RemoteBase', 'ProxyBase', 'RuntimeDisconnectedError']


class RuntimeDisconnectedError(Exception):
pass


class Base:
Expand Down Expand Up @@ -84,6 +88,17 @@ async def __init_async__(self, *args, **kwargs):
async def init(self, *args, **kwargs):
pass

def deregister_runtime(self, uid):
if uid != self.runtime_id:
return

if self._init_future.done():
self._init_future = Future()

self.init_future.set_exception(
RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid)
)

def __repr__(self):
NotImplementedError('Unimplemented Base method __repr__')

Expand Down
20 changes: 19 additions & 1 deletion mosaic/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cached_property import cached_property

from .. import types
from .base import Base, RemoteBase, ProxyBase
from .base import Base, RemoteBase, ProxyBase, RuntimeDisconnectedError
from ..utils import Future, MultiError


Expand Down Expand Up @@ -482,6 +482,24 @@ async def init(self):
if self._state == 'init':
self.state_changed('queued')

def deregister_runtime(self, uid):
if uid != self.runtime_id:
return

super().deregister_runtime(uid)

self.state_changed('failed')

try:
self._done_future.set_exception(
RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid)
)
except asyncio.InvalidStateError:
pass
else:
# Once done release local copy of the arguments
self._cleanup()

@cached_property
def runtime_id(self):
"""
Expand Down
3 changes: 1 addition & 2 deletions mosaic/runtime/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def start_worker(*args, **extra_kwargs):
worker_proxy = RuntimeProxy(name='worker', indices=indices)
worker_subprocess = subprocess(start_worker)(name=worker_proxy.uid,
daemon=False,
cpu_affinity=worker_cpus.get(worker_index, None),
mem_affinity=worker_nodes.get(worker_index, None))
cpu_affinity=worker_cpus.get(worker_index, None))
worker_subprocess.start_process()
worker_proxy.subprocess = worker_subprocess

Expand Down
94 changes: 78 additions & 16 deletions mosaic/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..utils import SpillBuffer
from ..utils.event_loop import EventLoop
from ..comms import CommsManager
from ..core import Task
from ..core import Task, RuntimeDisconnectedError
from ..profile import profiler, global_profiler
from ..utils.utils import memory_limit, cpu_count

Expand Down Expand Up @@ -262,11 +262,17 @@ async def barrier(self, timeout=None):
def async_for(self, *iterables, **kwargs):
assert not self._inside_async_for

safe = kwargs.pop('safe', False)
safe = kwargs.pop('safe', True)
timeout = kwargs.pop('timeout', None)
max_await = kwargs.pop('max_await', None)

async def _async_for(func):
self._inside_async_for = True

available_workers = self.num_workers
if available_workers <= 0:
raise RuntimeError('No workers available to complete async workload')

worker_queue = asyncio.Queue()
for worker in self._workers.values():
await worker_queue.put(worker)
Expand All @@ -279,8 +285,31 @@ async def call(*iters):

return res

tasks = [call(*each) for each in zip(*iterables)]
gather = await asyncio.gather(*tasks)
tasks = [asyncio.create_task(call(*each)) for each in zip(*iterables)]

gather = []
for task in asyncio.as_completed(tasks, timeout=timeout):
try:
res = await task
gather.append(res)
except Exception as exc:
if safe:
self.logger.info('Runtime failed, retiring worker: %s' % exc)
available_workers -= 1
if available_workers <= 0:
for other_task in tasks:
other_task.cancel()
with contextlib.suppress(RuntimeDisconnectedError, asyncio.CancelledError):
await other_task
raise RuntimeError('No workers available to complete async workload')
else:
raise

if max_await is not None and len(gather) > max_await:
for other_task in tasks:
other_task.close()
break

await self.barrier()

self._inside_async_for = False
Expand All @@ -292,17 +321,8 @@ async def call(*iters):
@contextlib.asynccontextmanager
async def _exclusive_proxy(self, queue, safe=False):
proxy = await queue.get()

try:
yield proxy
except Exception as exc:
if safe:
self.logger.info('Runtime %s failed, retiring runtime:\n\n%s'
% (proxy.uid, exc))
else:
raise
else:
await queue.put(proxy)
yield proxy
await queue.put(proxy)

@property
def address(self):
Expand Down Expand Up @@ -584,6 +604,27 @@ def proxy_from_uid(self, uid, proxy=None):
else:
return found_proxy

def remove_proxy_from_uid(self, uid, proxy=None):
"""
Remove a proxy from a UID.
Parameters
----------
uid : str
proxy : BaseProxy
Returns
-------
"""
proxy = proxy or self.proxy(uid=uid)

if hasattr(self, '_' + proxy.name + 's'):
del getattr(self, '_' + proxy.name + 's')[uid]

elif hasattr(self, '_' + proxy.name):
setattr(self, '_' + proxy.name, None)

@staticmethod
def proxy(name=None, indices=(), uid=None):
"""
Expand Down Expand Up @@ -771,7 +812,28 @@ def disconnect(self, sender_id, uid):
-------
"""
pass
# deregister if remote uid held a proxy to a local tessera
for obj in self._tessera.values():
obj.deregister_proxy(uid)

# deregister if remote uid held a proxy to a local task
for obj in self._task.values():
obj.deregister_proxy(uid)

# deregister if local tessera proxy points to remote uid
for obj in self._tessera_proxy.values():
obj.deregister_runtime(uid)

# deregister if local tessera proxy array points to remote uid
for obj in self._tessera_proxy_array.values():
obj.deregister_runtime(uid)

# deregister if local task proxy points to remote uid
for obj in self._task_proxy.values():
obj.deregister_runtime(uid)

# remove remote runtime from local runtime
self.remove_proxy_from_uid(uid)

async def stop(self, sender_id=None):
"""
Expand Down
4 changes: 2 additions & 2 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def forward(problem, pde, *args, **kwargs):
dump = kwargs.pop('dump', True)
shot_ids = kwargs.pop('shot_ids', None)
deallocate = kwargs.pop('deallocate', False)
safe = kwargs.pop('safe', False)
safe = kwargs.pop('safe', True)

if dump is True:
try:
Expand Down Expand Up @@ -206,7 +206,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
restart_id = kwargs.pop('restart_id', -1)

dump = kwargs.pop('dump', True)
safe = kwargs.pop('safe', False)
safe = kwargs.pop('safe', True)

f_min = kwargs.pop('f_min', None)
f_max = kwargs.pop('f_max', None)
Expand Down

0 comments on commit e270b55

Please sign in to comment.