diff --git a/include/mav/Connection.h b/include/mav/Connection.h index bc3af7e..2cdb497 100644 --- a/include/mav/Connection.h +++ b/include/mav/Connection.h @@ -55,6 +55,7 @@ namespace mav { using Expectation = std::shared_ptr>; private: + using ExpectationWeakRef = std::weak_ptr>; struct FunctionCallback { std::function callback; @@ -62,7 +63,7 @@ namespace mav { }; struct PromiseCallback { - Expectation promise; + ExpectationWeakRef promise; std::function selector; }; @@ -88,6 +89,11 @@ namespace mav { public: + size_t callbackCount() { + std::scoped_lock lock(_message_callback_mtx); + return _message_callbacks.size(); + } + void removeAllCallbacks() { std::scoped_lock lock(_message_callback_mtx); _message_callbacks.clear(); @@ -102,7 +108,7 @@ namespace mav { return _partner; } - void consumeMessageFromNetwork(const Message& message) { + void consumeMessageFromNetwork(const Message& message) noexcept { // in case we received a heartbeat, update last heartbeat time, to keep the connection alive. _last_received_ms = millis(); @@ -121,11 +127,16 @@ namespace mav { } it++; } else if constexpr (std::is_same_v) { - if (arg.selector(message)) { - arg.promise->set_value(message); + auto promise = arg.promise.lock(); + if (!promise) { it = _message_callbacks.erase(it); } else { - it++; + if (arg.selector(message)) { + promise->set_value(message); + it = _message_callbacks.erase(it); + } else { + it++; + } } } }, callback); @@ -133,7 +144,7 @@ namespace mav { } } - void consumeNetworkExceptionFromNetwork(const std::exception_ptr& exception) { + void consumeNetworkExceptionFromNetwork(const std::exception_ptr& exception) noexcept { _underlying_network_fault = true; std::scoped_lock lock(_message_callback_mtx); auto it = _message_callbacks.begin(); @@ -147,8 +158,13 @@ namespace mav { } it++; } else if constexpr (std::is_same_v) { - arg.promise->set_exception(exception); - it = _message_callbacks.erase(it); + auto promise = arg.promise.lock(); + if (!promise) { + it = _message_callbacks.erase(it); + } else { + promise->set_exception(exception); + it = _message_callbacks.erase(it); + } } }, callback); } diff --git a/tests/Network.cpp b/tests/Network.cpp index 18b4e03..7aac93f 100644 --- a/tests/Network.cpp +++ b/tests/Network.cpp @@ -199,11 +199,24 @@ TEST_CASE("Create network runtime") { CHECK_THROWS_AS(auto message = connection->receive(expectation, 100), TimeoutException); } - SUBCASE("Receive throws a NetworkError if the interface fails") { + SUBCASE("Receive throws a NetworkError if the interface fails, error callback gets called") { interface.reset(); + + // add a callback using the callback API. The error should then call the error callback + auto error_callback_called_promise = std::promise(); + connection->addMessageCallback([](const Message &message) { + // do nothing + }, [&error_callback_called_promise](const std::exception_ptr& exception) { + error_callback_called_promise.set_value(); + CHECK_THROWS_AS(std::rethrow_exception(exception), NetworkError); + }); + auto expectation = connection->expect("TEST_MESSAGE"); interface.makeFailOnNextReceive(); + // Receive on the sync api. The receive should then throw an exception CHECK_THROWS_AS(auto message = connection->receive(expectation), NetworkError); + CHECK((error_callback_called_promise.get_future().wait_for(std::chrono::seconds(2)) != std::future_status::timeout)); + connection->removeAllCallbacks(); } SUBCASE("Connection recycled on recover after fail") { @@ -265,4 +278,33 @@ TEST_CASE("Create network runtime") { interface_partner)); CHECK(found); } + + SUBCASE("Correct callback called when message is received") { + interface.reset(); + std::promise callback_called_promise; + auto callback_called_future = callback_called_promise.get_future(); + + connection->addMessageCallback([&callback_called_promise](const Message &message) { + if (message.name() == "TEST_MESSAGE") { + callback_called_promise.set_value(); + } + }); + + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner); + CHECK((callback_called_future.wait_for(std::chrono::seconds(2)) != std::future_status::timeout)); + connection->removeAllCallbacks(); + } + + SUBCASE("Callbacks are cleaned up on receive timeout") { + interface.reset(); + for (int i = 0; i < 10; i++) { + auto expectation = connection->expect("TEST_MESSAGE"); + CHECK_THROWS_AS(auto message = connection->receive(expectation, 100), TimeoutException); + } + // send a heartbeat. Any message will clear expired expectations + interface.addToReceiveQueue("\xfd\x09\x00\x00\x00\xfd\x01\x00\x00\x00\x04\x00\x00\x00\x01\x02\x03\x05\x06\x77\x53"s, interface_partner); + // wait for the heartbeat to be received, to make sure timing is correct in test + connection->receive("HEARTBEAT"); + CHECK_EQ(connection->callbackCount(), 0); + } }