Skip to content

Commit

Permalink
fix: locking primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Dec 3, 2024
1 parent f584b2f commit 7ae853e
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 39 deletions.
11 changes: 9 additions & 2 deletions include/dpp/discordclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
#include <dpp/event.h>
#include <queue>
#include <thread>
#include <memory>
#include <deque>
#include <dpp/etf.h>
#include <mutex>
#include <shared_mutex>

Expand Down Expand Up @@ -270,6 +272,11 @@ class DPP_EXPORT discord_client : public websocket_client
*/
std::shared_mutex queue_mutex;

/**
* @brief Mutex for zlib pointer
*/
std::mutex zlib_mutex;

/**
* @brief Queue of outbound messages
*/
Expand All @@ -283,7 +290,7 @@ class DPP_EXPORT discord_client : public websocket_client
/**
* @brief ZLib decompression buffer
*/
unsigned char* decomp_buffer;
std::vector<unsigned char*> decomp_buffer;

/**
* @brief Decompressed string
Expand Down Expand Up @@ -316,7 +323,7 @@ class DPP_EXPORT discord_client : public websocket_client
/**
* @brief ETF parser for when in ws_etf mode
*/
class etf_parser* etf;
std::unique_ptr<etf_parser> etf;

/**
* @brief Convert a JSON object to string.
Expand Down
3 changes: 3 additions & 0 deletions include/dpp/sslclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <string>
#include <functional>
#include <ctime>
#include <mutex>
#include <dpp/socket.h>
#include <cstdint>
#include <dpp/timer.h>
Expand Down Expand Up @@ -90,6 +91,8 @@ class DPP_EXPORT ssl_client
*/
void cleanup();

std::mutex ssl_mutex;

/**
* @brief Start offset into internal ring buffer for client to server IO
*/
Expand Down
45 changes: 15 additions & 30 deletions src/dpp/discordclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ thread_local static std::string last_ping_message;
discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint32_t _max_shards, const std::string &_token, uint32_t _intents, bool comp, websocket_protocol_t ws_proto)
: websocket_client(_cluster, _cluster->default_gateway, "443", comp ? (ws_proto == ws_json ? PATH_COMPRESSED_JSON : PATH_COMPRESSED_ETF) : (ws_proto == ws_json ? PATH_UNCOMPRESSED_JSON : PATH_UNCOMPRESSED_ETF)),
compressed(comp),
decomp_buffer(nullptr),
zlib(nullptr),
decompressed_total(0),
connect_time(0),
Expand All @@ -86,41 +85,23 @@ discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint3
ready(false),
last_heartbeat_ack(time(nullptr)),
protocol(ws_proto),
resume_gateway_url(_cluster->default_gateway)
resume_gateway_url(_cluster->default_gateway)
{
etf = std::make_unique<etf_parser>(etf_parser());
start_connecting();
}

void discord_client::start_connecting() {
try {
zlib = new zlibcontext();
etf = new etf_parser();
}
catch (std::bad_alloc&) {
cleanup();
/* Clean up and rethrow to caller */
throw std::bad_alloc();
}
try {
this->connect();
}
catch (std::exception&) {
cleanup();
throw;
}
this->connect();
}

void discord_client::cleanup()
{
delete etf;
delete zlib;
etf = nullptr;
zlib = nullptr;
}

discord_client::~discord_client()
{
cleanup();
end_zlib();
}

