From 3ed8acefc05eefff9b28464b8133ffab6230d05e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Mon, 4 Nov 2024 20:35:41 +0100 Subject: [PATCH] strict_mode added to validation --- yasmin/include/yasmin/state_machine.hpp | 4 +- yasmin/src/yasmin/state_machine.cpp | 47 +++++++++++++--------- yasmin/tests/python/test_blackboard.py | 8 ++-- yasmin/tests/python/test_state.py | 10 ++--- yasmin/tests/python/test_state_machine.py | 30 ++++++++------ yasmin/yasmin/state_machine.py | 42 ++++++++++--------- yasmin_ros/tests/python/test_yasmin_ros.py | 16 ++++---- 7 files changed, 88 insertions(+), 69 deletions(-) diff --git a/yasmin/include/yasmin/state_machine.hpp b/yasmin/include/yasmin/state_machine.hpp index fd64463..3b45688 100644 --- a/yasmin/include/yasmin/state_machine.hpp +++ b/yasmin/include/yasmin/state_machine.hpp @@ -184,9 +184,11 @@ class StateMachine : public State { /** * @brief Validates the state machine configuration. * + * @param strict Whether the validation is strict, which means checking if all + * state outcomes are used and all state machine outcomes are reached. * @throws std::runtime_error If the state machine is misconfigured. */ - void validate(); + void validate(bool strict_mode = false); /** * @brief Executes the state machine. diff --git a/yasmin/src/yasmin/state_machine.cpp b/yasmin/src/yasmin/state_machine.cpp index 4b2678d..068f54b 100644 --- a/yasmin/src/yasmin/state_machine.cpp +++ b/yasmin/src/yasmin/state_machine.cpp @@ -196,11 +196,11 @@ void StateMachine::call_end_cbs( } } -void StateMachine::validate() { +void StateMachine::validate(bool strict_mode) { YASMIN_LOG_DEBUG("Validating state machine '%s'", this->to_string().c_str()); - if (this->validated.load()) { + if (this->validated.load() && !strict_mode) { YASMIN_LOG_DEBUG("State machine '%s' has already been validated", this->to_string().c_str()); } @@ -210,6 +210,7 @@ void StateMachine::validate() { throw std::runtime_error("No initial state set"); } + // Terminal outcomes from all transitions std::set terminal_outcomes; // Check all states @@ -222,26 +223,30 @@ void StateMachine::validate() { std::set outcomes = state->get_outcomes(); - // Check if all state outcomes are in transitions - for (const std::string &o : outcomes) { + if (strict_mode) { + // Check if all outcomes of the state are in transitions + for (const std::string &o : outcomes) { - if (transitions.find(o) == transitions.end() && - std::find(this->get_outcomes().begin(), this->get_outcomes().end(), - o) == this->get_outcomes().end()) { + if (transitions.find(o) == transitions.end() && + std::find(this->get_outcomes().begin(), this->get_outcomes().end(), + o) == this->get_outcomes().end()) { - throw std::runtime_error("State '" + state_name + "' outcome '" + o + - "' not registered in transitions"); + throw std::runtime_error("State '" + state_name + "' outcome '" + o + + "' not registered in transitions"); - } else if (std::find(this->get_outcomes().begin(), - this->get_outcomes().end(), - o) != this->get_outcomes().end()) { - terminal_outcomes.insert(o); + // Outcomes of the state that are in outcomes of the state machine + // do not need transitions + } else if (std::find(this->get_outcomes().begin(), + this->get_outcomes().end(), + o) != this->get_outcomes().end()) { + terminal_outcomes.insert(o); + } } } // If state is a state machine, validate it if (std::dynamic_pointer_cast(state)) { - std::dynamic_pointer_cast(state)->validate(); + std::dynamic_pointer_cast(state)->validate(strict_mode); } // Add terminal outcomes @@ -255,15 +260,17 @@ void StateMachine::validate() { std::set sm_outcomes(this->get_outcomes().begin(), this->get_outcomes().end()); - // Check if all state machine outcomes are in the terminal outcomes - for (const std::string &o : this->get_outcomes()) { - if (terminal_outcomes.find(o) == terminal_outcomes.end()) { - throw std::runtime_error("Target outcome '" + o + - "' not registered in transitions"); + if (strict_mode) { + // Check if all outcomes from the state machine are in the terminal outcomes + for (const std::string &o : this->get_outcomes()) { + if (terminal_outcomes.find(o) == terminal_outcomes.end()) { + throw std::runtime_error("Target outcome '" + o + + "' not registered in transitions"); + } } } - // Check if all terminal outcomes are states or state machine outcomes + // Check if all terminal outcomes are states or outcomes of the state machine for (const std::string &o : terminal_outcomes) { if (this->states.find(o) == this->states.end() && sm_outcomes.find(o) == sm_outcomes.end()) { diff --git a/yasmin/tests/python/test_blackboard.py b/yasmin/tests/python/test_blackboard.py index 028f139..1c45b88 100644 --- a/yasmin/tests/python/test_blackboard.py +++ b/yasmin/tests/python/test_blackboard.py @@ -23,19 +23,19 @@ class TestBlackboard(unittest.TestCase): def setUp(self): self.blackboard = Blackboard() - def test_blackboard_get(self): + def test_get(self): self.blackboard["foo"] = "foo" self.assertEqual("foo", self.blackboard["foo"]) - def test_blackboard_delete(self): + def test_delete(self): self.blackboard["foo"] = "foo" del self.blackboard["foo"] self.assertFalse("foo" in self.blackboard) - def test_blackboard_contains(self): + def test_contains(self): self.blackboard["foo"] = "foo" self.assertTrue("foo" in self.blackboard) - def test_blackboard_len(self): + def test_len(self): self.blackboard["foo"] = "foo" self.assertEqual(1, len(self.blackboard)) diff --git a/yasmin/tests/python/test_state.py b/yasmin/tests/python/test_state.py index 3344d58..d207e40 100644 --- a/yasmin/tests/python/test_state.py +++ b/yasmin/tests/python/test_state.py @@ -39,19 +39,19 @@ class TestState(unittest.TestCase): def setUp(self): self.state = FooState() - def test_state_call(self): + def test_call(self): self.assertEqual("outcome1", self.state()) - def test_state_cancel(self): + def test_cancel(self): self.assertFalse(self.state.is_canceled()) self.state.cancel_state() self.assertTrue(self.state.is_canceled()) - def test_state_get_outcomes(self): + def test_get_outcomes(self): self.assertEqual("outcome1", list(self.state.get_outcomes())[0]) - def test_state_str(self): + def test_str(self): self.assertEqual("FooState", str(self.state)) - def test_state_init_exception(self): + def test_init_exception(self): self.assertRaises(Exception, BarState) diff --git a/yasmin/tests/python/test_state_machine.py b/yasmin/tests/python/test_state_machine.py index 3db5bde..e5a1301 100644 --- a/yasmin/tests/python/test_state_machine.py +++ b/yasmin/tests/python/test_state_machine.py @@ -35,7 +35,7 @@ def execute(self, blackboard): class BarState(State): def __init__(self): - super().__init__(outcomes=["outcome2"]) + super().__init__(outcomes=["outcome2", "outcome3"]) def execute(self, blackboard): return "outcome2" @@ -46,7 +46,7 @@ class TestStateMachine(unittest.TestCase): maxDiff = None def setUp(self): - self.sm = StateMachine(outcomes=["outcome4"]) + self.sm = StateMachine(outcomes=["outcome4", "outcome5"]) self.sm.add_state( "FOO", @@ -62,19 +62,19 @@ def setUp(self): transitions={"outcome2": "FOO"}, ) - def test_state_machine_str(self): + def test_str(self): self.assertEqual("State Machine [BAR (BarState), FOO (FooState)]", str(self.sm)) - def test_state_machine_get_states(self): + def test_get_states(self): self.assertTrue(isinstance(self.sm.get_states()["FOO"]["state"], FooState)) self.assertTrue(isinstance(self.sm.get_states()["BAR"]["state"], BarState)) - def test_state_machine_get_start_state(self): + def test_get_start_state(self): self.assertEqual("FOO", self.sm.get_start_state()) self.sm.set_start_state("BAR") self.assertEqual("BAR", self.sm.get_start_state()) - def test_state_machine_get_current_state(self): + def test_get_current_state(self): self.assertEqual("", self.sm.get_current_state()) def test_state_call(self): @@ -148,7 +148,7 @@ def test_add_wrong_target_transition(self): str(context.exception), "Transitions with empty target in state 'FOO1'" ) - def test_validate_state_machine_outcome_from_fsm_not_used(self): + def test_validate_outcome_from_fsm_not_used(self): sm_1 = StateMachine(outcomes=["outcome4"]) @@ -163,14 +163,15 @@ def test_validate_state_machine_outcome_from_fsm_not_used(self): "outcome2": "outcome4", }, ) + with self.assertRaises(KeyError) as context: - sm_1.validate() + sm_1.validate(True) self.assertEqual( str(context.exception), "\"State 'FSM' outcome 'outcome5' not registered in transitions\"", ) - def test_validate_state_machine_outcome_from_state_not_used(self): + def test_validate_outcome_from_state_not_used(self): sm_1 = StateMachine(outcomes=["outcome4"]) @@ -184,14 +185,15 @@ def test_validate_state_machine_outcome_from_state_not_used(self): "outcome1": "outcome4", }, ) + with self.assertRaises(KeyError) as context: - sm_1.validate() + sm_1.validate(True) self.assertEqual( str(context.exception), "\"State 'FOO' outcome 'outcome2' not registered in transitions\"", ) - def test_validate_state_machine_fsm_outcome_not_used(self): + def test_validate_fsm_outcome_not_used(self): sm_1 = StateMachine(outcomes=["outcome4"]) @@ -212,14 +214,15 @@ def test_validate_state_machine_fsm_outcome_not_used(self): "outcome2": "outcome4", }, ) + with self.assertRaises(KeyError) as context: - sm_1.validate() + sm_1.validate(True) self.assertEqual( str(context.exception), "\"Target outcome 'outcome5' not registered in transitions\"", ) - def test_validate_state_machine_wrong_state(self): + def test_validate_wrong_state(self): sm_1 = StateMachine(outcomes=["outcome4"]) @@ -240,6 +243,7 @@ def test_validate_state_machine_wrong_state(self): "outcome2": "outcome4", }, ) + with self.assertRaises(KeyError) as context: sm_1.validate() self.assertEqual( diff --git a/yasmin/yasmin/state_machine.py b/yasmin/yasmin/state_machine.py index d89d4bf..8e4ef99 100644 --- a/yasmin/yasmin/state_machine.py +++ b/yasmin/yasmin/state_machine.py @@ -32,6 +32,7 @@ class StateMachine(State): _start_state (str): The name of the initial state of the state machine. __current_state (str): The name of the current state being executed. __current_state_lock (Lock): A threading lock to manage access to the current state. + _validated (bool): A flag indicating whether the state machine has been validated. __start_cbs (List[Tuple[Callable[[Blackboard, str, List[Any]], None], List[Any]]]): A list of callbacks to call when the state machine starts. __transition_cbs (List[Tuple[Callable[[Blackboard, str, List[Any]], None], List[Any]]]): A list of callbacks to call during state transitions. __end_cbs (List[Tuple[Callable[[Blackboard, str, List[Any]], None], List[Any]]]): A list of callbacks to call when the state machine ends. @@ -276,20 +277,23 @@ def _call_end_cbs(self, blackboard: Blackboard, outcome: str) -> None: except Exception as e: yasmin.YASMIN_LOG_ERROR(f"Could not execute end callback: {e}") - def validate(self) -> None: + def validate(self, strict_mode: bool = False) -> None: """ Validates the state machine to ensure all states and transitions are correct. + Parameters: + strict_mode (bool): Whether the validation is strict, which means checking if all state outcomes are used and all state machine outcomes are reached. + Raises: RuntimeError: If no initial state is set. KeyError: If there are any unregistered outcomes or transitions. """ yasmin.YASMIN_LOG_DEBUG(f"Validating state machine '{self}'") - if self._validated: + if self._validated and not strict_mode: yasmin.YASMIN_LOG_DEBUG("State machine '{self}' has already been validated") - # terminal outcomes + # Terminal outcomes from all transitions terminal_outcomes = [] # Check initial state @@ -303,20 +307,21 @@ def validate(self) -> None: outcomes = state.get_outcomes() - # Check if all state outcomes are in transitions - for o in outcomes: - if o not in set(list(transitions.keys()) + list(self.get_outcomes())): - raise KeyError( - f"State '{state_name}' outcome '{o}' not registered in transitions" - ) + if strict_mode: + # Check if all outcomes of the state are in transitions + for o in outcomes: + if o not in set(list(transitions.keys()) + list(self.get_outcomes())): + raise KeyError( + f"State '{state_name}' outcome '{o}' not registered in transitions" + ) - # State outcomes that are in state machine outcomes do not need transitions - elif o in self.get_outcomes(): - terminal_outcomes.append(o) + # Outcomes of the state that are in outcomes of the state machine do not need transitions + elif o in self.get_outcomes(): + terminal_outcomes.append(o) # If state is a state machine, validate it if isinstance(state, StateMachine): - state.validate() + state.validate(strict_mode) # Add terminal outcomes terminal_outcomes.extend([transitions[key] for key in transitions]) @@ -324,12 +329,13 @@ def validate(self) -> None: # Check terminal outcomes for the state machine terminal_outcomes = set(terminal_outcomes) - # Check if all state machine outcomes are in the terminal outcomes - for o in self.get_outcomes(): - if o not in terminal_outcomes: - raise KeyError(f"Target outcome '{o}' not registered in transitions") + if strict_mode: + # Check if all outcomes of the state machine are in the terminal outcomes + for o in self.get_outcomes(): + if o not in terminal_outcomes: + raise KeyError(f"Target outcome '{o}' not registered in transitions") - # Check if all terminal outcomes are states or state machine outcomes + # Check if all terminal outcomes are states or outcomes of the state machine for o in terminal_outcomes: if o not in set(list(self._states.keys()) + list(self.get_outcomes())): raise KeyError( diff --git a/yasmin_ros/tests/python/test_yasmin_ros.py b/yasmin_ros/tests/python/test_yasmin_ros.py index 91bd997..6a3ebfc 100644 --- a/yasmin_ros/tests/python/test_yasmin_ros.py +++ b/yasmin_ros/tests/python/test_yasmin_ros.py @@ -14,8 +14,8 @@ # along with this program. If not, see . -import unittest import time +import unittest from threading import Thread from yasmin_ros import ActionState, ServiceState, MonitorState @@ -109,7 +109,7 @@ def setUpClass(cls): def tearDownClass(cls): rclpy.shutdown() - def test_yasmin_ros_action_succeed(self): + def test_action_succeed(self): def create_goal_cb(blackboard): goal = Fibonacci.Goal() @@ -119,7 +119,7 @@ def create_goal_cb(blackboard): state = ActionState(Fibonacci, "test", create_goal_cb) self.assertEqual(SUCCEED, state()) - def test_yasmin_ros_action_result_handler(self): + def test_action_result_handler(self): def create_goal_cb(blackboard): goal = Fibonacci.Goal() @@ -134,7 +134,7 @@ def result_handler(blackboard, result): ) self.assertEqual("new_outcome", state()) - def test_yasmin_ros_action_cancel(self): + def test_action_cancel(self): def create_goal_cb(blackboard): goal = Fibonacci.Goal() @@ -157,7 +157,7 @@ def cancel_state(state, seconds): self.assertEqual(CANCEL, state()) thread.join() - def test_yasmin_ros_action_abort(self): + def test_action_abort(self): def create_goal_cb(blackboard): goal = Fibonacci.Goal() @@ -167,7 +167,7 @@ def create_goal_cb(blackboard): state = ActionState(Fibonacci, "test", create_goal_cb) self.assertEqual(ABORT, state()) - def test_yasmin_ros_service(self): + def test_service(self): def create_request_cb(blackboard): request = AddTwoInts.Request() @@ -178,7 +178,7 @@ def create_request_cb(blackboard): state = ServiceState(AddTwoInts, "test", create_request_cb) self.assertEqual(SUCCEED, state()) - def test_yasmin_ros_service_response_handler(self): + def test_service_response_handler(self): def create_request_cb(blackboard): request = AddTwoInts.Request() @@ -194,7 +194,7 @@ def response_handler(blackboard, response): ) self.assertEqual("new_outcome", state()) - def test_yasmin_ros_monitor_timeout(self): + def test_monitor_timeout(self): def monitor_handler(blackboard, msg): return SUCCEED