From f80a97510a7f9d3a3b3822e6495de56975ef34c5 Mon Sep 17 00:00:00 2001 From: Craig Edwards Date: Thu, 21 Nov 2024 00:55:19 +0000 Subject: [PATCH] properly mutex everything --- doxygen-awesome-css | 2 +- include/dpp/socketengine.h | 7 +++ src/dpp/socketengine.cpp | 4 ++ src/dpp/socketengines/poll.cpp | 102 ++++++++++++++++++--------------- src/unittest/test.cpp | 75 ++++++++++++------------ src/unittest/test.h | 2 + src/unittest/unittest.cpp | 6 +- 7 files changed, 113 insertions(+), 85 deletions(-) diff --git a/doxygen-awesome-css b/doxygen-awesome-css index af1d9030b3..c6568ebc70 160000 --- a/doxygen-awesome-css +++ b/doxygen-awesome-css @@ -1 +1 @@ -Subproject commit af1d9030b3ffa7b483fa9997a7272fb12af6af4c +Subproject commit c6568ebc70adf9fb0fb6c1745737ae6945576813 diff --git a/include/dpp/socketengine.h b/include/dpp/socketengine.h index c7040845c4..7d2d20e164 100644 --- a/include/dpp/socketengine.h +++ b/include/dpp/socketengine.h @@ -25,6 +25,7 @@ #include #include #include +#include #include namespace dpp { @@ -144,6 +145,12 @@ using socket_container = std::unordered_map(e)); return true; @@ -39,6 +40,7 @@ bool socket_engine_base::register_socket(const socket_events &e) { } bool socket_engine_base::update_socket(const socket_events &e) { + std::unique_lock lock(fds_mutex); if (e.fd != INVALID_SOCKET && fds.find(e.fd) != fds.end()) { auto iter = fds.find(e.fd); *(iter->second) = e; @@ -68,6 +70,7 @@ time_t last_time = time(nullptr); void socket_engine_base::prune() { if (to_delete_count > 0) { + std::unique_lock lock(fds_mutex); for (auto it = fds.cbegin(); it != fds.cend();) { if ((it->second->flags & WANT_DELETION) != 0L) { remove_socket(it->second->fd); @@ -98,6 +101,7 @@ void socket_engine_base::prune() { } bool socket_engine_base::delete_socket(dpp::socket fd) { + std::unique_lock lock(fds_mutex); auto iter = fds.find(fd); if (iter == fds.end() || ((iter->second->flags & WANT_DELETION) != 0L)) { return false; diff --git a/src/dpp/socketengines/poll.cpp b/src/dpp/socketengines/poll.cpp index d2e0da1a82..dd89c6efda 100644 --- a/src/dpp/socketengines/poll.cpp +++ b/src/dpp/socketengines/poll.cpp @@ -21,6 +21,7 @@ #include #include +#include #ifdef _WIN32 /* Windows-specific sockets includes */ #include @@ -54,70 +55,78 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { */ std::vector poll_set; pollfd out_set[FD_SETSIZE]{0}; + std::shared_mutex poll_set_mutex; void process_events() final { const int poll_delay = 1000; - if (poll_set.empty()) { - /* On many platforms, it is not possible to wait on an empty set */ - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } else { - if (poll_set.size() > FD_SETSIZE) { - throw dpp::connection_exception("poll() does not support more than FD_SETSIZE active sockets at once!"); + prune(); + { + std::shared_lock lock(poll_set_mutex); + if (poll_set.empty()) { + /* On many platforms, it is not possible to wait on an empty set */ + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + return; + } else { + if (poll_set.size() > FD_SETSIZE) { + throw dpp::connection_exception("poll() does not support more than FD_SETSIZE active sockets at once!"); + } + /** + * We must make a copy of the poll_set, because it would cause thread locking/contention + * issues if we had it locked for read during poll/iteration of the returned set. + */ + std::copy(poll_set.begin(), poll_set.end(), out_set); } + } - std::copy(poll_set.begin(), poll_set.end(), out_set); + int i = poll(out_set, static_cast(poll_set.size()), poll_delay); + int processed = 0; - int i = poll(out_set, static_cast(poll_set.size()), poll_delay); - int processed = 0; + for (size_t index = 0; index < poll_set.size() && processed < i; index++) { + const int fd = out_set[index].fd; + const short revents = out_set[index].revents; - for (size_t index = 0; index < poll_set.size() && processed < i; index++) { - const int fd = out_set[index].fd; - const short revents = out_set[index].revents; + if (revents > 0) { + processed++; + } - if (revents > 0) { - processed++; - } + auto iter = fds.find(fd); + if (iter == fds.end()) { + continue; + } + socket_events *eh = iter->second.get(); + + try { - auto iter = fds.find(fd); - if (iter == fds.end()) { + if ((revents & POLLHUP) != 0) { + eh->on_error(fd, *eh, 0); continue; } - socket_events *eh = iter->second.get(); - - try { - - if ((revents & POLLHUP) != 0) { - eh->on_error(fd, *eh, 0); - continue; - } - - if ((revents & POLLERR) != 0) { - socklen_t codesize = sizeof(int); - int errcode{}; - if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *) &errcode, &codesize) < 0) { - errcode = errno; - } - eh->on_error(fd, *eh, errcode); - continue; - } - if ((revents & POLLIN) != 0) { - eh->on_read(fd, *eh); + if ((revents & POLLERR) != 0) { + socklen_t codesize = sizeof(int); + int errcode{}; + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *) &errcode, &codesize) < 0) { + errcode = errno; } + eh->on_error(fd, *eh, errcode); + continue; + } - if ((revents & POLLOUT) != 0) { - eh->flags &= ~WANT_WRITE; - update_socket(*eh); - eh->on_write(fd, *eh); - } + if ((revents & POLLIN) != 0) { + eh->on_read(fd, *eh); + } - } catch (const std::exception &e) { - eh->on_error(fd, *eh, 0); + if ((revents & POLLOUT) != 0) { + eh->flags &= ~WANT_WRITE; + update_socket(*eh); + eh->on_write(fd, *eh); } + + } catch (const std::exception &e) { + eh->on_error(fd, *eh, 0); } } - prune(); } #if _WIN32 @@ -129,6 +138,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { bool register_socket(const socket_events& e) final { bool r = socket_engine_base::register_socket(e); if (r) { + std::unique_lock lock(poll_set_mutex); pollfd fd_info{}; fd_info.fd = e.fd; fd_info.events = 0; @@ -146,6 +156,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { bool update_socket(const socket_events& e) final { bool r = socket_engine_base::update_socket(e); if (r) { + std::unique_lock lock(poll_set_mutex); /* We know this will succeed */ for (pollfd& fd_info : poll_set) { if (fd_info.fd != e.fd) { @@ -171,6 +182,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { bool remove_socket(dpp::socket fd) final { bool r = socket_engine_base::remove_socket(fd); if (r) { + std::unique_lock lock(poll_set_mutex); for (auto i = poll_set.begin(); i != poll_set.end(); ++i) { if (i->fd == fd) { poll_set.erase(i); diff --git a/src/unittest/test.cpp b/src/unittest/test.cpp index 19dd1970c9..d0119e7b16 100644 --- a/src/unittest/test.cpp +++ b/src/unittest/test.cpp @@ -26,6 +26,11 @@ #include #include +/** + * @brief global lock for log output + */ +std::mutex loglock; + /** * @brief Type trait to check if a certain type has a build_json method * @@ -926,6 +931,9 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b coro_offline_tests(); } + std::promise ready_promise; + std::future ready_future = ready_promise.get_future(); + std::vector dpp_logo = load_data("DPP-Logo.png"); set_test(PRESENCE, false); @@ -971,9 +979,6 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b /* This ensures we test both protocols, as voice is json and shard is etf */ bot.set_websocket_protocol(dpp::ws_etf); - bot.on_form_submit([&](const dpp::form_submit_t & event) { - }); - /* This is near impossible to test without a 'clean room' voice channel. * We attach this event just so that the decoder events are fired while we * are sending audio later, this way if the audio receive code is plain unstable @@ -982,7 +987,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b bot.on_voice_receive_combined([&](const auto& event) { }); - bot.on_guild_create([&](const dpp::guild_create_t& event) { + bot.on_guild_create([dpp_logo,&bot](const dpp::guild_create_t& event) { dpp::guild *g = event.created; if (g->id == TEST_GUILD_ID) { @@ -1016,9 +1021,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - std::promise ready_promise; - std::future ready_future = ready_promise.get_future(); - bot.on_ready([&](const dpp::ready_t & event) { + bot.on_ready([&ready_promise,&bot,dpp_logo](const dpp::ready_t & event) { set_test(CONNECTION, true); ready_promise.set_value(); @@ -1076,10 +1079,9 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b }); }); - std::mutex loglock; - bot.on_log([&](const dpp::log_t & event) { - std::lock_guard locker(loglock); + bot.on_log([](const dpp::log_t & event) { if (event.severity > dpp::ll_trace) { + std::lock_guard locker(loglock); std::cout << "[" << std::fixed << std::setprecision(3) << (dpp::utility::time_f() - get_start_time()) << "]: [\u001b[36m" << dpp::utility::loglevel(event.severity) << "\u001b[0m] " << event.message << "\n"; } if (event.message == "Test log message") { @@ -1096,7 +1098,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } set_test(RUNONCE, (runs == 1)); - bot.on_voice_ready([&](const dpp::voice_ready_t & event) { + bot.on_voice_ready([&testaudio](const dpp::voice_ready_t & event) { set_test(VOICECONN, true); dpp::discord_voice_client* v = event.voice_client; set_test(VOICESEND, false); @@ -1121,7 +1123,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - bot.on_voice_buffer_send([&](const dpp::voice_buffer_send_t & event) { + bot.on_voice_buffer_send([](const dpp::voice_buffer_send_t & event) { static bool sent_some_data = false; if (event.buffer_size > 0) { @@ -1133,7 +1135,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - bot.on_guild_create([&](const dpp::guild_create_t & event) { + bot.on_guild_create([](const dpp::guild_create_t & event) { if (event.created->id == TEST_GUILD_ID) { set_test(GUILDCREATE, true); if (event.presences.size() && event.presences.begin()->second.user_id > 0) { @@ -1230,21 +1232,21 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b return false; } }; - message.attachments[0].download([&](const dpp::http_request_completion_t &callback) { + message.attachments[0].download([this](const dpp::http_request_completion_t &callback) { std::lock_guard lock(mutex); if (callback.status == 200 && callback.body == "test") { files_success[0] = true; } set_file_tested(0); }); - message.attachments[1].download([&](const dpp::http_request_completion_t &callback) { + message.attachments[1].download([this](const dpp::http_request_completion_t &callback) { std::lock_guard lock(mutex); if (callback.status == 200 && check_mimetype(callback.headers, "text/plain") && callback.body == "test") { files_success[1] = true; } set_file_tested(1); }); - message.attachments[2].download([&](const dpp::http_request_completion_t &callback) { + message.attachments[2].download([this](const dpp::http_request_completion_t &callback) { std::lock_guard lock(mutex); // do not check the contents here because discord can change compression if (callback.status == 200 && check_mimetype(callback.headers, "image/png")) { @@ -1594,7 +1596,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b thread_test_helper thread_helper(bot); - bot.on_thread_create([&](const dpp::thread_create_t &event) { + bot.on_thread_create([&thread_helper](const dpp::thread_create_t &event) { if (event.created.name == "thread test") { set_test(THREAD_CREATE_EVENT, true); thread_helper.run(event.created); @@ -1694,7 +1696,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - bot.on_thread_update([&](const dpp::thread_update_t &event) { + bot.on_thread_update([&thread_helper](const dpp::thread_update_t &event) { if (event.updating_guild->id == TEST_GUILD_ID && event.updated.id == thread_helper.thread_id && event.updated.name == "edited") { set_test(THREAD_UPDATE_EVENT, true); } @@ -2281,13 +2283,17 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b set_test(BOTSTART, false); } + dpp::https_client *c{}; + dpp::https_client *c2{}; + dpp::https_client *c3{}; + set_test(HTTPS, false); if (!offline) { dpp::multipart_content multipart = dpp::https_client::build_multipart( "{\"content\":\"test\"}", {"test.txt", "blob.blob"}, {"ABCDEFGHI", "BLOB!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"}, {"text/plain", "application/octet-stream"} ); try { - dpp::https_client c(&bot, "discord.com", 443, "/api/channels/" + std::to_string(TEST_TEXT_CHANNEL_ID) + "/messages", "POST", multipart.body, + c = new dpp::https_client(&bot, "discord.com", 443, "/api/channels/" + std::to_string(TEST_TEXT_CHANNEL_ID) + "/messages", "POST", multipart.body, { {"Content-Type", multipart.mimetype}, {"Authorization", "Bot " + token} @@ -2297,49 +2303,43 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b set_test(HTTPS, hdr1 == "cloudflare" && c->get_status() == 200); } ); - std::this_thread::sleep_for(std::chrono::seconds(6)); } catch (const dpp::exception& e) { - std::cout << e.what() << "\n"; - set_test(HTTPS, false); + set_status(HTTPS, ts_failed, e.what()); } set_test(HTTP, false); try { - dpp::https_client c2(&bot, "github.com", 80, "/", "GET", "", {}, true, 5, "1.1", [](dpp::https_client* c2) { + c2 = new dpp::https_client(&bot, "github.com", 80, "/", "GET", "", {}, true, 5, "1.1", [](dpp::https_client *c2) { std::string hdr2 = c2->get_header("location"); std::string content2 = c2->get_content(); set_test(HTTP, hdr2 == "https://github.com/" && c2->get_status() == 301); }); - std::this_thread::sleep_for(std::chrono::seconds(6)); } catch (const dpp::exception& e) { - std::cout << e.what() << "\n"; - set_test(HTTP, false); + set_status(HTTP, ts_failed, e.what()); } set_test(MULTIHEADER, false); try { - dpp::https_client c2(&bot, "dl.dpp.dev", 443, "/cookietest.php", "GET", "", {}, true, 5, "1.1", [](dpp::https_client* c2) { + c3 = new dpp::https_client(&bot, "dl.dpp.dev", 443, "/cookietest.php", "GET", "", {}, true, 5, "1.1", [](dpp::https_client *c2) { size_t count = c2->get_header_count("set-cookie"); size_t count_list = c2->get_header_list("set-cookie").size(); // This test script sets a bunch of cookies when we request it. set_test(MULTIHEADER, c2->get_status() == 200 && count > 1 && count == count_list); }); - std::this_thread::sleep_for(std::chrono::seconds(6)); } catch (const dpp::exception& e) { - std::cout << e.what() << "\n"; - set_test(MULTIHEADER, false); + set_status(MULTIHEADER, ts_failed, e.what()); } } set_test(TIMERSTART, false); - uint32_t ticks = 0; - dpp::timer th = bot.start_timer([&](dpp::timer timer_handle) { - if (ticks == 5) { + static uint32_t ticks = 0; + dpp::timer th = bot.start_timer([](dpp::timer timer_handle) { + if (ticks == 2) { /* The simple test timer ticks every second. - * If we get to 5 seconds, we know the timer is working. + * If we get to 2 seconds, we know the timer is working. */ set_test(TIMERSTART, true); } @@ -2451,10 +2451,13 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b wait_for_tests(); + delete c; + delete c2; + delete c3; + } catch (const std::exception &e) { - std::cout << e.what() << "\n"; - set_test(CLUSTER, false); + set_status(CLUSTER, ts_failed, e.what()); } /* Return value = number of failed tests, exit code 0 = success */ diff --git a/src/unittest/test.h b/src/unittest/test.h index 56e8d2cbb7..cf9f966a28 100644 --- a/src/unittest/test.h +++ b/src/unittest/test.h @@ -39,6 +39,8 @@ _Pragma("warning( disable : 5105 )"); // 4251 warns when we export classes or st using json = nlohmann::json; +extern std::mutex loglock; + enum test_flags_t { tf_offline = 0, /* A test that requires discord connectivity */ diff --git a/src/unittest/unittest.cpp b/src/unittest/unittest.cpp index 65059ad3fc..ec58d888c2 100644 --- a/src/unittest/unittest.cpp +++ b/src/unittest/unittest.cpp @@ -38,8 +38,7 @@ test_t::test_t(std::string_view testname, std::string_view testdesc, int testfla } void set_status(test_t &test, test_status_t newstatus, std::string_view message) { - static std::mutex m; - std::scoped_lock lock{m}; + std::lock_guard locker(loglock); if (is_skipped(test) || newstatus == test.status) { return; @@ -99,6 +98,7 @@ double get_time() { int test_summary() { /* Report on all test cases */ + std::lock_guard locker(loglock); int failed = 0, passed = 0, skipped = 0; std::cout << "\u001b[37;1m\n\nUNIT TEST SUMMARY\n==================\n\u001b[0m"; for (auto & t : tests) { @@ -193,7 +193,7 @@ void wait_for_tests() { } } if (finished == tests.size()) { - std::this_thread::sleep_for(std::chrono::seconds(10)); + std::this_thread::sleep_for(std::chrono::seconds(5)); return; } std::this_thread::sleep_for(std::chrono::seconds(1));