Skip to content

Commit

Permalink
Merge pull request #12 from simonsobs/dev
Browse files Browse the repository at this point in the history
Turn the Waiting state into a sub-state of Auto
  • Loading branch information
TaiSakuma authored Apr 26, 2024
2 parents 759b8dd + 1b9718d commit 37e1169
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 47 deletions.
42 changes: 25 additions & 17 deletions src/nextline_schedule/auto/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, Coroutine

from nextline import Nextline
from nextline.plugin.spec import Context, hookimpl

from nextline_schedule.types import Statement

Expand All @@ -16,9 +17,28 @@ def build_auto_mode_state_machine(
callback = Callback(nextline=nextline, request_statement=request_statement)
machine = AutoModeStateMachine(callback=callback)
callback.auto_mode = machine
on_call = OnCall(auto_mode=machine)
nextline.register(on_call)
return machine


class OnCall:
def __init__(self, auto_mode: AutoModeStateMachine):
self.auto_mode = auto_mode

@hookimpl
async def on_initialize_run(self) -> None:
await self.auto_mode.on_initialized() # type: ignore

@hookimpl
async def on_finished(self, context: Context) -> None:
nextline = context.nextline
if nextline.format_exception():
await self.auto_mode.on_raised() # type: ignore
return
await self.auto_mode.on_finished() # type: ignore


class Callback:
def __init__(
self,
Expand All @@ -27,23 +47,15 @@ def __init__(
):
self._nextline = nextline
self._request_statement = request_statement

self._logger = getLogger(__name__)

self.auto_mode: AutoModeStateMachine # to be set

async def wait(self) -> None:
try:
async for state in self._nextline.subscribe_state():
if state == 'initialized':
await self.auto_mode.on_initialized() # type: ignore
break
if state == 'finished':
await self.auto_mode.on_finished() # type: ignore
break
except asyncio.CancelledError:
self._logger.info(f'{self.__class__.__name__}.wait() cancelled')
return
match self._nextline.state:
case 'initialized':
await self.auto_mode.on_initialized() # type: ignore
case 'finished':
await self.auto_mode.on_finished() # type: ignore

async def pull(self) -> None:
try:
Expand All @@ -61,9 +73,5 @@ async def pull(self) -> None:
async def run(self, started: asyncio.Event) -> None:
try:
await self._nextline.run_continue_and_wait(started)
if self._nextline.format_exception():
await self.auto_mode.on_raised() # type: ignore
return
await self.auto_mode.on_finished() # type: ignore
except asyncio.CancelledError:
self._logger.info(f'{self.__class__.__name__}.run() cancelled')
36 changes: 18 additions & 18 deletions src/nextline_schedule/auto/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) ->
.-------------->| Off |<--------------.
| '-------------' |
| | turn_on() |
turn_off() | |
| v |
| .-------------. |
|---------------| Waiting | on_raised()
| '-------------' |
| | on_initialized() |
| | on_finished() |
turn_off() | on_raised()
| | |
| | |
| .------------------+------------------. |
| | Auto | | |
| | v | |
| | .-------------. | |
| | | Waiting | | |
| | '-------------' | |
| | | on_initialized() | |
| | | on_finished() | |
| | v | |
| | .-------------. | |
| | | Pulling | | |
'---| '-------------' |---'
| run() | ^ |
Expand All @@ -44,11 +45,11 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) ->
'-------------------------------------'
>>> class Model:
... def on_enter_waiting(self):
... def on_enter_auto_waiting(self):
... print('enter the waiting state')
... self.on_finished()
...
... def on_exit_waiting(self):
... def on_exit_auto_waiting(self):
... print('exit the waiting state')
...
... def on_enter_auto_pulling(self):
Expand Down Expand Up @@ -88,14 +89,16 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) ->

auto_state_conf = {
'name': 'auto',
'states': ['pulling', 'running'],
'states': ['waiting', 'pulling', 'running'],
'transitions': [
['on_initialized', 'waiting', 'pulling'],
['on_finished', 'waiting', 'pulling'],
['run', 'pulling', 'running'],
['on_finished', 'running', 'pulling'],
],
'initial': 'pulling',
'initial': 'waiting',
'queued': True,
# 'ignore_invalid_triggers': True,
'ignore_invalid_triggers': True,
}

