diff --git a/src/nextline_schedule/auto/callback.py b/src/nextline_schedule/auto/callback.py index 2560e21..547be6e 100644 --- a/src/nextline_schedule/auto/callback.py +++ b/src/nextline_schedule/auto/callback.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Coroutine from nextline import Nextline +from nextline.plugin.spec import Context, hookimpl from nextline_schedule.types import Statement @@ -16,9 +17,28 @@ def build_auto_mode_state_machine( callback = Callback(nextline=nextline, request_statement=request_statement) machine = AutoModeStateMachine(callback=callback) callback.auto_mode = machine + on_call = OnCall(auto_mode=machine) + nextline.register(on_call) return machine +class OnCall: + def __init__(self, auto_mode: AutoModeStateMachine): + self.auto_mode = auto_mode + + @hookimpl + async def on_initialize_run(self) -> None: + await self.auto_mode.on_initialized() # type: ignore + + @hookimpl + async def on_finished(self, context: Context) -> None: + nextline = context.nextline + if nextline.format_exception(): + await self.auto_mode.on_raised() # type: ignore + return + await self.auto_mode.on_finished() # type: ignore + + class Callback: def __init__( self, @@ -27,23 +47,15 @@ def __init__( ): self._nextline = nextline self._request_statement = request_statement - self._logger = getLogger(__name__) - self.auto_mode: AutoModeStateMachine # to be set async def wait(self) -> None: - try: - async for state in self._nextline.subscribe_state(): - if state == 'initialized': - await self.auto_mode.on_initialized() # type: ignore - break - if state == 'finished': - await self.auto_mode.on_finished() # type: ignore - break - except asyncio.CancelledError: - self._logger.info(f'{self.__class__.__name__}.wait() cancelled') - return + match self._nextline.state: + case 'initialized': + await self.auto_mode.on_initialized() # type: ignore + case 'finished': + await self.auto_mode.on_finished() # type: ignore async def pull(self) -> None: try: @@ -61,9 +73,5 @@ async def pull(self) -> None: async def run(self, started: asyncio.Event) -> None: try: await self._nextline.run_continue_and_wait(started) - if self._nextline.format_exception(): - await self.auto_mode.on_raised() # type: ignore - return - await self.auto_mode.on_finished() # type: ignore except asyncio.CancelledError: self._logger.info(f'{self.__class__.__name__}.run() cancelled') diff --git a/src/nextline_schedule/auto/factory.py b/src/nextline_schedule/auto/factory.py index 1fc2a22..5864b22 100644 --- a/src/nextline_schedule/auto/factory.py +++ b/src/nextline_schedule/auto/factory.py @@ -20,18 +20,19 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) -> .-------------->| Off |<--------------. | '-------------' | | | turn_on() | - turn_off() | | - | v | - | .-------------. | - |---------------| Waiting | on_raised() - | '-------------' | - | | on_initialized() | - | | on_finished() | + turn_off() | on_raised() + | | | | | | | .------------------+------------------. | | | Auto | | | | | v | | | | .-------------. | | + | | | Waiting | | | + | | '-------------' | | + | | | on_initialized() | | + | | | on_finished() | | + | | v | | + | | .-------------. | | | | | Pulling | | | '---| '-------------' |---' | run() | ^ | @@ -44,11 +45,11 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) -> '-------------------------------------' >>> class Model: - ... def on_enter_waiting(self): + ... def on_enter_auto_waiting(self): ... print('enter the waiting state') ... self.on_finished() ... - ... def on_exit_waiting(self): + ... def on_exit_auto_waiting(self): ... print('exit the waiting state') ... ... def on_enter_auto_pulling(self): @@ -88,14 +89,16 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) -> auto_state_conf = { 'name': 'auto', - 'states': ['pulling', 'running'], + 'states': ['waiting', 'pulling', 'running'], 'transitions': [ + ['on_initialized', 'waiting', 'pulling'], + ['on_finished', 'waiting', 'pulling'], ['run', 'pulling', 'running'], ['on_finished', 'running', 'pulling'], ], - 'initial': 'pulling', + 'initial': 'waiting', 'queued': True, - # 'ignore_invalid_triggers': True, + 'ignore_invalid_triggers': True, } # Ideally, we would be able to pass the auto_state_conf dict directly to @@ -109,25 +112,22 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) -> 'states': [ 'created', 'off', - 'waiting', {'name': 'auto', 'children': auto_state}, ], 'transitions': [ ['start', 'created', 'off'], - ['turn_on', 'off', 'waiting'], - ['on_initialized', 'waiting', 'auto'], - ['on_finished', 'waiting', 'auto'], + ['turn_on', 'off', 'auto'], ['on_raised', 'auto', 'off'], { 'trigger': 'turn_off', - 'source': ['waiting', 'auto'], + 'source': 'auto', 'dest': 'off', 'before': 'cancel_task', }, ], 'initial': 'created', 'queued': True, - # 'ignore_invalid_triggers': True, + 'ignore_invalid_triggers': True, } machine = MachineClass(model=model, **state_conf) # type: ignore diff --git a/src/nextline_schedule/auto/machine.py b/src/nextline_schedule/auto/machine.py index 0b49d2c..f5d9041 100644 --- a/src/nextline_schedule/auto/machine.py +++ b/src/nextline_schedule/auto/machine.py @@ -33,7 +33,7 @@ def __init__(self, callback: CallbackType): def subscribe_state(self) -> AsyncIterator[str]: return self._pubsub_state.subscribe() - async def on_enter_waiting(self) -> None: + async def on_enter_auto_waiting(self) -> None: task = asyncio.create_task(self._callback.wait()) self._task = task self._tasks.add(task) diff --git a/tests/auto/test_auto.py b/tests/auto/test_auto.py index 24ea321..84be900 100644 --- a/tests/auto/test_auto.py +++ b/tests/auto/test_auto.py @@ -36,7 +36,7 @@ async def test_one() -> None: expected = [ 'off', - 'waiting', + 'auto_waiting', 'auto_pulling', 'auto_running', 'auto_pulling', diff --git a/tests/auto/test_auto_on_raised.py b/tests/auto/test_auto_on_raised.py index 7faa09a..778bb1b 100644 --- a/tests/auto/test_auto_on_raised.py +++ b/tests/auto/test_auto_on_raised.py @@ -38,7 +38,7 @@ async def request_statement(): if state == 'off': break - expected = ['off', 'waiting', 'auto_pulling', 'off'] + expected = ['off', 'auto_waiting', 'auto_pulling', 'off'] assert expected == await states @@ -61,7 +61,7 @@ async def request_statement(): if state == 'off': break - expected = ['off', 'waiting', 'auto_pulling', 'auto_running', 'off'] + expected = ['off', 'auto_waiting', 'auto_pulling', 'auto_running', 'off'] assert expected == await states diff --git a/tests/auto/test_auto_turn_off.py b/tests/auto/test_auto_turn_off.py index 65938a4..18041a2 100644 --- a/tests/auto/test_auto_turn_off.py +++ b/tests/auto/test_auto_turn_off.py @@ -42,7 +42,7 @@ async def test_turn_off_while_waiting(): trace_no=prompt.trace_no, ) - expected = ['off', 'waiting', 'off'] + expected = ['off', 'auto_waiting', 'off'] assert expected == await states diff --git a/tests/test_fsm.py b/tests/test_fsm.py index dc01b5d..fb3bd85 100644 --- a/tests/test_fsm.py +++ b/tests/test_fsm.py @@ -5,7 +5,7 @@ import pytest from hypothesis import given from hypothesis import strategies as st -from transitions import Machine, MachineError +from transitions import Machine from transitions.extensions.markup import HierarchicalMarkupMachine from nextline_schedule.auto.factory import build_state_machine @@ -46,12 +46,13 @@ def st_paths(draw: st.DrawFn): 'start': {'dest': 'off'}, }, 'off': { - 'turn_on': {'dest': 'waiting'}, + 'turn_on': {'dest': 'auto_waiting'}, }, - 'waiting': { + 'auto_waiting': { 'turn_off': {'dest': 'off', 'before': 'cancel_task'}, 'on_initialized': {'dest': 'auto_pulling'}, 'on_finished': {'dest': 'auto_pulling'}, + 'on_raised': {'dest': 'off'}, }, 'auto_pulling': { 'run': {'dest': 'auto_running'}, @@ -84,7 +85,7 @@ def st_paths(draw: st.DrawFn): paths.append((trigger, trigger_map[trigger])) state = trigger_map[trigger]['dest'] else: - paths.append((trigger, {'error': MachineError})) + paths.append((trigger, {'invalid': True})) while state not in final_states: trigger_map = state_map_reduced[state] @@ -102,9 +103,8 @@ async def test_transitions_hypothesis(paths: list[tuple[str, dict[str, Any]]]): assert machine.is_created() for method, map in paths: - if error := map.get('error'): - with pytest.raises(error): - await getattr(machine, method)() + if map.get('invalid'): + await getattr(machine, method)() continue if before := map.get('before'):