diff --git a/include/dpp/discordvoiceclient.h b/include/dpp/discordvoiceclient.h index ad0c5e93fa..b902208d4c 100644 --- a/include/dpp/discordvoiceclient.h +++ b/include/dpp/discordvoiceclient.h @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -267,16 +268,6 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ std::deque message_queue; - /** - * @brief Thread this connection is executing on - */ - std::thread* runner; - - /** - * @brief Run shard loop under a thread - */ - void thread_run(); - /** * @brief Last connect time of voice session */ @@ -446,6 +437,16 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ bool sent_stop_frames; + /** + * @brief Number of times we have tried to reconnect in the last few seconds + */ + size_t times_looped{0}; + + /** + * @brief Last time we reconnected + */ + time_t last_loop_time{0}; + #ifdef HAVE_VOICE /** * @brief libopus encoder @@ -623,25 +624,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client int udp_recv(char* data, size_t max_length); /** - * @brief This hooks the ssl_client, returning the file - * descriptor if we want to send buffered data, or - * -1 if there is nothing to send - * - * @return int file descriptor or -1 - */ - dpp::socket want_write(); - - /** - * @brief This hooks the ssl_client, returning the file - * descriptor if we want to receive buffered data, or - * -1 if we are not wanting to receive - * - * @return int file descriptor or -1 - */ - dpp::socket want_read(); - - /** - * @brief Called by ssl_client when the socket is ready + * @brief Called by socketengine when the socket is ready * for writing, at this point we pick the head item off * the buffer and send it. So long as it doesn't error * completely, we pop it off the head of the queue. @@ -649,7 +632,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client void write_ready(); /** - * @brief Called by ssl_client when there is data to be + * @brief Called by socketengine when there is data to be * read. At this point we insert that data into the * input queue. * @throw dpp::voice_exception if voice support is not compiled into D++ @@ -710,6 +693,16 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ void update_ratchets(bool force = false); + /** + * @brief Called in constructor and on reconnection of websocket + */ + void setup(); + + /** + * @brief Events for UDP Socket IO + */ + dpp::socket_events udp_events; + public: /** @@ -747,11 +740,6 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ time_t last_heartbeat; - /** - * @brief Thread ID - */ - std::thread::native_handle_type thread_id; - /** * @brief Discord voice session token */ @@ -1269,6 +1257,11 @@ class DPP_EXPORT discord_voice_client : public websocket_client * @param rmap Roster map */ void process_mls_group_rosters(const std::map>& rmap); + + /** + * @brief Called on websocket disconnection + */ + void on_disconnect(); }; } diff --git a/include/dpp/socketengine.h b/include/dpp/socketengine.h index 7d2d20e164..18a8b2d98b 100644 --- a/include/dpp/socketengine.h +++ b/include/dpp/socketengine.h @@ -129,6 +129,10 @@ struct DPP_EXPORT socket_events { socket_events(dpp::socket socket_fd, uint8_t _flags, const socket_read_event& read_event, const socket_write_event& write_event = {}, const socket_error_event& error_event = {}) : fd(socket_fd), flags(_flags), on_read(read_event), on_write(write_event), on_error(error_event) { } + /** + * @brief Default constructor + */ + socket_events() = default; }; /** diff --git a/src/davetest/dave.cpp b/src/davetest/dave.cpp index 94812e0d89..72bac9231b 100644 --- a/src/davetest/dave.cpp +++ b/src/davetest/dave.cpp @@ -80,7 +80,7 @@ int main() { dave_test.on_guild_create([&](const dpp::guild_create_t & event) { if (event.created->id == TEST_GUILD_ID) { dpp::discord_client* s = dave_test.get_shard(0); - bool muted = false, deaf = false, enable_dave = true; + bool muted = false, deaf = false, enable_dave = false; s->connect_voice(TEST_GUILD_ID, TEST_VC_ID, muted, deaf, enable_dave); } }); diff --git a/src/dpp/voice/enabled/cleanup.cpp b/src/dpp/voice/enabled/cleanup.cpp index 5ae1d9c7e2..6d7b594b1d 100644 --- a/src/dpp/voice/enabled/cleanup.cpp +++ b/src/dpp/voice/enabled/cleanup.cpp @@ -33,12 +33,6 @@ namespace dpp { void discord_voice_client::cleanup() { - if (runner) { - this->terminating = true; - runner->join(); - delete runner; - runner = nullptr; - } if (encoder) { opus_encoder_destroy(encoder); encoder = nullptr; @@ -55,6 +49,9 @@ void discord_voice_client::cleanup() voice_courier_shared_state.signal_iteration.notify_one(); voice_courier.join(); } + if (fd != INVALID_SOCKET) { + owner->socketengine->delete_socket(fd); + } } } diff --git a/src/dpp/voice/enabled/constructor.cpp b/src/dpp/voice/enabled/constructor.cpp index 939648fc9a..54f33753ae 100644 --- a/src/dpp/voice/enabled/constructor.cpp +++ b/src/dpp/voice/enabled/constructor.cpp @@ -34,13 +34,14 @@ namespace dpp { discord_voice_client::discord_voice_client(dpp::cluster* _cluster, snowflake _channel_id, snowflake _server_id, const std::string &_token, const std::string &_session_id, const std::string &_host, bool enable_dave) : websocket_client(_cluster, _host.substr(0, _host.find(':')), _host.substr(_host.find(':') + 1, _host.length()), "/?v=" + std::to_string(voice_protocol_version), OP_TEXT), - runner(nullptr), connect_time(0), mixer(std::make_unique()), port(0), ssrc(0), timescale(1000000), paused(false), + sent_stop_frames(false), + last_loop_time(time(nullptr)), encoder(nullptr), repacketizer(nullptr), fd(INVALID_SOCKET), @@ -60,6 +61,11 @@ discord_voice_client::discord_voice_client(dpp::cluster* _cluster, snowflake _ch sessionid(_session_id), server_id(_server_id), channel_id(_channel_id) +{ + setup(); +} + +void discord_voice_client::setup() { int opusError = 0; encoder = opus_encoder_create(opus_sample_rate_hz, opus_channel_count, OPUS_APPLICATION_VOIP, &opusError); diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index bfd85da758..4521f7a375 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -461,10 +461,17 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod /* Hook poll() in the ssl_client to add a new file descriptor */ this->fd = newfd; - this->custom_writeable_fd = [this] { return want_write(); }; - this->custom_readable_fd = [this] { return want_read(); }; - this->custom_writeable_ready = [this] { write_ready(); }; - this->custom_readable_ready = [this] { read_ready(); }; + + udp_events = dpp::socket_events( + fd, + WANT_READ | WANT_WRITE | WANT_ERROR, + [this](socket fd, const struct socket_events &e) { read_ready(); }, + [this](socket fd, const struct socket_events &e) { write_ready(); }, + [this](socket fd, const struct socket_events &e, int error_code) { + this->close(); + } + ); + owner->socketengine->register_socket(udp_events); int bound_port = address_t().get_port(this->fd); this->write(json({ diff --git a/src/dpp/voice/enabled/read_write.cpp b/src/dpp/voice/enabled/read_write.cpp index 52a09d5a39..e131caf784 100644 --- a/src/dpp/voice/enabled/read_write.cpp +++ b/src/dpp/voice/enabled/read_write.cpp @@ -28,20 +28,6 @@ namespace dpp { -dpp::socket discord_voice_client::want_write() { - std::lock_guard lock(this->stream_mutex); - if (!this->sent_stop_frames && !outbuf.empty()) { - return fd; - } - return INVALID_SOCKET; - -} - -dpp::socket discord_voice_client::want_read() { - return fd; -} - - void discord_voice_client::send(const char* packet, size_t len, uint64_t duration, bool send_now) { if (!send_now) [[likely]] { voice_out_packet frame; @@ -50,6 +36,10 @@ void discord_voice_client::send(const char* packet, size_t len, uint64_t duratio std::lock_guard lock(this->stream_mutex); outbuf.emplace_back(frame); + if (!this->sent_stop_frames) { + udp_events.flags = WANT_READ | WANT_WRITE | WANT_ERROR; + owner->socketengine->update_socket(udp_events); + } } else [[unlikely]] { this->udp_send(packet, len); } diff --git a/src/dpp/voice/enabled/thread.cpp b/src/dpp/voice/enabled/thread.cpp index 655e0b6a64..2d5372977a 100644 --- a/src/dpp/voice/enabled/thread.cpp +++ b/src/dpp/voice/enabled/thread.cpp @@ -20,7 +20,6 @@ * ************************************************************************************/ -#include #include #include #include @@ -29,60 +28,42 @@ namespace dpp { -void discord_voice_client::thread_run() -{ - utility::set_thread_name(std::string("vc/") + std::to_string(server_id)); +void discord_voice_client::on_disconnect() { - size_t times_looped = 0; - time_t last_loop_time = time(nullptr); + time_t current_time = time(nullptr); - do { - bool error = false; - ssl_client::read_loop(); - ssl_client::close(); + /* Here, we check if it's been longer than 3 seconds since the previous loop, + * this gives us time to see if it's an actual disconnect, or an error. + * This will prevent us from looping too much, meaning error codes do not cause an infinite loop. + */ + if (current_time - last_loop_time >= 3) { + times_looped = 0; + } - time_t current_time = time(nullptr); - /* Here, we check if it's been longer than 3 seconds since the previous loop, - * this gives us time to see if it's an actual disconnect, or an error. - * This will prevent us from looping too much, meaning error codes do not cause an infinite loop. - */ - if (current_time - last_loop_time >= 3) { - times_looped = 0; - } + /* This does mean we'll always have times_looped at a minimum of 1, this is intended. */ + times_looped++; - /* This does mean we'll always have times_looped at a minimum of 1, this is intended. */ - times_looped++; - /* If we've looped 5 or more times, abort the loop. */ - if (times_looped >= 5) { - log(dpp::ll_warning, "Reached max loops whilst attempting to read from the websocket. Aborting websocket."); - break; - } + /* If we've looped 5 or more times, abort the loop. */ + if (terminating || times_looped >= 5) { + log(dpp::ll_warning, "Reached max loops whilst attempting to read from the websocket. Aborting websocket."); + return; + } + last_loop_time = current_time; - last_loop_time = current_time; - - if (!terminating) { - log(dpp::ll_debug, "Attempting to reconnect the websocket..."); - do { - try { - ssl_client::connect(); - websocket_client::connect(); - } - catch (const std::exception &e) { - log(dpp::ll_error, std::string("Error establishing voice websocket connection, retry in 5 seconds: ") + e.what()); - ssl_client::close(); - std::this_thread::sleep_for(std::chrono::seconds(5)); - error = true; - } - } while (error && !terminating); - } - } while(!terminating); + log(dpp::ll_debug, "Attempting to reconnect the websocket..."); + owner->start_timer([this](auto handle) { + owner->stop_timer(handle); + cleanup(); + setup(); + terminating = false; + ssl_client::connect(); + websocket_client::connect(); + run(); + }, 1); } -void discord_voice_client::run() -{ - this->runner = new std::thread(&discord_voice_client::thread_run, this); - this->thread_id = runner->native_handle(); +void discord_voice_client::run() { + ssl_client::read_loop(); } - -} +} \ No newline at end of file diff --git a/src/dpp/voice/enabled/write_ready.cpp b/src/dpp/voice/enabled/write_ready.cpp index ade6b20716..385388c356 100644 --- a/src/dpp/voice/enabled/write_ready.cpp +++ b/src/dpp/voice/enabled/write_ready.cpp @@ -59,6 +59,10 @@ void discord_voice_client::write_ready() { bufsize = outbuf[0].packet.length(); outbuf.erase(outbuf.begin()); } + if (!outbuf.empty()) { + udp_events.flags = WANT_READ | WANT_WRITE | WANT_ERROR; + owner->socketengine->update_socket(udp_events); + } } } } @@ -123,5 +127,4 @@ void discord_voice_client::write_ready() { } } - } diff --git a/src/dpp/voice/stub/stubs.cpp b/src/dpp/voice/stub/stubs.cpp index 52bb2d5fe2..226b8643db 100644 --- a/src/dpp/voice/stub/stubs.cpp +++ b/src/dpp/voice/stub/stubs.cpp @@ -36,15 +36,12 @@ namespace dpp { void discord_voice_client::voice_courier_loop(discord_voice_client& client, courier_shared_state_t& shared_state) { } - void discord_voice_client::cleanup(){ + void discord_voice_client::cleanup() { } void discord_voice_client::run() { } - void discord_voice_client::thread_run() { - } - bool discord_voice_client::voice_payload::operator<(const voice_payload& other) const { return false; } @@ -71,15 +68,6 @@ namespace dpp { return *this; } - dpp::socket discord_voice_client::want_write() { - return INVALID_SOCKET; - } - - dpp::socket discord_voice_client::want_read() { - return INVALID_SOCKET; - } - - void discord_voice_client::send(const char* packet, size_t len, uint64_t duration, bool send_now) { } @@ -99,4 +87,7 @@ namespace dpp { return ""; } + void discord_voice_client::setup() { + } + }