diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index ad3df8f66a..180cee5604 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -2,6 +2,7 @@ import asyncio import functools +import time from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from typing import Any, Generic, cast @@ -425,6 +426,7 @@ async def observe_value( signal: SignalR[SignalDatatypeT], timeout: float | None = None, done_status: Status | None = None, + done_timeout: float | None = None, ) -> AsyncGenerator[SignalDatatypeT, None]: """Subscribe to the value of a signal so it can be iterated from. @@ -439,9 +441,17 @@ async def observe_value( done_status: If this status is complete, stop observing and make the iterator return. If it raises an exception then this exception will be raised by the iterator. + done_timeout: + If given, the maximum time to watch a signal, in seconds. If the loop is still + being watched after this length, raise asyncio.TimeoutError. This should be used + instead of on an 'asyncio.wait_for' timeout Notes ----- + Due to a rare condition with busy signals, it is not recommended to use this + function with asyncio.timeout, including in an 'asyncio.wait_for' loop. Instead, + this timeout should be given to the done_timeout parameter. + Example usage:: async for value in observe_value(sig): @@ -449,15 +459,26 @@ async def observe_value( """ async for _, value in observe_signals_value( - signal, timeout=timeout, done_status=done_status + signal, + timeout=timeout, + done_status=done_status, + done_timeout=done_timeout, ): yield value +def _get_iteration_timeout( + timeout: float | None, overall_deadline: float | None +) -> float | None: + overall_deadline = overall_deadline - time.monotonic() if overall_deadline else None + return min([x for x in [overall_deadline, timeout] if x is not None], default=None) + + async def observe_signals_value( *signals: SignalR[SignalDatatypeT], timeout: float | None = None, done_status: Status | None = None, + done_timeout: float | None = None, ) -> AsyncGenerator[tuple[SignalR[SignalDatatypeT], SignalDatatypeT], None]: """Subscribe to the value of a signal so it can be iterated from. @@ -472,6 +493,10 @@ async def observe_signals_value( done_status: If this status is complete, stop observing and make the iterator return. If it raises an exception then this exception will be raised by the iterator. + done_timeout: + If given, the maximum time to watch a signal, in seconds. If the loop is still + being watched after this length, raise asyncio.TimeoutError. This should be used + instead of on an 'asyncio.wait_for' timeout Notes ----- @@ -486,12 +511,6 @@ async def observe_signals_value( q: asyncio.Queue[tuple[SignalR[SignalDatatypeT], SignalDatatypeT] | Status] = ( asyncio.Queue() ) - if timeout is None: - get_value = q.get - else: - - async def get_value(): - return await asyncio.wait_for(q.get(), timeout) cbs: dict[SignalR, Callback] = {} for signal in signals: @@ -504,13 +523,17 @@ def queue_value(value: SignalDatatypeT, signal=signal): if done_status is not None: done_status.add_callback(q.put_nowait) - + overall_deadline = time.monotonic() + done_timeout if done_timeout else None try: while True: - # yield here in case something else is filling the queue - # like in test_observe_value_times_out_with_no_external_task() - await asyncio.sleep(0) - item = await get_value() + if overall_deadline and time.monotonic() >= overall_deadline: + raise asyncio.TimeoutError( + f"observe_value was still observing signals " + f"{[signal.source for signal in signals]} after " + f"timeout {done_timeout}s" + ) + iteration_timeout = _get_iteration_timeout(timeout, overall_deadline) + item = await asyncio.wait_for(q.get(), iteration_timeout) if done_status and item is done_status: if exc := done_status.exception(): raise exc diff --git a/src/ophyd_async/epics/testing/test_records.db b/src/ophyd_async/epics/testing/test_records.db index fbd81b02c2..ff45a6350c 100644 --- a/src/ophyd_async/epics/testing/test_records.db +++ b/src/ophyd_async/epics/testing/test_records.db @@ -150,3 +150,9 @@ record(lsi, "$(device)longstr2") { field(INP, {const:"a string that is just longer than forty characters"}) field(PINI, "YES") } + +record(calc, "$(device)ticking") { + field(INPA, "$(device)ticking") + field(CALC, "A+1") + field(SCAN, ".1 second") +} diff --git a/src/ophyd_async/testing/__init__.py b/src/ophyd_async/testing/__init__.py new file mode 100644 index 0000000000..d3efd849bd --- /dev/null +++ b/src/ophyd_async/testing/__init__.py @@ -0,0 +1,22 @@ +import asyncio + + +async def wait_for_pending_wakeups(max_yields=20, raise_if_exceeded=True): + """Allow any ready asyncio tasks to be woken up. + + Used in: + + - Tests to allow tasks like ``set()`` to start so that signal + puts can be tested + - `observe_value` to allow it to be wrapped in `asyncio.wait_for` + with a timeout + """ + loop = asyncio.get_event_loop() + # If anything has called loop.call_soon or is scheduled a wakeup + # then let it run + for _ in range(max_yields): + await asyncio.sleep(0) + if not loop._ready: # type: ignore # noqa: SLF001 + return + if raise_if_exceeded: + raise RuntimeError(f"Tasks still scheduling wakeups after {max_yields} yields") diff --git a/tests/core/test_observe.py b/tests/core/test_observe.py index 14b9443ac2..50e0a167b5 100644 --- a/tests/core/test_observe.py +++ b/tests/core/test_observe.py @@ -60,14 +60,14 @@ async def tick(): recv = [] async def watch(): - async for val in observe_value(sig): + async for val in observe_value(sig, done_timeout=0.2): recv.append(val) t = asyncio.create_task(tick()) start = time.time() try: with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(watch(), timeout=0.2) + await watch() assert recv == [0, 1] assert time.time() - start == pytest.approx(0.2, abs=0.05) finally: @@ -85,7 +85,7 @@ async def tick(): recv = [] async def watch(): - async for val in observe_value(sig): + async for val in observe_value(sig, done_timeout=0.2): time.sleep(0.15) recv.append(val) @@ -93,7 +93,7 @@ async def watch(): start = time.time() try: with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(watch(), timeout=0.2) + await watch() assert recv == [0, 1] assert time.time() - start == pytest.approx(0.3, abs=0.05) finally: @@ -105,13 +105,26 @@ async def test_observe_value_times_out_with_no_external_task(): recv = [] - async def watch(): - async for val in observe_value(sig): + async def watch(done_timeout): + async for val in observe_value(sig, done_timeout=done_timeout): recv.append(val) setter(val + 1) start = time.time() with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(watch(), timeout=0.1) + await watch(done_timeout=0.1) assert recv assert time.time() - start == pytest.approx(0.1, abs=0.05) + + +async def test_observe_value_uses_correct_timeout(): + sig, _ = soft_signal_r_and_setter(float) + + async def watch(timeout, done_timeout): + async for _ in observe_value(sig, timeout, done_timeout=done_timeout): + ... + + start = time.time() + with pytest.raises(asyncio.TimeoutError): + await watch(timeout=0.3, done_timeout=0.15) + assert time.time() - start == pytest.approx(0.15, abs=0.05) diff --git a/tests/epics/adcore/test_drivers.py b/tests/epics/adcore/test_drivers.py index f2b2dcadc9..6f006ba0b7 100644 --- a/tests/epics/adcore/test_drivers.py +++ b/tests/epics/adcore/test_drivers.py @@ -70,7 +70,7 @@ async def test_start_acquiring_driver_and_ensure_status_flags_immediate_failure( ): set_mock_value(driver.detector_state, adcore.DetectorState.ERROR) acquiring = await adcore.start_acquiring_driver_and_ensure_status( - driver, timeout=0.01 + driver, timeout=0.05 ) with pytest.raises(ValueError): await acquiring diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 333d1a78a7..fba02106c5 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -21,10 +21,7 @@ set_mock_value, ) from ophyd_async.epics import demo - -# Long enough for multiple asyncio event loop cycles to run so -# all the tasks have a chance to run -A_WHILE = 0.001 +from ophyd_async.testing import wait_for_pending_wakeups @pytest.fixture @@ -141,7 +138,7 @@ async def test_mover_moving_well(mock_mover: demo.Mover) -> None: time_elapsed=pytest.approx(0.1, abs=0.05), ) set_mock_value(mock_mover.readback, 0.5499999) - await asyncio.sleep(A_WHILE) + await wait_for_pending_wakeups() assert s.done assert s.success done.assert_called_once_with(s) diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index b44221b192..12295aac99 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -26,6 +26,7 @@ T, Table, load_from_yaml, + observe_value, save_to_yaml, ) from ophyd_async.epics.core import ( @@ -36,7 +37,9 @@ epics_signal_w, epics_signal_x, ) -from ophyd_async.epics.core._signal import _epics_signal_backend # noqa: PLC2701 +from ophyd_async.epics.core._signal import ( + _epics_signal_backend, # noqa: PLC2701 +) from ophyd_async.epics.testing import ( ExampleCaDevice, ExampleEnum, @@ -932,3 +935,25 @@ def my_plan(): def test_signal_module_emits_deprecation_warning(): with pytest.deprecated_call(): import ophyd_async.epics.signal # noqa: F401 + + +@PARAMETERISE_PROTOCOLS +async def test_observe_ticking_signal_with_busy_loop(ioc, protocol): + sig = epics_signal_rw(int, f"{protocol}://{get_prefix(ioc, protocol)}ticking") + await sig.connect() + + recv = [] + + async def watch(): + async for val in observe_value(sig, done_timeout=0.4): + time.sleep(0.3) + recv.append(val) + + start = time.time() + + with pytest.raises(asyncio.TimeoutError): + await watch() + assert time.time() - start == pytest.approx(0.6, abs=0.1) + assert len(recv) == 2 + # Don't check values as CA and PVA have different algorithms for + # dropping updates for slow callbacks diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index 055d49cb31..01f34fd785 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -17,17 +17,7 @@ soft_signal_rw, ) from ophyd_async.epics import motor - - -async def wait_for_wakeups(max_yields=10): - loop = asyncio.get_event_loop() - # If anything has called loop.call_soon or is scheduled a wakeup - # then let it run - for _ in range(max_yields): - await asyncio.sleep(0) - if not loop._ready: - return - raise RuntimeError(f"Tasks still scheduling wakeups after {max_yields} yields") +from ophyd_async.testing import wait_for_pending_wakeups @pytest.fixture @@ -44,7 +34,7 @@ async def sim_motor(): async def wait_for_eq(item, attribute, comparison, timeout): timeout_time = time.monotonic() + timeout while getattr(item, attribute) != comparison: - await wait_for_wakeups() + await wait_for_pending_wakeups() if time.monotonic() > timeout_time: raise TimeoutError @@ -56,7 +46,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: s.watch(watcher) done = Mock() s.add_callback(done) - await wait_for_wakeups() + await wait_for_pending_wakeups() await wait_for_eq(watcher, "call_count", 1, 1) assert watcher.call_args == call( name="sim_motor", @@ -86,7 +76,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: set_mock_value(sim_motor.motor_done_move, True) set_mock_value(sim_motor.user_readback, 0.55) set_mock_put_proceeds(sim_motor.user_setpoint, True) - await wait_for_wakeups() + await wait_for_pending_wakeups() await wait_for_eq(s, "done", True, 1) done.assert_called_once_with(s) @@ -98,7 +88,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: s.watch(watcher) done = Mock() s.add_callback(done) - await wait_for_wakeups() + await wait_for_pending_wakeups() assert watcher.call_count == 1 assert watcher.call_args == call( name="sim_motor", @@ -126,7 +116,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: time_elapsed=pytest.approx(0.1, abs=0.2), ) set_mock_put_proceeds(sim_motor.user_setpoint, True) - await wait_for_wakeups() + await wait_for_pending_wakeups() assert s.done done.assert_called_once_with(s) @@ -165,7 +155,7 @@ async def test_motor_moving_stopped(sim_motor: motor.Motor): assert not s.done await sim_motor.stop() set_mock_put_proceeds(sim_motor.user_setpoint, True) - await wait_for_wakeups() + await wait_for_pending_wakeups() assert s.done assert s.success is False