Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests on macOS #585

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
26 changes: 20 additions & 6 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@ on:
pull_request:

jobs:
run_tests:

runs-on: ubuntu-latest
run_tests:
strategy:
matrix:
python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11' ]
include:
- { os: 'ubuntu', python-version: '3.8' }
- { os: 'ubuntu', python-version: '3.9' }
- { os: 'ubuntu', python-version: '3.10' }
- { os: 'ubuntu', python-version: '3.11' }
- { os: 'macos', python-version: '3.8' }
- { os: 'macos', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v3
Expand All @@ -38,8 +45,15 @@ jobs:
- name: Test
run: |
cd tests
export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
pytest --durations=0 --durations-min=1.0 -v
export HIVEMIND_MEMORY_SHARING_STRATEGY=${{ matrix.os == 'ubuntu' && 'file_descriptor' || 'file_system' }}

ulimit -n 8192
export no_proxy=* # See https://github.com/kevlened/pytest-parallel/issues/93#issuecomment-839913651
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
export OMP_NUM_THREADS=1

pytest test_dht* --durations=0 --durations-min=1.0 -v

build_and_test_p2pd:
runs-on: ubuntu-latest
timeout-minutes: 10
Expand Down Expand Up @@ -71,8 +85,8 @@ jobs:
cd tests
export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
pytest -k "p2p" -v
codecov_in_develop_mode:

codecov_in_develop_mode:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
Expand Down
5 changes: 3 additions & 2 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
enter_asynchronously,
switch_to_uvloop,
)
from hivemind.utils.compat import safe_recv
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
Expand Down Expand Up @@ -313,7 +314,7 @@ async def _run():
if not self._inner_pipe.poll():
continue
try:
method, args, kwargs = self._inner_pipe.recv()
method, args, kwargs = safe_recv(self._inner_pipe)
except (OSError, ConnectionError, RuntimeError) as e:
logger.exception(e)
await asyncio.sleep(self.request_timeout)
Expand Down Expand Up @@ -774,7 +775,7 @@ def _background_thread_fetch_current_state(
"""
while True:
try:
trigger, future = pipe.recv()
trigger, future = safe_recv(pipe)
except BaseException as e:
logger.debug(f"Averager background thread finished: {repr(e)}")
break
Expand Down
3 changes: 2 additions & 1 deletion hivemind/dht/dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.p2p import P2P, PeerID
from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
from hivemind.utils.compat import safe_recv
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration

logger = get_logger(__name__)
Expand Down Expand Up @@ -126,7 +127,7 @@ async def _run():
if not self._inner_pipe.poll():
continue
try:
method, args, kwargs = self._inner_pipe.recv()
method, args, kwargs = safe_recv(self._inner_pipe)
except (OSError, ConnectionError, RuntimeError) as e:
logger.exception(e)
await asyncio.sleep(self._node.protocol.wait_timeout)
Expand Down
5 changes: 3 additions & 2 deletions hivemind/moe/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from hivemind.moe.server.runtime import Runtime
from hivemind.p2p import PeerInfo
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.compat import safe_recv
from hivemind.utils.logging import get_logger
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor

Expand Down Expand Up @@ -314,7 +315,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
runner.start()
# once the server is ready, runner will send us
# either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
start_ok, data = pipe.recv()
start_ok, data = safe_recv(pipe)
if start_ok:
yield data
pipe.send("SHUTDOWN") # on exit from context, send shutdown signal
Expand All @@ -339,7 +340,7 @@ def _server_runner(pipe, *args, **kwargs):
try:
dht_maddrs = server.dht.get_visible_maddrs()
pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
pipe.recv() # wait for shutdown signal
safe_recv(pipe) # wait for shutdown signal

finally:
logger.info("Shutting down server...")
Expand Down
5 changes: 3 additions & 2 deletions hivemind/moe/server/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

from hivemind.utils import get_logger
from hivemind.utils.compat import safe_recv
from hivemind.utils.mpfuture import InvalidStateError, MPFuture

logger = get_logger(__name__)
Expand Down Expand Up @@ -195,7 +196,7 @@ def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):

while True:
logger.debug(f"{self.name} waiting for results from runtime")
batch_index, batch_outputs_or_exception = self.outputs_receiver.recv()
batch_index, batch_outputs_or_exception = safe_recv(self.outputs_receiver)
batch_tasks = pending_batches.pop(batch_index)

if isinstance(batch_outputs_or_exception, BaseException):
Expand Down Expand Up @@ -234,7 +235,7 @@ def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[to
if not self.batch_receiver.poll(timeout):
raise TimeoutError()

batch_index, batch_inputs = self.batch_receiver.recv()
batch_index, batch_inputs = safe_recv(self.batch_receiver)
batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
return batch_index, batch_inputs

Expand Down
23 changes: 23 additions & 0 deletions hivemind/utils/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import multiprocessing as mp
import time
from typing import Any

from hivemind.utils.logging import get_logger

logger = get_logger(__name__)


def safe_recv(pipe: mp.connection.Connection) -> Any:
# Needed for macOS, see https://github.com/urllib3/urllib3/issues/63#issuecomment-4609289

while True:
try:
return pipe.recv()
except Exception as e:
if (isinstance(e, BlockingIOError) and str(e) == "[Errno 35] Resource temporarily unavailable") or (
isinstance(e, EOFError) and str(e) == "Ran out of input"
):
logger.warning(repr(e))
time.sleep(0)
continue
raise
3 changes: 2 additions & 1 deletion hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch # used for py3.7-compatible shared memory

from hivemind.utils.compat import safe_recv
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -182,7 +183,7 @@ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection)
if cls._pipe_waiter_thread is not threading.current_thread():
break # backend was reset, a new background thread has started

uid, update_type, payload = receiver_pipe.recv()
uid, update_type, payload = safe_recv(receiver_pipe)
future = None
future_ref = cls._active_futures.pop(uid, None)
if future_ref is not None:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_dht_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.node import DHTNode
from hivemind.dht.validation import DHTRecord
from hivemind.utils.compat import safe_recv
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.timed_storage import get_dht_time

Expand Down Expand Up @@ -79,8 +80,8 @@ def test_validator_instance_is_picklable():


def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
validator = conn.recv()
record = conn.recv()
validator = safe_recv(conn)
record = safe_recv(conn)

record = dataclasses.replace(record, value=validator.sign_value(record))

Expand All @@ -101,7 +102,7 @@ def test_signing_in_different_process():
)
parent_conn.send(record)

signed_record = parent_conn.recv()
signed_record = safe_recv(parent_conn)
assert b"[signature:" in signed_record.value
assert validator.validate(signed_record)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_dht_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hivemind.dht import DHTID
from hivemind.dht.protocol import DHTProtocol
from hivemind.dht.storage import DictionaryDHTValue
from hivemind.utils.compat import safe_recv

logger = get_logger(__name__)

Expand Down Expand Up @@ -56,7 +57,7 @@ def launch_protocol_listener(
dht_id = DHTID.generate()
process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
process.start()
peer_id, visible_maddrs = local_conn.recv()
peer_id, visible_maddrs = safe_recv(local_conn)

return dht_id, process, peer_id, visible_maddrs

Expand Down
5 changes: 3 additions & 2 deletions tests/test_p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
from hivemind.proto import dht_pb2, test_pb2
from hivemind.utils.compat import safe_recv
from hivemind.utils.serializer import MSGPackSerializer

from test_utils.networking import get_free_port
Expand Down Expand Up @@ -328,8 +329,8 @@ async def test_call_peer_different_processes():
proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
proc.start()

peer_id = client_side.recv()
peer_maddrs = client_side.recv()
peer_id = safe_recv(client_side)
peer_maddrs = safe_recv(client_side)

client = await P2P.create(initial_peers=peer_maddrs)
client_pid = client._child.pid
Expand Down
5 changes: 3 additions & 2 deletions tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
cancel_and_wait,
enter_asynchronously,
)
from hivemind.utils.compat import safe_recv
from hivemind.utils.mpfuture import InvalidStateError
from hivemind.utils.performance_ema import PerformanceEMA

Expand Down Expand Up @@ -260,7 +261,7 @@ def _check_result_and_set(future):
p = mp.Process(target=_future_creator)
p.start()

future1, future2 = receiver.recv()
future1, future2 = safe_recv(receiver)
future1.set_result(123)

with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -309,7 +310,7 @@ def _run_peer():
p = mp.Process(target=_run_peer)
p.start()

some_fork_futures = receiver.recv()
some_fork_futures = safe_recv(recv)

time.sleep(0.1) # giving enough time for the futures to be destroyed
assert len(hivemind.MPFuture._active_futures) == 700
Expand Down