Skip to content

Commit

Permalink
Try safe_recv()
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Aug 24, 2023
1 parent 6dc02d7 commit 5003c16
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 16 deletions.
5 changes: 3 additions & 2 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
enter_asynchronously,
switch_to_uvloop,
)
from hivemind.utils.compat import safe_recv
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
Expand Down Expand Up @@ -313,7 +314,7 @@ async def _run():
if not self._inner_pipe.poll():
continue
try:
method, args, kwargs = self._inner_pipe.recv()
method, args, kwargs = safe_recv(self._inner_pipe)
except (OSError, ConnectionError, RuntimeError) as e:
logger.exception(e)
await asyncio.sleep(self.request_timeout)
Expand Down Expand Up @@ -774,7 +775,7 @@ def _background_thread_fetch_current_state(
"""
while True:
try:
trigger, future = pipe.recv()
trigger, future = safe_recv(pipe)
except BaseException as e:
logger.debug(f"Averager background thread finished: {repr(e)}")
break
Expand Down
3 changes: 2 additions & 1 deletion hivemind/dht/dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.p2p import P2P, PeerID
from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
from hivemind.utils.compat import safe_recv
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration

logger = get_logger(__name__)
Expand Down Expand Up @@ -126,7 +127,7 @@ async def _run():
if not self._inner_pipe.poll():
continue
try:
method, args, kwargs = self._inner_pipe.recv()
method, args, kwargs = safe_recv(self._inner_pipe)
except (OSError, ConnectionError, RuntimeError) as e:
logger.exception(e)
await asyncio.sleep(self._node.protocol.wait_timeout)
Expand Down
5 changes: 3 additions & 2 deletions hivemind/moe/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from hivemind.moe.server.runtime import Runtime
from hivemind.p2p import PeerInfo
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.compat import safe_recv
from hivemind.utils.logging import get_logger
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor

Expand Down Expand Up @@ -314,7 +315,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
runner.start()
# once the server is ready, runner will send us
# either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
start_ok, data = pipe.recv()
start_ok, data = safe_recv(pipe)
if start_ok:
yield data
pipe.send("SHUTDOWN") # on exit from context, send shutdown signal
Expand All @@ -339,7 +340,7 @@ def _server_runner(pipe, *args, **kwargs):
try:
dht_maddrs = server.dht.get_visible_maddrs()
pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
pipe.recv() # wait for shutdown signal
safe_recv(pipe) # wait for shutdown signal

finally:
logger.info("Shutting down server...")
Expand Down
5 changes: 3 additions & 2 deletions hivemind/moe/server/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

from hivemind.utils import get_logger
from hivemind.utils.compat import safe_recv
from hivemind.utils.mpfuture import InvalidStateError, MPFuture

logger = get_logger(__name__)
Expand Down Expand Up @@ -195,7 +196,7 @@ def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):

while True:
logger.debug(f"{self.name} waiting for results from runtime")
batch_index, batch_outputs_or_exception = self.outputs_receiver.recv()
batch_index, batch_outputs_or_exception = safe_recv(self.outputs_receiver)
batch_tasks = pending_batches.pop(batch_index)

if isinstance(batch_outputs_or_exception, BaseException):
Expand Down Expand Up @@ -234,7 +235,7 @@ def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[to
if not self.batch_receiver.poll(timeout):
raise TimeoutError()

batch_index, batch_inputs = self.batch_receiver.recv()
batch_index, batch_inputs = safe_recv(self.batch_receiver)
batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
return batch_index, batch_inputs

Expand Down
3 changes: 2 additions & 1 deletion hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch # used for py3.7-compatible shared memory

from hivemind.utils.compat import safe_recv
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -182,7 +183,7 @@ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection)
if cls._pipe_waiter_thread is not threading.current_thread():
break # backend was reset, a new background thread has started

uid, update_type, payload = receiver_pipe.recv()
uid, update_type, payload = safe_recv(receiver_pipe)
future = None
future_ref = cls._active_futures.pop(uid, None)
if future_ref is not None:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_dht_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.node import DHTNode
from hivemind.dht.validation import DHTRecord
from hivemind.utils.compat import safe_recv
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.timed_storage import get_dht_time

Expand Down Expand Up @@ -79,8 +80,8 @@ def test_validator_instance_is_picklable():


def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
validator = conn.recv()
record = conn.recv()
validator = safe_recv(conn)
record = safe_recv(conn)

record = dataclasses.replace(record, value=validator.sign_value(record))

Expand All @@ -101,7 +102,7 @@ def test_signing_in_different_process():
)
parent_conn.send(record)

signed_record = parent_conn.recv()
signed_record = safe_recv(parent_conn)
assert b"[signature:" in signed_record.value
assert validator.validate(signed_record)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_dht_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hivemind.dht import DHTID
from hivemind.dht.protocol import DHTProtocol
from hivemind.dht.storage import DictionaryDHTValue
from hivemind.utils.compat import safe_recv

logger = get_logger(__name__)

Expand Down Expand Up @@ -56,7 +57,7 @@ def launch_protocol_listener(
dht_id = DHTID.generate()
process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
process.start()
peer_id, visible_maddrs = local_conn.recv()
peer_id, visible_maddrs = safe_recv(local_conn)

return dht_id, process, peer_id, visible_maddrs

Expand Down
5 changes: 3 additions & 2 deletions tests/test_p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
from hivemind.proto import dht_pb2, test_pb2
from hivemind.utils.compat import safe_recv
from hivemind.utils.serializer import MSGPackSerializer

from test_utils.networking import get_free_port
Expand Down Expand Up @@ -328,8 +329,8 @@ async def test_call_peer_different_processes():
proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
proc.start()

peer_id = client_side.recv()
peer_maddrs = client_side.recv()
peer_id = safe_recv(client_side)
peer_maddrs = safe_recv(client_side)

client = await P2P.create(initial_peers=peer_maddrs)
client_pid = client._child.pid
Expand Down
5 changes: 3 additions & 2 deletions tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
from hivemind.utils.compat import safe_recv
from hivemind.utils.asyncio import (
achain,
aenumerate,
Expand Down Expand Up @@ -260,7 +261,7 @@ def _check_result_and_set(future):
p = mp.Process(target=_future_creator)
p.start()

future1, future2 = receiver.recv()
future1, future2 = safe_recv(receiver)
future1.set_result(123)

with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -309,7 +310,7 @@ def _run_peer():
p = mp.Process(target=_run_peer)
p.start()

some_fork_futures = receiver.recv()
some_fork_futures = safe_recv(recv)

time.sleep(0.1) # giving enough time for the futures to be destroyed
assert len(hivemind.MPFuture._active_futures) == 700
Expand Down

0 comments on commit 5003c16

Please sign in to comment.