diff --git a/include/ocpp/common/websocket/websocket_base.hpp b/include/ocpp/common/websocket/websocket_base.hpp index 997aa46c4..2881dc996 100644 --- a/include/ocpp/common/websocket/websocket_base.hpp +++ b/include/ocpp/common/websocket/websocket_base.hpp @@ -47,7 +47,7 @@ enum class ConnectionFailedReason { /// class WebsocketBase { protected: - bool m_is_connected; + std::atomic_bool m_is_connected; WebsocketConnectionOptions connection_options; std::function connected_callback; std::function disconnected_callback; @@ -59,11 +59,11 @@ class WebsocketBase { websocketpp::connection_hdl handle; std::mutex reconnect_mutex; std::mutex connection_mutex; - long reconnect_backoff_ms; + std::atomic_int reconnect_backoff_ms; websocketpp::transport::timer_handler reconnect_callback; - int connection_attempts; - bool shutting_down; - bool reconnecting; + std::atomic_int connection_attempts; + std::atomic_bool shutting_down; + std::atomic_bool reconnecting; /// \brief Indicates if the required callbacks are registered /// \returns true if the websocket is properly initialized diff --git a/include/ocpp/common/websocket/websocket_tls_tpm.hpp b/include/ocpp/common/websocket/websocket_tls_tpm.hpp index 67709a382..655f0b146 100644 --- a/include/ocpp/common/websocket/websocket_tls_tpm.hpp +++ b/include/ocpp/common/websocket/websocket_tls_tpm.hpp @@ -67,7 +67,7 @@ class WebsocketTlsTPM final : public WebsocketBase { void request_write(); - void poll_message(const std::shared_ptr& msg, bool wait_sendaf); + void poll_message(const std::shared_ptr& msg); private: std::shared_ptr evse_security; @@ -79,8 +79,10 @@ class WebsocketTlsTPM final : public WebsocketBase { std::condition_variable conn_cv; std::mutex queue_mutex; + std::queue> message_queue; std::condition_variable msg_send_cv; + std::mutex msg_send_cv_mutex; std::unique_ptr recv_message_thread; std::mutex recv_mutex; diff --git a/lib/ocpp/common/websocket/websocket_tls_tpm.cpp b/lib/ocpp/common/websocket/websocket_tls_tpm.cpp index d410bdf7a..08f2161ee 100644 --- a/lib/ocpp/common/websocket/websocket_tls_tpm.cpp +++ b/lib/ocpp/common/websocket/websocket_tls_tpm.cpp @@ -87,7 +87,7 @@ struct ConnectionData { } bool is_connecting() { - return (state == EConnectionState::CONNECTING); + return (state.load() == EConnectionState::CONNECTING); } bool is_close_requested() { @@ -95,33 +95,39 @@ struct ConnectionData { } auto get_state() { - return state; + return state.load(); } lws* get_conn() { return wsi; } - lws_context* get_ctx() { - return lws_ctx.get(); + WebsocketTlsTPM* get_owner() { + return owner.load(); + } + + void set_owner(WebsocketTlsTPM* o) { + owner = o; } public: + // This public block will only be used from client loop thread, no locking needed // Openssl context, must be destroyed in this order std::unique_ptr sec_context; std::unique_ptr sec_lib_context; - // libwebsockets state std::unique_ptr lws_ctx; - lws* wsi; - WebsocketTlsTPM* owner; + lws* wsi; private: + std::atomic owner; + std::thread::id lws_thread_id; - bool is_running; - bool is_marked_close; - EConnectionState state; + + std::atomic_bool is_running; + std::atomic_bool is_marked_close; + std::atomic state; }; struct WebsocketMessage { @@ -140,7 +146,7 @@ struct WebsocketMessage { // just that these were sent to libwebsockets size_t sent_bytes; // If libwebsockets has sent all the bytes through the wire - volatile bool message_sent; + std::atomic_bool message_sent; }; WebsocketTlsTPM::WebsocketTlsTPM(const WebsocketConnectionOptions& connection_options, @@ -189,7 +195,12 @@ static int callback_minimal(struct lws* wsi, enum lws_callback_reasons reason, v // Get user safely, since on some callbacks (void *user) can be different than what we set if (wsi != nullptr) { if (ConnectionData* data = reinterpret_cast(lws_wsi_user(wsi))) { - return data->owner->process_callback(wsi, static_cast(reason), user, in, len); + auto owner = data->get_owner(); + if (owner not_eq nullptr) { + return data->get_owner()->process_callback(wsi, static_cast(reason), user, in, len); + } else { + EVLOG_error << "callback_minimal called, but data->owner is nullptr"; + } } } @@ -331,11 +342,14 @@ void WebsocketTlsTPM::recv_loop() { while (false == data->is_interupted()) { // Process all messages - while (false == recv_message_queue.empty()) { + while (true) { std::string message{}; { std::lock_guard lk(this->recv_mutex); + if (recv_message_queue.empty()) + break; + message = std::move(recv_message_queue.front()); recv_message_queue.pop(); } @@ -459,9 +473,14 @@ void WebsocketTlsTPM::client_loop() { while (n >= 0 && (false == data->is_interupted())) { // Set to -1 for continuous servicing, of required, not recommended - n = lws_service(data->get_ctx(), 0); + n = lws_service(data->lws_ctx.get(), 0); - if (false == message_queue.empty()) { + bool message_queue_empty; + { + std::lock_guard lock(this->queue_mutex); + message_queue_empty = message_queue.empty(); + } + if (false == message_queue_empty) { lws_callback_on_writable(data->get_conn()); } } @@ -470,6 +489,7 @@ void WebsocketTlsTPM::client_loop() { EVLOG_debug << "Exit client loop with ID: " << std::this_thread::get_id(); } +// Will be called from external threads as well bool WebsocketTlsTPM::connect() { if (!this->initialized()) { return false; @@ -485,19 +505,42 @@ bool WebsocketTlsTPM::connect() { } auto conn_data = new ConnectionData(); - conn_data->owner = this; + conn_data->set_owner(this); this->conn_data.reset(conn_data); // Wait old thread for a clean state if (this->websocket_thread) { + // Awake libwebsockets thread to quickly exit + request_write(); this->websocket_thread->join(); } if (this->recv_message_thread) { + // Awake the receiving message thread to finish + recv_message_cv.notify_one(); this->recv_message_thread->join(); } + // Stop any pending reconnect timer + { + std::lock_guard lk(this->reconnect_mutex); + this->reconnect_timer_tpm.stop(); + } + + // Clear any pending messages on a new connection + { + std::lock_guard lock(queue_mutex); + std::queue> empty; + empty.swap(message_queue); + } + + { + std::lock_guard lock(recv_mutex); + std::queue empty; + empty.swap(recv_message_queue); + } + // Bind reconnect callback this->reconnect_callback = [this](const websocketpp::lib::error_code& ec) { EVLOG_info << "Reconnecting to TLS websocket at uri: " << this->connection_options.csms_uri.string() @@ -509,11 +552,6 @@ bool WebsocketTlsTPM::connect() { this->close(websocketpp::close::status::abnormal_close, "Reconnect"); } - { - std::lock_guard lk(this->reconnect_mutex); - this->reconnect_timer_tpm.stop(); - } - this->connect(); }; @@ -557,7 +595,6 @@ void WebsocketTlsTPM::reconnect(std::error_code reason, long delay) { return; } - std::lock_guard lk(this->reconnect_mutex); if (this->m_is_connected) { EVLOG_info << "Closing websocket connection before reconnecting"; this->close(websocketpp::close::status::abnormal_close, "Reconnect"); @@ -566,8 +603,11 @@ void WebsocketTlsTPM::reconnect(std::error_code reason, long delay) { EVLOG_info << "Reconnecting in: " << delay << "ms" << ", attempt: " << this->connection_attempts; - this->reconnect_timer_tpm.timeout([this]() { this->reconnect_callback(websocketpp::lib::error_code()); }, - std::chrono::milliseconds(delay)); + { + std::lock_guard lk(this->reconnect_mutex); + this->reconnect_timer_tpm.timeout([this]() { this->reconnect_callback(websocketpp::lib::error_code()); }, + std::chrono::milliseconds(delay)); + } } void WebsocketTlsTPM::close(websocketpp::close::status::value code, const std::string& reason) { @@ -590,7 +630,8 @@ void WebsocketTlsTPM::close(websocketpp::close::status::value code, const std::s } this->m_is_connected = false; - this->closed_callback(websocketpp::close::status::normal); + std::thread closing([this]() { this->closed_callback(websocketpp::close::status::normal); }); + closing.detach(); } void WebsocketTlsTPM::on_conn_connected() { @@ -600,7 +641,8 @@ void WebsocketTlsTPM::on_conn_connected() { this->m_is_connected = true; this->reconnecting = false; - this->connected_callback(this->connection_options.security_profile); + std::thread connected([this]() { this->connected_callback(this->connection_options.security_profile); }); + connected.detach(); } void WebsocketTlsTPM::on_conn_close() { @@ -611,7 +653,8 @@ void WebsocketTlsTPM::on_conn_close() { this->disconnected_callback(); this->cancel_reconnect_timer(); - this->closed_callback(websocketpp::close::status::normal); + std::thread closing([this]() { this->closed_callback(websocketpp::close::status::normal); }); + closing.detach(); } void WebsocketTlsTPM::on_conn_fail() { @@ -619,7 +662,8 @@ void WebsocketTlsTPM::on_conn_fail() { std::lock_guard lk(this->connection_mutex); if (this->m_is_connected) { - this->disconnected_callback(); + std::thread disconnect([this]() { this->disconnected_callback(); }); + disconnect.detach(); } this->m_is_connected = false; @@ -717,11 +761,17 @@ void WebsocketTlsTPM::on_writable() { return; } - while (false == message_queue.empty()) { + // Execute while we have messages that were polled + while (true) { WebsocketMessage* message = nullptr; { std::lock_guard lock(this->queue_mutex); + + // Break if we have en empty queue + if (message_queue.empty()) + break; + message = message_queue.front().get(); } @@ -729,7 +779,7 @@ void WebsocketTlsTPM::on_writable() { EVLOG_AND_THROW(std::runtime_error("Null message in queue, fatal error!")); } - // Pop all sent messages + // This message was polled in a previous iteration if (message->sent_bytes >= message->payload.length()) { EVLOG_info << "Websocket message fully written, popping processing thread from queue!"; @@ -744,20 +794,47 @@ void WebsocketTlsTPM::on_writable() { EVLOG_debug << "Notifying waiting thread!"; // Notify any waiting thread to check it's state - msg_send_cv.notify_one(); + msg_send_cv.notify_all(); } else { - EVLOG_debug << "Client writable, sending message part!"; + // If the message was not polled, we reached the first unpolled and break + break; + } + } + + // If we still have message ONLY poll a single one that can be processed in the invoke of the function + // libwebsockets is designed so that when a message is sent to the wire from the internal buffer it + // will invoke 'on_writable' again and we can execute the code above + bool any_message_polled; + { + std::lock_guard lock(this->queue_mutex); + any_message_polled = not message_queue.empty(); + } - // Continue sending message part, for a single message only - bool sent = send_internal(data->get_conn(), message); + // Poll a single message + if (any_message_polled) { + EVLOG_debug << "Client writable, sending message part!"; - // If we failed, attempt again later - if (false == sent) { - message->sent_bytes = 0; - } + WebsocketMessage* message = nullptr; - // Break loop - break; + { + std::lock_guard lock(this->queue_mutex); + message = message_queue.front().get(); + } + + if (message == nullptr) { + EVLOG_AND_THROW(std::runtime_error("Null message in queue, fatal error!")); + } + + if (message->sent_bytes >= message->payload.length()) { + EVLOG_AND_THROW(std::runtime_error("Already polled message should be handled above, fatal error!")); + } + + // Continue sending message part, for a single message only + bool sent = send_internal(data->get_conn(), message); + + // If we failed, attempt again later + if (false == sent) { + message->sent_bytes = 0; } } } @@ -766,8 +843,9 @@ void WebsocketTlsTPM::request_write() { if (this->m_is_connected) { if (auto* data = conn_data.get()) { if (data->get_conn()) { - // Notify waiting processing thread to wake up - lws_cancel_service(data->get_ctx()); + // Notify waiting processing thread to wake up. According to docs it is ok to call from another + // thread. + lws_cancel_service(data->lws_ctx.get()); } } } else { @@ -775,7 +853,8 @@ void WebsocketTlsTPM::request_write() { } } -void WebsocketTlsTPM::poll_message(const std::shared_ptr& msg, bool wait_send) { +void WebsocketTlsTPM::poll_message(const std::shared_ptr& msg) { + if (std::this_thread::get_id() == conn_data->get_lws_thread_id()) { EVLOG_AND_THROW(std::runtime_error("Deadlock detected, polling send from client lws thread!")); } @@ -790,17 +869,17 @@ void WebsocketTlsTPM::poll_message(const std::shared_ptr& msg, // Request a write callback request_write(); - if (wait_send) { - std::unique_lock lock(this->queue_mutex); - msg_send_cv.wait_for(lock, std::chrono::seconds(10), [&] { return (true == msg->message_sent); }); + { + std::unique_lock lock(this->msg_send_cv_mutex); + if (msg_send_cv.wait_for(lock, std::chrono::seconds(20), [&] { return (true == msg->message_sent); })) { + EVLOG_info << "Successfully sent last message over TLS websocket!"; + } else { + EVLOG_warning << "Could not send last message over TLS websocket!"; + } } - - if (msg->message_sent) - EVLOG_info << "Successfully sent last message over TLS websocket!"; - else - EVLOG_warning << "Could not send last message over TLS websocket!"; } +// Will be called from external threads bool WebsocketTlsTPM::send(const std::string& message) { if (!this->initialized()) { EVLOG_error << "Could not send message because websocket is not properly initialized."; @@ -811,7 +890,7 @@ bool WebsocketTlsTPM::send(const std::string& message) { msg->payload = std::move(message); msg->protocol = LWS_WRITE_TEXT; - poll_message(msg, true); + poll_message(msg); return msg->message_sent; } @@ -825,7 +904,7 @@ void WebsocketTlsTPM::ping() { msg->payload = this->connection_options.ping_payload; msg->protocol = LWS_WRITE_PING; - poll_message(msg, true); + poll_message(msg); } int WebsocketTlsTPM::process_callback(void* wsi_ptr, int callback_reason, void* user, void* in, size_t len) { @@ -949,31 +1028,53 @@ int WebsocketTlsTPM::process_callback(void* wsi_ptr, int callback_reason, void* case LWS_CALLBACK_CLIENT_WRITEABLE: on_writable(); - - if (false == message_queue.empty()) { - lws_callback_on_writable(wsi); + { + bool message_queue_empty; + { + std::lock_guard lock(this->queue_mutex); + message_queue_empty = message_queue.empty(); + } + if (false == message_queue_empty) { + lws_callback_on_writable(wsi); + } } break; - case LWS_CALLBACK_CLIENT_RECEIVE_PONG: - if (false == message_queue.empty()) { + case LWS_CALLBACK_CLIENT_RECEIVE_PONG: { + bool message_queue_empty; + { + std::lock_guard lock(this->queue_mutex); + message_queue_empty = message_queue.empty(); + } + if (false == message_queue_empty) { lws_callback_on_writable(data->get_conn()); } - break; + } break; case LWS_CALLBACK_CLIENT_RECEIVE: on_message(in, len); - - if (false == message_queue.empty()) { - lws_callback_on_writable(data->get_conn()); + { + bool message_queue_empty; + { + std::lock_guard lock(this->queue_mutex); + message_queue_empty = message_queue.empty(); + } + if (false == message_queue_empty) { + lws_callback_on_writable(data->get_conn()); + } } break; - case LWS_CALLBACK_EVENT_WAIT_CANCELLED: - if (false == message_queue.empty()) { + case LWS_CALLBACK_EVENT_WAIT_CANCELLED: { + bool message_queue_empty; + { + std::lock_guard lock(this->queue_mutex); + message_queue_empty = message_queue.empty(); + } + if (false == message_queue_empty) { lws_callback_on_writable(data->get_conn()); } - break; + } break; default: EVLOG_info << "Callback with unhandled reason: " << reason;