void discord_client::on_disconnect()
Expand All @@ -137,7 +118,6 @@ void discord_client::on_disconnect()
reconnect_timer = owner->start_timer([this](auto handle) {
log(dpp::ll_debug, "Reconnecting shard " + std::to_string(shard_id) + " to wss://" + hostname + "...");
try {
cleanup();
if (timer_handle) {
owner->stop_timer(timer_handle);
timer_handle = 0;
Expand Down Expand Up @@ -165,26 +145,31 @@ uint64_t discord_client::get_decompressed_bytes_in()

void discord_client::setup_zlib()
{
std::lock_guard<std::mutex> lock(zlib_mutex);
if (compressed) {
if (zlib == nullptr) {
zlib = new zlibcontext();
}
zlib->d_stream.zalloc = (alloc_func)0;
zlib->d_stream.zfree = (free_func)0;
zlib->d_stream.opaque = (voidpf)0;
int error = inflateInit(&(zlib->d_stream));
if (error != Z_OK) {
throw dpp::connection_exception((exception_error_code)error, "Can't initialise stream compression!");
}
this->decomp_buffer = new unsigned char[DECOMP_BUFFER_SIZE];
decomp_buffer.resize(DECOMP_BUFFER_SIZE);
}

}

void discord_client::end_zlib()
{
if (compressed) {
std::lock_guard<std::mutex> lock(zlib_mutex);
if (compressed && zlib) {
inflateEnd(&(zlib->d_stream));
delete[] this->decomp_buffer;
this->decomp_buffer = nullptr;
}
delete zlib;
zlib = nullptr;
}

void discord_client::set_resume_hostname()
Expand Down Expand Up @@ -214,7 +199,7 @@ bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode)
zlib->d_stream.next_in = (Bytef *)buffer.c_str();
zlib->d_stream.avail_in = (uInt)buffer.size();
do {
zlib->d_stream.next_out = (Bytef*)decomp_buffer;
zlib->d_stream.next_out = (Bytef*)decomp_buffer.data();
zlib->d_stream.avail_out = DECOMP_BUFFER_SIZE;
int ret = inflate(&(zlib->d_stream), Z_NO_FLUSH);
int have = DECOMP_BUFFER_SIZE - zlib->d_stream.avail_out;
Expand All @@ -234,7 +219,7 @@ bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode)
this->close();
return true;
case Z_OK:
this->decompressed.append((const char*)decomp_buffer, have);
this->decompressed.append((const char*)decomp_buffer.data(), have);
this->decompressed_total += have;
break;
default:
Expand Down
2 changes: 2 additions & 0 deletions src/dpp/sslclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ void ssl_client::on_write(socket fd, const struct socket_events& e) {
}
if (!ssl->ssl) {
/* Create SSL session */
std::lock_guard<std::mutex> lock(ssl_mutex);
ssl->ssl = SSL_new(openssl_context.get());
if (ssl->ssl == nullptr) {
throw dpp::connection_exception(err_ssl_new, "SSL_new failed!");
Expand Down Expand Up @@ -539,6 +540,7 @@ bool ssl_client::handle_buffer(std::string &buffer)
void ssl_client::close()
{
if (!plaintext && ssl->ssl) {
std::lock_guard<std::mutex> lock(ssl_mutex);
SSL_free(ssl->ssl);
ssl->ssl = nullptr;
}
Expand Down
19 changes: 12 additions & 7 deletions src/soaktest/soak.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
#include <dpp/dpp.h>
#include <iostream>
#include <thread>
#include <atomic>
#ifndef _WIN32
#include <csignal>
#endif

dpp::cluster* s{nullptr};
std::atomic_bool signalled{false};

int main() {
using namespace std::chrono_literals;
Expand All @@ -42,19 +44,22 @@ int main() {
soak_test.start(dpp::st_return);

#ifndef _WIN32
signal(SIGINT, [](int sig) {
dpp::discord_client* dc = s->get_shard(0);
if (dc != nullptr) {
dc->close();
}
signal(SIGUSR1, [](int sig) {
signalled = true;
});
#endif

while (true) {
std::this_thread::sleep_for(60s);
std::this_thread::sleep_for(1s);
dpp::discord_client* dc = soak_test.get_shard(0);
if (dc != nullptr) {
std::cout << "Websocket latency: " << std::fixed << dc->websocket_ping << " Guilds: " << dpp::get_guild_count() << " Users: " << dpp::get_user_count() << "\n";
if (time(nullptr) % 60 == 0) {
std::cout << "Websocket latency: " << std::fixed << dc->websocket_ping << " Guilds: " << dpp::get_guild_count() << " Users: " << dpp::get_user_count() << "\n";
}
if (signalled) {
signalled = false;
dc->close();
}
}
}
}
Expand Down

0 comments on commit 7ae853e

Please sign in to comment.