diff --git a/include/mav/Connection.h b/include/mav/Connection.h index 170232b..76125b4 100644 --- a/include/mav/Connection.h +++ b/include/mav/Connection.h @@ -192,8 +192,8 @@ namespace mav { return !_underlying_network_fault && (millis() - _last_received_ms < CONNECTION_TIMEOUT); } - template - CallbackHandle addMessageCallback(const T &on_message, const E &on_error) { + CallbackHandle addMessageCallback(const std::function &on_message, + const std::function on_error) { std::scoped_lock lock(_message_callback_mtx); CallbackHandle handle = _next_handle; _message_callbacks[handle] = FunctionCallback{on_message, on_error}; @@ -201,9 +201,32 @@ namespace mav { return handle; } - template - CallbackHandle addMessageCallback(const T &on_message) { - return addMessageCallback(on_message, nullptr); + CallbackHandle addMessageCallback(const std::function &on_message) { + return addMessageCallback(on_message, std::function{}); + } + + CallbackHandle addMessageCallback(const std::function &selector, + const std::function &on_message, + const std::function &on_error) { + return addMessageCallback([selector, on_message](const Message &message) { + if (selector(message)) { + on_message(message); + } + }, on_error); + } + + CallbackHandle addMessageCallback(int message_id, std::function on_message, + int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) { + return addMessageCallback([message_id, source_id, component_id](const Message &message) { + return message.id() == message_id && + (source_id == mav::ANY_ID || message.header().systemId() == source_id) && + (component_id == mav::ANY_ID || message.header().componentId() == component_id); + }, on_message, std::function{}); + } + + CallbackHandle addMessageCallback(const std::string &message_name, std::function on_message, + int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) { + return addMessageCallback(_message_set.idForMessage(message_name), on_message, source_id, component_id); } void removeMessageCallback(CallbackHandle handle) { diff --git a/tests/Network.cpp b/tests/Network.cpp index 381f43d..dea86a2 100644 --- a/tests/Network.cpp +++ b/tests/Network.cpp @@ -308,4 +308,20 @@ TEST_CASE("Create network runtime") { connection->receive("HEARTBEAT"); CHECK_EQ(connection->callbackCount(), 0); } + + SUBCASE("Message callback for specific message is called when message arrives") { + interface.reset(); + std::promise callback_called_promise; + auto callback_called_future = callback_called_promise.get_future(); + + connection->addMessageCallback("TEST_MESSAGE", [&callback_called_promise](const Message &message) { + if (message.name() == "TEST_MESSAGE") { + callback_called_promise.set_value(); + } + }); + + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner); + CHECK((callback_called_future.wait_for(std::chrono::seconds(2)) != std::future_status::timeout)); + connection->removeAllCallbacks(); + } }