From 315046ce7b916d5de1c80bd773498a561f7f5dca Mon Sep 17 00:00:00 2001 From: Craig Edwards Date: Mon, 2 Dec 2024 11:11:48 +0000 Subject: [PATCH] tidy up shard startup --- include/dpp/cluster_coro_calls.h | 8 ++++---- include/dpp/discordclient.h | 5 ----- include/dpp/exception.h | 1 + mlspp/include/namespace.h | 2 +- src/dpp/cluster.cpp | 35 ++++++++++++++++++++++++-------- src/dpp/discordclient.cpp | 6 +----- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/include/dpp/cluster_coro_calls.h b/include/dpp/cluster_coro_calls.h index 9077e9efa6..ffeefbc1c0 100644 --- a/include/dpp/cluster_coro_calls.h +++ b/include/dpp/cluster_coro_calls.h @@ -450,8 +450,8 @@ * @note This method supports audit log reasons set by the cluster::set_audit_reason() method. * @param c Channel to set permissions for * @param overwrite_id Overwrite to change (a user or role ID) - * @param allow allow permissions bitmask - * @param deny deny permissions bitmask + * @param allow Bitmask of allowed permissions (refer to enum dpp::permissions) + * @param deny Bitmask of denied permissions (refer to enum dpp::permissions) * @param member true if the overwrite_id is a user id, false if it is a channel id * @return confirmation returned object on completion * \memberof dpp::cluster @@ -466,8 +466,8 @@ * @note This method supports audit log reasons set by the cluster::set_audit_reason() method. * @param channel_id ID of the channel to set permissions for * @param overwrite_id Overwrite to change (a user or role ID) - * @param allow allow permissions bitmask - * @param deny deny permissions bitmask + * @param allow Bitmask of allowed permissions (refer to enum dpp::permissions) + * @param deny Bitmask of denied permissions (refer to enum dpp::permissions) * @param member true if the overwrite_id is a user id, false if it is a channel id * @return confirmation returned object on completion * \memberof dpp::cluster diff --git a/include/dpp/discordclient.h b/include/dpp/discordclient.h index bb7c4a0974..74a609fa0f 100644 --- a/include/dpp/discordclient.h +++ b/include/dpp/discordclient.h @@ -165,11 +165,6 @@ class DPP_EXPORT discord_client : public websocket_client */ friend class dpp::cluster; - /** - * @brief True if the shard is terminating - */ - bool terminating; - /** * @brief Disconnect from the connected voice channel on a guild * diff --git a/include/dpp/exception.h b/include/dpp/exception.h index 3d5f41ff08..9d6dd2d69d 100644 --- a/include/dpp/exception.h +++ b/include/dpp/exception.h @@ -98,6 +98,7 @@ enum exception_error_code { err_no_voice_support = 29, err_invalid_voice_packet_length = 30, err_opus = 31, + err_cant_start_shard = 32, err_etf = 33, err_cache = 34, err_icon_size = 35, diff --git a/mlspp/include/namespace.h b/mlspp/include/namespace.h index 43a5121ccb..d07ba5ee94 100755 --- a/mlspp/include/namespace.h +++ b/mlspp/include/namespace.h @@ -1,4 +1,4 @@ #pragma once // Configurable top-level MLS namespace -#define MLS_NAMESPACE mls +#define MLS_NAMESPACE ../include/dpp/mlspp/mls diff --git a/src/dpp/cluster.cpp b/src/dpp/cluster.cpp index 21947d0679..7d4be31810 100644 --- a/src/dpp/cluster.cpp +++ b/src/dpp/cluster.cpp @@ -223,25 +223,39 @@ void cluster::start(start_type return_after) { } /* Start up all shards */ - get_gateway_bot([this](const auto& response) { + get_gateway_bot([this, return_after](const auto& response) { + + auto throw_if_not_threaded = [this, return_after](exception_error_code error_id, const std::string& msg) { + log(ll_critical, msg); + if (return_after == st_wait) { + throw dpp::connection_exception(error_id, msg); + } + }; + if (response.is_error()) { - // TODO: Check for 401 unauthorized - // throw dpp::invalid_token_exception(err_unauthorized, "Invalid bot token (401: Unauthorized when getting gateway shard count)"); + if (response.http_info.status == 401) { + throw_if_not_threaded(err_unauthorized, "Invalid bot token (401: Unauthorized when getting gateway shard count)"); + } else { + throw_if_not_threaded(err_auto_shard, "get_gateway_bot: " + response.http_info.body); + } return; } auto g = std::get(response.value); log(ll_debug, "Cluster: " + std::to_string(g.session_start_remaining) + " of " + std::to_string(g.session_start_total) + " session starts remaining"); if (g.session_start_remaining < g.shards || g.shards == 0) { - throw dpp::connection_exception(err_no_sessions_left, "Discord indicates you cannot start enough sessions to boot this cluster! Cluster startup aborted. Try again later."); + throw_if_not_threaded(err_no_sessions_left, "Discord indicates you cannot start enough sessions to boot this cluster! Cluster startup aborted. Try again later."); + return; } else if (g. session_start_max_concurrency == 0) { - throw dpp::connection_exception(err_auto_shard, "Cluster: Could not determine concurrency, startup aborted!"); + throw_if_not_threaded(err_auto_shard, "Cluster: Could not determine concurrency, startup aborted!"); + return; } else if (g.session_start_max_concurrency > 1) { log(ll_debug, "Cluster: Large bot sharding; Using session concurrency: " + std::to_string(g.session_start_max_concurrency)); } else if (numshards == 0) { if (g.shards) { log(ll_info, "Auto Shard: Bot requires " + std::to_string(g.shards) + std::string(" shard") + ((g.shards > 1) ? "s" : "")); } else { - throw dpp::connection_exception(err_auto_shard, "Auto Shard: Cannot determine number of shards. Cluster startup aborted. Check your connection."); + throw_if_not_threaded(err_auto_shard, "Auto Shard: Cannot determine number of shards. Cluster startup aborted. Check your connection."); + return; } numshards = g.shards; } @@ -257,7 +271,8 @@ void cluster::start(start_type return_after) { this->shards[s]->run(); } catch (const std::exception &e) { - log(dpp::ll_critical, "Could not start shard " + std::to_string(s) + ": " + std::string(e.what())); + throw_if_not_threaded(err_cant_start_shard, "Could not start shard " + std::to_string(s) + ": " + std::string(e.what())); + return; } /* Stagger the shard startups, pausing every 'session_start_max_concurrency' shards for 5 seconds. * This means that for bots that don't have large bot sharding, any number % 1 is always 0, @@ -287,6 +302,10 @@ void cluster::start(start_type return_after) { /* Get all active DM channels and map them to user id -> dm id */ current_user_get_dms([this](const dpp::confirmation_callback_t& completion) { + if (completion.is_error()) { + log(dpp::ll_debug, "Failed to get bot DM list"); + return; + } dpp::channel_map dmchannels = std::get(completion.value); for (auto & c : dmchannels) { for (auto & u : c.second.recipients) { @@ -298,7 +317,7 @@ void cluster::start(start_type return_after) { log(ll_debug, "Shards started."); }); - if (return_after) { + if (return_after == st_return) { engine_thread = std::thread([event_loop]() { dpp::utility::set_thread_name("event_loop"); event_loop(); diff --git a/src/dpp/discordclient.cpp b/src/dpp/discordclient.cpp index c698c8cb7f..37d703d59e 100644 --- a/src/dpp/discordclient.cpp +++ b/src/dpp/discordclient.cpp @@ -65,7 +65,6 @@ thread_local static std::string last_ping_message; discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint32_t _max_shards, const std::string &_token, uint32_t _intents, bool comp, websocket_protocol_t ws_proto) : websocket_client(_cluster, _cluster->default_gateway, "443", comp ? (ws_proto == ws_json ? PATH_COMPRESSED_JSON : PATH_COMPRESSED_ETF) : (ws_proto == ws_json ? PATH_UNCOMPRESSED_JSON : PATH_UNCOMPRESSED_ETF)), - terminating(false), compressed(comp), decomp_buffer(nullptr), zlib(nullptr), @@ -98,8 +97,7 @@ void discord_client::start_connecting() { etf = new etf_parser(); } catch (std::bad_alloc&) { - delete zlib; - delete etf; + cleanup(); /* Clean up and rethrow to caller */ throw std::bad_alloc(); } @@ -114,7 +112,6 @@ void discord_client::start_connecting() { void discord_client::cleanup() { - terminating = true; delete etf; delete zlib; } @@ -134,7 +131,6 @@ void discord_client::on_disconnect() log(dpp::ll_debug, "Reconnecting shard " + std::to_string(shard_id) + " to wss://" + hostname + "..."); owner->stop_timer(handle); cleanup(); - terminating = false; if (timer_handle) { owner->stop_timer(timer_handle); timer_handle = 0;