Skip to content

Commit

Permalink
state machine callbacks created
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 30, 2024
1 parent 5f10d9a commit e98f5f8
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 19 deletions.
33 changes: 32 additions & 1 deletion yasmin/include/yasmin/state_machine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef YASMIN_STATE_MACHINE_HPP
#define YASMIN_STATE_MACHINE_HPP

#include <functional>
#include <map>
#include <memory>
#include <mutex>
Expand All @@ -29,6 +30,17 @@ namespace yasmin {

class StateMachine : public State {

using StartCallbackType = std::function<void(
std::shared_ptr<yasmin::blackboard::Blackboard>, const std::string &,
const std::vector<std::string> &)>;
using TransitionCallbackType = std::function<void(
std::shared_ptr<yasmin::blackboard::Blackboard>, const std::string &,
const std::string &, const std::string &,
const std::vector<std::string> &)>;
using EndCallbackType = std::function<void(
std::shared_ptr<yasmin::blackboard::Blackboard>, const std::string &,
const std::vector<std::string> &)>;

public:
StateMachine(std::vector<std::string> outcomes);

Expand All @@ -46,21 +58,40 @@ class StateMachine : public State {
get_transitions();
std::string get_current_state();

void add_start_cb(StartCallbackType cb, std::vector<std::string> args = {});
void add_transition_cb(TransitionCallbackType cb,
std::vector<std::string> args = {});
void add_end_cb(EndCallbackType cb, std::vector<std::string> args = {});
void
call_start_cbs(std::shared_ptr<yasmin::blackboard::Blackboard> blackboard,
const std::string &start_state);
void call_transition_cbs(
std::shared_ptr<yasmin::blackboard::Blackboard> blackboard,
const std::string &from_state, const std::string &to_state,
const std::string &outcome);
void call_end_cbs(std::shared_ptr<yasmin::blackboard::Blackboard> blackboard,
const std::string &outcome);

void validate();
std::string
execute(std::shared_ptr<blackboard::Blackboard> blackboard) override;
std::string execute();
std::string operator()();
using State::operator();

std::string to_string();
void validate();

private:
std::map<std::string, std::shared_ptr<State>> states;
std::map<std::string, std::map<std::string, std::string>> transitions;
std::string start_state;
std::string current_state;
std::unique_ptr<std::mutex> current_state_mutex;

std::vector<std::pair<StartCallbackType, std::vector<std::string>>> start_cbs;
std::vector<std::pair<TransitionCallbackType, std::vector<std::string>>>
transition_cbs;
std::vector<std::pair<EndCallbackType, std::vector<std::string>>> end_cbs;
};

} // namespace yasmin
Expand Down
2 changes: 1 addition & 1 deletion yasmin/src/yasmin/cb_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ CbState::execute(std::shared_ptr<blackboard::Blackboard> blackboard) {
return this->callback(blackboard);
}

std::string CbState::to_string() { return "CbState"; }
std::string CbState::to_string() { return "CbState"; }
81 changes: 77 additions & 4 deletions yasmin/src/yasmin/state_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,68 @@ std::string StateMachine::get_current_state() {
return this->current_state;
}

void StateMachine::add_start_cb(EndCallbackType cb,
std::vector<std::string> args) {
this->start_cbs.emplace_back(cb, args);
}

void StateMachine::add_transition_cb(TransitionCallbackType cb,
std::vector<std::string> args) {
this->transition_cbs.emplace_back(cb, args);
}

void StateMachine::add_end_cb(EndCallbackType cb,
std::vector<std::string> args) {
this->end_cbs.emplace_back(cb, args);
}

void StateMachine::call_start_cbs(
std::shared_ptr<yasmin::blackboard::Blackboard> blackboard,
const std::string &start_state) {
try {
for (const auto &callback_pair : this->start_cbs) {
const auto &cb = callback_pair.first;
const auto &args = callback_pair.second;
cb(blackboard, start_state, args);
}
} catch (const std::exception &e) {
YASMIN_LOG_ERROR("Could not execute start callback: %s"),
std::string(e.what());
}
}

void StateMachine::call_transition_cbs(
std::shared_ptr<yasmin::blackboard::Blackboard> blackboard,
const std::string &from_state, const std::string &to_state,
const std::string &outcome) {
try {
for (const auto &callback_pair : this->transition_cbs) {
const auto &cb = callback_pair.first;
const auto &args = callback_pair.second;
cb(blackboard, from_state, to_state, outcome, args);
}
} catch (const std::exception &e) {
YASMIN_LOG_ERROR("Could not execute transition callback: %s"),
std::string(e.what());
}
}

void StateMachine::call_end_cbs(
std::shared_ptr<yasmin::blackboard::Blackboard> blackboard,
const std::string &outcome) {
try {

for (const auto &callback_pair : this->end_cbs) {
const auto &cb = callback_pair.first;
const auto &args = callback_pair.second;
cb(blackboard, outcome, args);
}
} catch (const std::exception &e) {
YASMIN_LOG_ERROR("Could not execute end callback: %s"),
std::string(e.what());
}
}

void StateMachine::validate() {

// check initial state
Expand Down Expand Up @@ -196,12 +258,15 @@ StateMachine::execute(std::shared_ptr<blackboard::Blackboard> blackboard) {

this->validate();

this->call_start_cbs(blackboard, this->get_start_state());

this->current_state_mutex->lock();
this->current_state = this->start_state;
this->current_state_mutex->unlock();

std::map<std::string, std::string> transitions;
std::string outcome;
std::string old_outcome;

while (true) {

Expand All @@ -212,12 +277,13 @@ StateMachine::execute(std::shared_ptr<blackboard::Blackboard> blackboard) {
this->current_state_mutex->unlock();

outcome = (*state.get())(blackboard);
old_outcome = std::string(outcome);

// check outcome belongs to state
if (std::find(state->get_outcomes().begin(), state->get_outcomes().end(),
outcome) == this->outcomes.end()) {
throw std::logic_error("Outcome (" + outcome +
") is not register in state " +
throw std::logic_error("Outcome '" + outcome +
"' is not register in state " +
this->current_state);
}

Expand All @@ -238,6 +304,8 @@ StateMachine::execute(std::shared_ptr<blackboard::Blackboard> blackboard) {
this->current_state.clear();
this->current_state_mutex->unlock();

this->call_end_cbs(blackboard, outcome);

return outcome;

// outcome is a state
Expand All @@ -247,9 +315,14 @@ StateMachine::execute(std::shared_ptr<blackboard::Blackboard> blackboard) {
this->current_state = outcome;
this->current_state_mutex->unlock();

this->call_transition_cbs(blackboard, this->get_start_state(), outcome,
old_outcome);

// outcome is not in the sm
} else {
throw std::logic_error("Outcome (" + outcome + ") without transition");
throw std::logic_error(
"Outcome '" + outcome +
"' is not a state neither a state machine outcome");
}
}

Expand Down Expand Up @@ -284,4 +357,4 @@ std::string StateMachine::to_string() {
}

return result;
}
}
87 changes: 74 additions & 13 deletions yasmin/yasmin/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.


from typing import Dict, List, Union
from typing import Dict, List, Any, Union, Callable
from threading import Lock

import yasmin
Expand All @@ -31,6 +31,9 @@ def __init__(self, outcomes: List[str]) -> None:
self._start_state = None
self.__current_state = None
self.__current_state_lock = Lock()
self.__start_cbs = []
self.__transition_cbs = []
self.__end_cbs = []

def add_state(
self,
Expand Down Expand Up @@ -81,6 +84,58 @@ def cancel_state(self) -> None:
if self.__current_state:
self._states[self.__current_state]["state"].cancel_state()

def get_states(self) -> Dict[str, Union[State, Dict[str, str]]]:
return self._states

def get_current_state(self) -> str:
with self.__current_state_lock:
if self.__current_state:
return self.__current_state

return ""

def add_start_cb(self, cb: Callable, args: List[Any] = None) -> None:
if args is None:
args = []
self.__start_cbs.append((cb, args))

def add_transition_cb(self, cb: Callable, args: List[Any] = None) -> None:
if args is None:
args = []
self.__transition_cbs.append((cb, args))

def add_end_cb(self, cb: Callable, args: List[Any] = None) -> None:
if args is None:
args = []
self.__end_cbs.append((cb, args))

def _call_start_cbs(self, blackboard: Blackboard, start_state: str) -> None:
try:
for cb, args in self.__start_cbs:
cb(blackboard, start_state, *args)
except Exception as e:
yasmin.YASMIN_LOG_ERROR(f"Could not execute start callback: {e}")

def _call_transition_cbs(
self,
blackboard: Blackboard,
from_state: str,
to_state: str,
outcome: str,
) -> None:
try:
for cb, args in self.__transition_cbs:
cb(blackboard, from_state, to_state, outcome, *args)
except Exception as e:
yasmin.YASMIN_LOG_ERROR(f"Could not execute transition callback: {e}")

def _call_end_cbs(self, blackboard: Blackboard, outcome: str) -> None:
try:
for cb, args in self.__end_cbs:
cb(blackboard, outcome, *args)
except Exception as e:
yasmin.YASMIN_LOG_ERROR(f"Could not execute end callback: {e}")

def validate(self) -> None:

# check initial state
Expand Down Expand Up @@ -134,6 +189,8 @@ def execute(self, blackboard: Blackboard) -> str:

self.validate()

self._call_start_cbs(blackboard, self._start_state)

with self.__current_state_lock:
self.__current_state = self._start_state

Expand All @@ -143,11 +200,12 @@ def execute(self, blackboard: Blackboard) -> str:
state = self._states[self.__current_state]

outcome = state["state"](blackboard)
old_come = outcome

# check outcome belongs to state
if outcome not in state["state"].get_outcomes():
raise KeyError(
f"Outcome ({outcome}) is not register in state {self.__current_state}"
f"Outcome '{outcome}' is not register in state {self.__current_state}"
)

# translate outcome using transitions
Expand All @@ -161,26 +219,29 @@ def execute(self, blackboard: Blackboard) -> str:
if outcome in self.get_outcomes():
with self.__current_state_lock:
self.__current_state = None

yasmin.YASMIN_LOG_INFO(f"State machine ends with outcome '{outcome}'")
self._call_end_cbs(blackboard, outcome)
return outcome

# outcome is a state
elif outcome in self._states:

self._call_transition_cbs(
blackboard,
self.get_current_state(),
outcome,
old_come,
)

with self.__current_state_lock:
self.__current_state = outcome

# outcome is not in the sm
else:
raise KeyError(f"Outcome ({outcome}) without transition")

def get_states(self) -> Dict[str, Union[State, Dict[str, str]]]:
return self._states

def get_current_state(self) -> str:
with self.__current_state_lock:
if self.__current_state:
return self.__current_state

return ""
raise KeyError(
f"Outcome '{outcome}' is not a state neither a state machine outcome"
)

def __str__(self) -> str:
return f"StateMachine: {self._states}"

0 comments on commit e98f5f8

Please sign in to comment.