# Ideally, we would be able to pass the auto_state_conf dict directly to
Expand All @@ -109,25 +112,22 @@ def build_state_machine(model=None, graph=False, asyncio=True, markup=False) ->
'states': [
'created',
'off',
'waiting',
{'name': 'auto', 'children': auto_state},
],
'transitions': [
['start', 'created', 'off'],
['turn_on', 'off', 'waiting'],
['on_initialized', 'waiting', 'auto'],
['on_finished', 'waiting', 'auto'],
['turn_on', 'off', 'auto'],
['on_raised', 'auto', 'off'],
{
'trigger': 'turn_off',
'source': ['waiting', 'auto'],
'source': 'auto',
'dest': 'off',
'before': 'cancel_task',
},
],
'initial': 'created',
'queued': True,
# 'ignore_invalid_triggers': True,
'ignore_invalid_triggers': True,
}

machine = MachineClass(model=model, **state_conf) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion src/nextline_schedule/auto/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, callback: CallbackType):
def subscribe_state(self) -> AsyncIterator[str]:
return self._pubsub_state.subscribe()

async def on_enter_waiting(self) -> None:
async def on_enter_auto_waiting(self) -> None:
task = asyncio.create_task(self._callback.wait())
self._task = task
self._tasks.add(task)
Expand Down
2 changes: 1 addition & 1 deletion tests/auto/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def test_one() -> None:

expected = [
'off',
'waiting',
'auto_waiting',
'auto_pulling',
'auto_running',
'auto_pulling',
Expand Down
4 changes: 2 additions & 2 deletions tests/auto/test_auto_on_raised.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def request_statement():
if state == 'off':
break

expected = ['off', 'waiting', 'auto_pulling', 'off']
expected = ['off', 'auto_waiting', 'auto_pulling', 'off']
assert expected == await states


Expand All @@ -61,7 +61,7 @@ async def request_statement():
if state == 'off':
break

expected = ['off', 'waiting', 'auto_pulling', 'auto_running', 'off']
expected = ['off', 'auto_waiting', 'auto_pulling', 'auto_running', 'off']
assert expected == await states


Expand Down
2 changes: 1 addition & 1 deletion tests/auto/test_auto_turn_off.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def test_turn_off_while_waiting():
trace_no=prompt.trace_no,
)

expected = ['off', 'waiting', 'off']
expected = ['off', 'auto_waiting', 'off']
assert expected == await states


Expand Down
14 changes: 7 additions & 7 deletions tests/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from hypothesis import given
from hypothesis import strategies as st
from transitions import Machine, MachineError
from transitions import Machine
from transitions.extensions.markup import HierarchicalMarkupMachine

from nextline_schedule.auto.factory import build_state_machine
Expand Down Expand Up @@ -46,12 +46,13 @@ def st_paths(draw: st.DrawFn):
'start': {'dest': 'off'},
},
'off': {
'turn_on': {'dest': 'waiting'},
'turn_on': {'dest': 'auto_waiting'},
},
'waiting': {
'auto_waiting': {
'turn_off': {'dest': 'off', 'before': 'cancel_task'},
'on_initialized': {'dest': 'auto_pulling'},
'on_finished': {'dest': 'auto_pulling'},
'on_raised': {'dest': 'off'},
},
'auto_pulling': {
'run': {'dest': 'auto_running'},
Expand Down Expand Up @@ -84,7 +85,7 @@ def st_paths(draw: st.DrawFn):
paths.append((trigger, trigger_map[trigger]))
state = trigger_map[trigger]['dest']
else:
paths.append((trigger, {'error': MachineError}))
paths.append((trigger, {'invalid': True}))

while state not in final_states:
trigger_map = state_map_reduced[state]
Expand All @@ -102,9 +103,8 @@ async def test_transitions_hypothesis(paths: list[tuple[str, dict[str, Any]]]):
assert machine.is_created()

for method, map in paths:
if error := map.get('error'):
with pytest.raises(error):
await getattr(machine, method)()
if map.get('invalid'):
await getattr(machine, method)()
continue

if before := map.get('before'):
Expand Down

0 comments on commit 37e1169

Please sign in to comment.