Skip to content

Commit

Permalink
Merge pull request #43 from Auterion/callback-cleanup-on-timeout
Browse files Browse the repository at this point in the history
Correct cleanup on repeated waiting for never arriving messages
  • Loading branch information
ThomasDebrunner authored Feb 29, 2024
2 parents a9f8403 + f81e224 commit 67b819b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
32 changes: 24 additions & 8 deletions include/mav/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ namespace mav {
using Expectation = std::shared_ptr<std::promise<Message>>;

private:
using ExpectationWeakRef = std::weak_ptr<std::promise<Message>>;

struct FunctionCallback {
std::function<void(const Message &message)> callback;
std::function<void(const std::exception_ptr& exception)> error_callback;
};

struct PromiseCallback {
Expectation promise;
ExpectationWeakRef promise;
std::function<bool(const Message &message)> selector;
};

Expand All @@ -88,6 +89,11 @@ namespace mav {

public:

size_t callbackCount() {
std::scoped_lock<std::mutex> lock(_message_callback_mtx);
return _message_callbacks.size();
}

void removeAllCallbacks() {
std::scoped_lock<std::mutex> lock(_message_callback_mtx);
_message_callbacks.clear();
Expand All @@ -102,7 +108,7 @@ namespace mav {
return _partner;
}

void consumeMessageFromNetwork(const Message& message) {
void consumeMessageFromNetwork(const Message& message) noexcept {
// in case we received a heartbeat, update last heartbeat time, to keep the connection alive.
_last_received_ms = millis();

Expand All @@ -121,19 +127,24 @@ namespace mav {
}
it++;
} else if constexpr (std::is_same_v<T, PromiseCallback>) {
if (arg.selector(message)) {
arg.promise->set_value(message);
auto promise = arg.promise.lock();
if (!promise) {
it = _message_callbacks.erase(it);
} else {
it++;
if (arg.selector(message)) {
promise->set_value(message);
it = _message_callbacks.erase(it);
} else {
it++;
}
}
}
}, callback);
}
}
}

void consumeNetworkExceptionFromNetwork(const std::exception_ptr& exception) {
void consumeNetworkExceptionFromNetwork(const std::exception_ptr& exception) noexcept {
_underlying_network_fault = true;
std::scoped_lock<std::mutex> lock(_message_callback_mtx);
auto it = _message_callbacks.begin();
Expand All @@ -147,8 +158,13 @@ namespace mav {
}
it++;
} else if constexpr (std::is_same_v<T, PromiseCallback>) {
arg.promise->set_exception(exception);
it = _message_callbacks.erase(it);
auto promise = arg.promise.lock();
if (!promise) {
it = _message_callbacks.erase(it);
} else {
promise->set_exception(exception);
it = _message_callbacks.erase(it);
}
}
}, callback);
}
Expand Down
44 changes: 43 additions & 1 deletion tests/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,24 @@ TEST_CASE("Create network runtime") {
CHECK_THROWS_AS(auto message = connection->receive(expectation, 100), TimeoutException);
}

SUBCASE("Receive throws a NetworkError if the interface fails") {
SUBCASE("Receive throws a NetworkError if the interface fails, error callback gets called") {
interface.reset();

// add a callback using the callback API. The error should then call the error callback
auto error_callback_called_promise = std::promise<void>();
connection->addMessageCallback([](const Message &message) {
// do nothing
}, [&error_callback_called_promise](const std::exception_ptr& exception) {
error_callback_called_promise.set_value();
CHECK_THROWS_AS(std::rethrow_exception(exception), NetworkError);
});

auto expectation = connection->expect("TEST_MESSAGE");
interface.makeFailOnNextReceive();
// Receive on the sync api. The receive should then throw an exception
CHECK_THROWS_AS(auto message = connection->receive(expectation), NetworkError);
CHECK((error_callback_called_promise.get_future().wait_for(std::chrono::seconds(2)) != std::future_status::timeout));
connection->removeAllCallbacks();
}

SUBCASE("Connection recycled on recover after fail") {
Expand Down Expand Up @@ -265,4 +278,33 @@ TEST_CASE("Create network runtime") {
interface_partner));
CHECK(found);
}

SUBCASE("Correct callback called when message is received") {
interface.reset();
std::promise<void> callback_called_promise;
auto callback_called_future = callback_called_promise.get_future();

connection->addMessageCallback([&callback_called_promise](const Message &message) {
if (message.name() == "TEST_MESSAGE") {
callback_called_promise.set_value();
}
});

interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner);
CHECK((callback_called_future.wait_for(std::chrono::seconds(2)) != std::future_status::timeout));
connection->removeAllCallbacks();
}

SUBCASE("Callbacks are cleaned up on receive timeout") {
interface.reset();
for (int i = 0; i < 10; i++) {
auto expectation = connection->expect("TEST_MESSAGE");
CHECK_THROWS_AS(auto message = connection->receive(expectation, 100), TimeoutException);
}
// send a heartbeat. Any message will clear expired expectations
interface.addToReceiveQueue("\xfd\x09\x00\x00\x00\xfd\x01\x00\x00\x00\x04\x00\x00\x00\x01\x02\x03\x05\x06\x77\x53"s, interface_partner);
// wait for the heartbeat to be received, to make sure timing is correct in test
connection->receive("HEARTBEAT");
CHECK_EQ(connection->callbackCount(), 0);
}
}

0 comments on commit 67b819b

Please sign in to comment.