Skip to content

Commit

Permalink
initial validation for sate machines in python
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 27, 2024
1 parent c7013b3 commit 2cb1a08
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 9 deletions.
109 changes: 101 additions & 8 deletions yasmin/tests/python/test_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def execute(self, blackboard):
self.counter += 1
blackboard["foo_str"] = f"Counter: {self.counter}"
return "outcome1"

else:
return "outcome2"

Expand All @@ -42,30 +43,122 @@ def execute(self, blackboard):

class TestStateMachine(unittest.TestCase):

def setUp(self):
maxDiff = None

self.sm = StateMachine(outcomes=["outcome4", "outcome5"])
def setUp(self):
self.sm = StateMachine(outcomes=["outcome4"])

self.sm.add_state(
"FOO", FooState(), transitions={"outcome1": "BAR", "outcome2": "outcome4"}
"FOO",
FooState(),
transitions={
"outcome1": "BAR",
"outcome2": "outcome4",
},
)
self.sm.add_state(
"BAR",
BarState(),
transitions={"outcome2": "FOO"},
)
self.sm.add_state("BAR", BarState(), transitions={"outcome2": "FOO"})

def test_state_machine_get_states(self):

self.assertTrue(isinstance(self.sm.get_states()["FOO"]["state"], FooState))
self.assertTrue(isinstance(self.sm.get_states()["BAR"]["state"], BarState))

def test_state_machine_get_start_state(self):

self.assertEqual("FOO", self.sm.get_start_state())
self.sm.set_start_state("BAR")
self.assertEqual("BAR", self.sm.get_start_state())

def test_state_machine_get_current_state(self):

self.assertEqual("", self.sm.get_current_state())

def test_state_call(self):

self.assertEqual("outcome4", self.sm())

def test_add_repeated_state(self):
with self.assertRaises(Exception) as context:
self.sm.add_state(
"FOO",
FooState(),
transitions={
"outcome1": "BAR",
},
)
self.assertEqual(
str(context.exception),
"State 'FOO' already registered in the state machine",
)

def test_add_state_with_wrong_outcome(self):
with self.assertRaises(Exception) as context:
self.sm.add_state(
"FOO1",
FooState(),
transitions={
"outcome9": "BAR",
},
)
self.assertEqual(
str(context.exception),
"State 'FOO1' references unregistered outcomes: 'outcome9', available outcomes are: ['outcome1', 'outcome2']",
)

def test_add_wrong_source_transition(self):
with self.assertRaises(Exception) as context:
self.sm.add_state(
"FOO1",
FooState(),
transitions={
"": "BAR",
},
)
self.assertEqual(
str(context.exception), "Transitions with empty source in state 'FOO1'"
)

def test_add_wrong_target_transition(self):
with self.assertRaises(Exception) as context:
self.sm.add_state(
"FOO1",
FooState(),
transitions={
"outcome1": "",
},
)
self.assertEqual(
str(context.exception), "Transitions with empty target in state 'FOO1'"
)

def test_validate_state_machine(self):

sm_1 = StateMachine(outcomes=["outcome4"])

sm_2 = StateMachine(outcomes=["outcome4", "outcome5"])
sm_1.add_state("FSM", sm_2)

sm_2.add_state(
"FOO",
FooState(),
transitions={
"outcome1": "BAR",
},
)
with self.assertRaises(Exception) as context:
sm_1.validate()
self.assertEqual(
str(context.exception),
(
f"{'*' * 100}\nState machine failed validation check:"
"\n\tState 'FSM' outcome 'outcome5' not registered in transitions"
f"\n\tState machine 'FSM' failed validation check\n{'*' * 100}\n"
"State machine failed validation check:"
"\n\tState 'FOO' outcome 'outcome2' not registered in transitions"
"\n\tTarget outcome 'outcome4' not registered in transitions"
"\n\tTarget outcome 'outcome5' not registered in transitions"
"\n\tState machine outcome 'BAR' not registered as outcome neither state"
f"\n\n\tAvailable states: ['FOO']\n{'*' * 100}"
f"\n\n\tAvailable states: ['FSM']\n{'*' * 100}"
),
)
77 changes: 76 additions & 1 deletion yasmin/yasmin/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,30 @@ def __init__(self, outcomes: List[str]) -> None:
self.__current_state_lock = Lock()

def add_state(
self, name: str, state: State, transitions: Dict[str, str] = None
self,
name: str,
state: State,
transitions: Dict[str, str] = None,
) -> None:

if not transitions:
transitions = {}

if name in self._states:
raise Exception(f"State '{name}' already registered in the state machine")

for key in transitions:
if not key:
raise Exception(f"Transitions with empty source in state '{name}'")

if not transitions[key]:
raise Exception(f"Transitions with empty target in state '{name}'")

if key not in state.get_outcomes():
raise Exception(
f"State '{name}' references unregistered outcomes: '{key}', available outcomes are: {state.get_outcomes()}"
)

self._states[name] = {"state": state, "transitions": transitions}

if not self._start_state:
Expand All @@ -55,8 +73,65 @@ def cancel_state(self) -> None:
if self.__current_state:
self._states[self.__current_state]["state"].cancel_state()

def validate(self, raise_exception: bool = True) -> str:
errors = ""

# check initial state
if self._start_state is None:
errors += "\n\tNo initial state set."
elif self._start_state not in self._states:
errors += f"\n\tInitial state label: '{self._start_state}' is not in the state machine."

terminal_outcomes = []

# check all states
for state_name in self._states:

state: State = self._states[state_name]["state"]
transitions: Dict[str, str] = self._states[state_name]["transitions"]

outcomes = state.get_outcomes()

# check if all state outcomes are in transitions
for o in outcomes:
if o not in set(list(transitions.keys()) + self.get_outcomes()):
errors += f"\n\tState '{state_name}' outcome '{o}' not registered in transitions"

# state outcomes that are in state machines out do not need transitions
elif o in self.get_outcomes():
terminal_outcomes.append(o)

# if sate is a state machine, validate it
if isinstance(state, StateMachine):
aux_errors = state.validate(False)
if aux_errors:
errors += f"\n\tState machine '{state_name}' failed validation check\n{aux_errors}"

# add terminal outcomes
terminal_outcomes.extend([transitions[key] for key in transitions])

# check terminal outcomes for the state machine
terminal_outcomes = set(terminal_outcomes)
for o in self.get_outcomes():
if o not in terminal_outcomes:
errors += f"\n\tTarget outcome '{o}' not registered in transitions"

for o in terminal_outcomes:
if o not in set(list(self._states.keys()) + self.get_outcomes()):
errors += f"\n\tState machine outcome '{o}' not registered as outcome neither state"

if errors:
errors = f"{'*' * 100}\nState machine failed validation check:{errors}\n\n\tAvailable states: {list(self._states.keys())}\n{'*' * 100}"

if raise_exception and errors:
raise Exception(errors)

return errors

def execute(self, blackboard: Blackboard) -> str:

self.validate()

with self.__current_state_lock:
self.__current_state = self._start_state

Expand Down

0 comments on commit 2cb1a08

Please sign in to comment.