From 7ae853ed1b061ffd7fae0b6537c1d49704dd0e57 Mon Sep 17 00:00:00 2001 From: Craig Edwards Date: Tue, 3 Dec 2024 13:43:01 +0000 Subject: [PATCH] fix: locking primitives --- include/dpp/discordclient.h | 11 +++++++-- include/dpp/sslclient.h | 3 +++ src/dpp/discordclient.cpp | 45 +++++++++++++------------------------ src/dpp/sslclient.cpp | 2 ++ src/soaktest/soak.cpp | 19 ++++++++++------ 5 files changed, 41 insertions(+), 39 deletions(-) diff --git a/include/dpp/discordclient.h b/include/dpp/discordclient.h index c5427e5f68..116b80f354 100644 --- a/include/dpp/discordclient.h +++ b/include/dpp/discordclient.h @@ -31,7 +31,9 @@ #include #include #include +#include #include +#include #include #include @@ -270,6 +272,11 @@ class DPP_EXPORT discord_client : public websocket_client */ std::shared_mutex queue_mutex; + /** + * @brief Mutex for zlib pointer + */ + std::mutex zlib_mutex; + /** * @brief Queue of outbound messages */ @@ -283,7 +290,7 @@ class DPP_EXPORT discord_client : public websocket_client /** * @brief ZLib decompression buffer */ - unsigned char* decomp_buffer; + std::vector decomp_buffer; /** * @brief Decompressed string @@ -316,7 +323,7 @@ class DPP_EXPORT discord_client : public websocket_client /** * @brief ETF parser for when in ws_etf mode */ - class etf_parser* etf; + std::unique_ptr etf; /** * @brief Convert a JSON object to string. diff --git a/include/dpp/sslclient.h b/include/dpp/sslclient.h index 552c677361..3d3c5a8566 100644 --- a/include/dpp/sslclient.h +++ b/include/dpp/sslclient.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -90,6 +91,8 @@ class DPP_EXPORT ssl_client */ void cleanup(); + std::mutex ssl_mutex; + /** * @brief Start offset into internal ring buffer for client to server IO */ diff --git a/src/dpp/discordclient.cpp b/src/dpp/discordclient.cpp index 3948132ab9..0d9bb6ac3b 100644 --- a/src/dpp/discordclient.cpp +++ b/src/dpp/discordclient.cpp @@ -66,7 +66,6 @@ thread_local static std::string last_ping_message; discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint32_t _max_shards, const std::string &_token, uint32_t _intents, bool comp, websocket_protocol_t ws_proto) : websocket_client(_cluster, _cluster->default_gateway, "443", comp ? (ws_proto == ws_json ? PATH_COMPRESSED_JSON : PATH_COMPRESSED_ETF) : (ws_proto == ws_json ? PATH_UNCOMPRESSED_JSON : PATH_UNCOMPRESSED_ETF)), compressed(comp), - decomp_buffer(nullptr), zlib(nullptr), decompressed_total(0), connect_time(0), @@ -86,41 +85,23 @@ discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint3 ready(false), last_heartbeat_ack(time(nullptr)), protocol(ws_proto), - resume_gateway_url(_cluster->default_gateway) + resume_gateway_url(_cluster->default_gateway) { + etf = std::make_unique(etf_parser()); start_connecting(); } void discord_client::start_connecting() { - try { - zlib = new zlibcontext(); - etf = new etf_parser(); - } - catch (std::bad_alloc&) { - cleanup(); - /* Clean up and rethrow to caller */ - throw std::bad_alloc(); - } - try { - this->connect(); - } - catch (std::exception&) { - cleanup(); - throw; - } + this->connect(); } void discord_client::cleanup() { - delete etf; - delete zlib; - etf = nullptr; - zlib = nullptr; } discord_client::~discord_client() { - cleanup(); + end_zlib(); } void discord_client::on_disconnect() @@ -137,7 +118,6 @@ void discord_client::on_disconnect() reconnect_timer = owner->start_timer([this](auto handle) { log(dpp::ll_debug, "Reconnecting shard " + std::to_string(shard_id) + " to wss://" + hostname + "..."); try { - cleanup(); if (timer_handle) { owner->stop_timer(timer_handle); timer_handle = 0; @@ -165,7 +145,11 @@ uint64_t discord_client::get_decompressed_bytes_in() void discord_client::setup_zlib() { + std::lock_guard lock(zlib_mutex); if (compressed) { + if (zlib == nullptr) { + zlib = new zlibcontext(); + } zlib->d_stream.zalloc = (alloc_func)0; zlib->d_stream.zfree = (free_func)0; zlib->d_stream.opaque = (voidpf)0; @@ -173,18 +157,19 @@ void discord_client::setup_zlib() if (error != Z_OK) { throw dpp::connection_exception((exception_error_code)error, "Can't initialise stream compression!"); } - this->decomp_buffer = new unsigned char[DECOMP_BUFFER_SIZE]; + decomp_buffer.resize(DECOMP_BUFFER_SIZE); } } void discord_client::end_zlib() { - if (compressed) { + std::lock_guard lock(zlib_mutex); + if (compressed && zlib) { inflateEnd(&(zlib->d_stream)); - delete[] this->decomp_buffer; - this->decomp_buffer = nullptr; } + delete zlib; + zlib = nullptr; } void discord_client::set_resume_hostname() @@ -214,7 +199,7 @@ bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode) zlib->d_stream.next_in = (Bytef *)buffer.c_str(); zlib->d_stream.avail_in = (uInt)buffer.size(); do { - zlib->d_stream.next_out = (Bytef*)decomp_buffer; + zlib->d_stream.next_out = (Bytef*)decomp_buffer.data(); zlib->d_stream.avail_out = DECOMP_BUFFER_SIZE; int ret = inflate(&(zlib->d_stream), Z_NO_FLUSH); int have = DECOMP_BUFFER_SIZE - zlib->d_stream.avail_out; @@ -234,7 +219,7 @@ bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode) this->close(); return true; case Z_OK: - this->decompressed.append((const char*)decomp_buffer, have); + this->decompressed.append((const char*)decomp_buffer.data(), have); this->decompressed_total += have; break; default: diff --git a/src/dpp/sslclient.cpp b/src/dpp/sslclient.cpp index dbe85c1060..cfef6e949b 100644 --- a/src/dpp/sslclient.cpp +++ b/src/dpp/sslclient.cpp @@ -402,6 +402,7 @@ void ssl_client::on_write(socket fd, const struct socket_events& e) { } if (!ssl->ssl) { /* Create SSL session */ + std::lock_guard lock(ssl_mutex); ssl->ssl = SSL_new(openssl_context.get()); if (ssl->ssl == nullptr) { throw dpp::connection_exception(err_ssl_new, "SSL_new failed!"); @@ -539,6 +540,7 @@ bool ssl_client::handle_buffer(std::string &buffer) void ssl_client::close() { if (!plaintext && ssl->ssl) { + std::lock_guard lock(ssl_mutex); SSL_free(ssl->ssl); ssl->ssl = nullptr; } diff --git a/src/soaktest/soak.cpp b/src/soaktest/soak.cpp index 679da011ae..fecebdc853 100644 --- a/src/soaktest/soak.cpp +++ b/src/soaktest/soak.cpp @@ -23,11 +23,13 @@ #include #include #include +#include #ifndef _WIN32 #include #endif dpp::cluster* s{nullptr}; +std::atomic_bool signalled{false}; int main() { using namespace std::chrono_literals; @@ -42,19 +44,22 @@ int main() { soak_test.start(dpp::st_return); #ifndef _WIN32 - signal(SIGINT, [](int sig) { - dpp::discord_client* dc = s->get_shard(0); - if (dc != nullptr) { - dc->close(); - } + signal(SIGUSR1, [](int sig) { + signalled = true; }); #endif while (true) { - std::this_thread::sleep_for(60s); + std::this_thread::sleep_for(1s); dpp::discord_client* dc = soak_test.get_shard(0); if (dc != nullptr) { - std::cout << "Websocket latency: " << std::fixed << dc->websocket_ping << " Guilds: " << dpp::get_guild_count() << " Users: " << dpp::get_user_count() << "\n"; + if (time(nullptr) % 60 == 0) { + std::cout << "Websocket latency: " << std::fixed << dc->websocket_ping << " Guilds: " << dpp::get_guild_count() << " Users: " << dpp::get_user_count() << "\n"; + } + if (signalled) { + signalled = false; + dc->close(); + } } } }