From 93fc814ecd44c5a66690637342e4ef424621728a Mon Sep 17 00:00:00 2001 From: Craig Edwards Date: Wed, 20 Nov 2024 21:22:39 +0000 Subject: [PATCH] fixes for poll on windows --- doxygen-awesome-css | 2 +- include/dpp/dns.h | 6 +- include/dpp/socketengine.h | 6 +- include/dpp/sslclient.h | 4 +- include/dpp/thread_pool.h | 16 +++-- mlspp/include/namespace.h | 2 +- src/dpp/dns.cpp | 3 +- src/dpp/socketengine.cpp | 8 +-- src/dpp/socketengines/epoll.cpp | 4 +- src/dpp/socketengines/kqueue.cpp | 4 +- src/dpp/socketengines/poll.cpp | 92 ++++++++++++++------------- src/dpp/thread_pool.cpp | 12 ++-- src/dpp/voice/enabled/discover_ip.cpp | 4 ++ src/sockettest/socket.cpp | 20 +++++- 14 files changed, 106 insertions(+), 77 deletions(-) diff --git a/doxygen-awesome-css b/doxygen-awesome-css index af1d9030b3..c6568ebc70 160000 --- a/doxygen-awesome-css +++ b/doxygen-awesome-css @@ -1 +1 @@ -Subproject commit af1d9030b3ffa7b483fa9997a7272fb12af6af4c +Subproject commit c6568ebc70adf9fb0fb6c1745737ae6945576813 diff --git a/include/dpp/dns.h b/include/dpp/dns.h index 48cebcd563..b1a6814065 100644 --- a/include/dpp/dns.h +++ b/include/dpp/dns.h @@ -40,7 +40,7 @@ namespace dpp { * @brief Represents a cached DNS result. * Used by the ssl_client class to store cached copies of dns lookups. */ - struct dns_cache_entry { + struct DPP_EXPORT dns_cache_entry { /** * @brief Resolved address metadata */ @@ -93,5 +93,5 @@ namespace dpp { * @return dns_cache_entry* First IP address associated with the hostname DNS record * @throw dpp::connection_exception On failure to resolve hostname */ - const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port); -} + DPP_EXPORT const dns_cache_entry *resolve_hostname(const std::string &hostname, const std::string &port); + } diff --git a/include/dpp/socketengine.h b/include/dpp/socketengine.h index b793cc098e..c7040845c4 100644 --- a/include/dpp/socketengine.h +++ b/include/dpp/socketengine.h @@ -83,7 +83,7 @@ using socket_error_event = std::function create_socket_engine(class cluster* creator); +DPP_EXPORT std::unique_ptr create_socket_engine(class cluster *creator); #ifndef _WIN32 void set_signal_handler(int signal); diff --git a/include/dpp/sslclient.h b/include/dpp/sslclient.h index 5266e70ab0..0d1bb693e7 100644 --- a/include/dpp/sslclient.h +++ b/include/dpp/sslclient.h @@ -54,7 +54,7 @@ typedef std::function socket_notification_t; * @param sfd Socket to close * @return false on error, true on success */ -bool close_socket(dpp::socket sfd); +DPP_EXPORT bool close_socket(dpp::socket sfd); /** * @brief Set a socket to blocking or non-blocking IO @@ -63,7 +63,7 @@ bool close_socket(dpp::socket sfd); * @param non_blocking should socket be non-blocking? * @return false on error, true on success */ -bool set_nonblocking(dpp::socket sockfd, bool non_blocking); +DPP_EXPORT bool set_nonblocking(dpp::socket sockfd, bool non_blocking); /* You'd think that we would get better performance with a bigger buffer, but SSL frames are 16k each. * SSL_read in non-blocking mode will only read 16k at a time. There's no point in a bigger buffer as diff --git a/include/dpp/thread_pool.h b/include/dpp/thread_pool.h index 76398d1daf..ef15ab33ea 100644 --- a/include/dpp/thread_pool.h +++ b/include/dpp/thread_pool.h @@ -28,27 +28,29 @@ #include #include +namespace dpp { + using work_unit = std::function; /** * A task within a thread pool. A simple lambda that accepts no parameters and returns void. */ -struct thread_pool_task { +struct DPP_EXPORT thread_pool_task { int priority; work_unit function; }; -struct thread_pool_task_comparator { - bool operator()(const thread_pool_task &a, const thread_pool_task &b) { - return a.priority < b.priority; - }; +struct DPP_EXPORT thread_pool_task_comparator { + bool operator()(const thread_pool_task &a, const thread_pool_task &b) { + return a.priority < b.priority; + }; }; /** * @brief A thread pool contains 1 or more worker threads which accept thread_pool_task lambadas * into a queue, which is processed in-order by whichever thread is free. */ -struct thread_pool { +struct DPP_EXPORT thread_pool { std::vector threads; std::priority_queue, thread_pool_task_comparator> tasks; std::mutex queue_mutex; @@ -59,3 +61,5 @@ struct thread_pool { ~thread_pool(); void enqueue(thread_pool_task task); }; + +} \ No newline at end of file diff --git a/mlspp/include/namespace.h b/mlspp/include/namespace.h index d07ba5ee94..43a5121ccb 100755 --- a/mlspp/include/namespace.h +++ b/mlspp/include/namespace.h @@ -1,4 +1,4 @@ #pragma once // Configurable top-level MLS namespace -#define MLS_NAMESPACE ../include/dpp/mlspp/mls +#define MLS_NAMESPACE mls diff --git a/src/dpp/dns.cpp b/src/dpp/dns.cpp index 42aed3ce8f..a927d68747 100644 --- a/src/dpp/dns.cpp +++ b/src/dpp/dns.cpp @@ -54,8 +54,7 @@ socket dns_cache_entry::make_connecting_socket() const { return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); } -const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) -{ +const dns_cache_entry *resolve_hostname(const std::string &hostname, const std::string &port) { addrinfo hints, *addrs; dns_cache_t::const_iterator iter; time_t now = time(nullptr); diff --git a/src/dpp/socketengine.cpp b/src/dpp/socketengine.cpp index 10301d4ec6..4cd7c7b8f6 100644 --- a/src/dpp/socketengine.cpp +++ b/src/dpp/socketengine.cpp @@ -31,7 +31,7 @@ namespace dpp { bool socket_engine_base::register_socket(const socket_events &e) { - if (e.fd > INVALID_SOCKET && fds.find(e.fd) == fds.end()) { + if (e.fd != INVALID_SOCKET && fds.find(e.fd) == fds.end()) { fds.emplace(e.fd, std::make_unique(e)); return true; } @@ -39,7 +39,7 @@ bool socket_engine_base::register_socket(const socket_events &e) { } bool socket_engine_base::update_socket(const socket_events &e) { - if (e.fd > INVALID_SOCKET && fds.find(e.fd) != fds.end()) { + if (e.fd != INVALID_SOCKET && fds.find(e.fd) != fds.end()) { auto iter = fds.find(e.fd); *(iter->second) = e; return true; @@ -48,7 +48,7 @@ bool socket_engine_base::update_socket(const socket_events &e) { } socket_engine_base::socket_engine_base(cluster* creator) : owner(creator) { -#ifndef WIN32 +#ifndef _WIN32 set_signal_handler(SIGALRM); set_signal_handler(SIGXFSZ); set_signal_handler(SIGCHLD); @@ -108,7 +108,7 @@ bool socket_engine_base::delete_socket(dpp::socket fd) { } bool socket_engine_base::remove_socket(dpp::socket fd) { - return false; + return true; } } diff --git a/src/dpp/socketengines/epoll.cpp b/src/dpp/socketengines/epoll.cpp index 77b2b34c71..4b7ee1a7c7 100644 --- a/src/dpp/socketengines/epoll.cpp +++ b/src/dpp/socketengines/epoll.cpp @@ -49,7 +49,7 @@ int modify_event(int epoll_handle, socket_events* eh, int new_events) { return new_events; } -struct socket_engine_epoll : public socket_engine_base { +struct DPP_EXPORT socket_engine_epoll : public socket_engine_base { int epoll_handle{INVALID_SOCKET}; static const int epoll_hint = 128; @@ -196,7 +196,7 @@ struct socket_engine_epoll : public socket_engine_base { } }; -std::unique_ptr create_socket_engine(cluster* creator) { +DPP_EXPORT std::unique_ptr create_socket_engine(cluster *creator) { return std::make_unique(creator); } diff --git a/src/dpp/socketengines/kqueue.cpp b/src/dpp/socketengines/kqueue.cpp index d15b2574f5..bc93272714 100644 --- a/src/dpp/socketengines/kqueue.cpp +++ b/src/dpp/socketengines/kqueue.cpp @@ -31,7 +31,7 @@ namespace dpp { -struct socket_engine_kqueue : public socket_engine_base { +struct DPP_EXPORT socket_engine_kqueue : public socket_engine_base { int kqueue_handle{INVALID_SOCKET}; std::array ke_list; @@ -147,7 +147,7 @@ struct socket_engine_kqueue : public socket_engine_base { } }; -std::unique_ptr create_socket_engine(cluster* creator) { +DPP_EXPORT std::unique_ptr create_socket_engine(cluster *creator) { return std::make_unique(creator); } diff --git a/src/dpp/socketengines/poll.cpp b/src/dpp/socketengines/poll.cpp index 148e01a716..d2e0da1a82 100644 --- a/src/dpp/socketengines/poll.cpp +++ b/src/dpp/socketengines/poll.cpp @@ -44,7 +44,7 @@ namespace dpp { -struct socket_engine_poll : public socket_engine_base { +struct DPP_EXPORT socket_engine_poll : public socket_engine_base { /* We store the pollfds as a vector. This means that insertion, deletion and updating * are comparatively slow O(n), but these operations don't happen too often. Obtaining the @@ -53,56 +53,68 @@ struct socket_engine_poll : public socket_engine_base { * anyway. */ std::vector poll_set; + pollfd out_set[FD_SETSIZE]{0}; void process_events() final { const int poll_delay = 1000; - int i = poll(poll_set.data(), static_cast(poll_set.size()), poll_delay); - int processed = 0; - for (size_t index = 0; index < poll_set.size() && processed < i; index++) { - const int fd = poll_set[index].fd; - const short revents = poll_set[index].revents; - - if (revents > 0) { - processed++; + if (poll_set.empty()) { + /* On many platforms, it is not possible to wait on an empty set */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } else { + if (poll_set.size() > FD_SETSIZE) { + throw dpp::connection_exception("poll() does not support more than FD_SETSIZE active sockets at once!"); } - auto iter = fds.find(fd); - if (iter == fds.end()) { - continue; - } - socket_events* eh = iter->second.get(); + std::copy(poll_set.begin(), poll_set.end(), out_set); - try { + int i = poll(out_set, static_cast(poll_set.size()), poll_delay); + int processed = 0; - if ((revents & POLLHUP) != 0) { - eh->on_error(fd, *eh, 0); - continue; + for (size_t index = 0; index < poll_set.size() && processed < i; index++) { + const int fd = out_set[index].fd; + const short revents = out_set[index].revents; + + if (revents > 0) { + processed++; } - if ((revents & POLLERR) != 0) { - socklen_t codesize = sizeof(int); - int errcode{}; - if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char*)&errcode, &codesize) < 0) { - errcode = errno; - } - eh->on_error(fd, *eh, errcode); + auto iter = fds.find(fd); + if (iter == fds.end()) { continue; } + socket_events *eh = iter->second.get(); - if ((revents & POLLIN) != 0) { - eh->on_read(fd, *eh); - } + try { - if ((revents & POLLOUT) != 0) { - int mask = eh->flags; - mask &= ~WANT_WRITE; - eh->flags = mask; - eh->on_write(fd, *eh); - } + if ((revents & POLLHUP) != 0) { + eh->on_error(fd, *eh, 0); + continue; + } - } catch (const std::exception& e) { - eh->on_error(fd, *eh, 0); + if ((revents & POLLERR) != 0) { + socklen_t codesize = sizeof(int); + int errcode{}; + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *) &errcode, &codesize) < 0) { + errcode = errno; + } + eh->on_error(fd, *eh, errcode); + continue; + } + + if ((revents & POLLIN) != 0) { + eh->on_read(fd, *eh); + } + + if ((revents & POLLOUT) != 0) { + eh->flags &= ~WANT_WRITE; + update_socket(*eh); + eh->on_write(fd, *eh); + } + + } catch (const std::exception &e) { + eh->on_error(fd, *eh, 0); + } } } prune(); @@ -126,9 +138,6 @@ struct socket_engine_poll : public socket_engine_base { if ((e.flags & WANT_WRITE) != 0) { fd_info.events |= POLLOUT; } - if ((e.flags & WANT_ERROR) != 0) { - fd_info.events |= POLLERR; - } poll_set.push_back(fd_info); } return r; @@ -149,9 +158,6 @@ struct socket_engine_poll : public socket_engine_base { if ((e.flags & WANT_WRITE) != 0) { fd_info.events |= POLLOUT; } - if ((e.flags & WANT_ERROR) != 0) { - fd_info.events |= POLLERR; - } break; } } @@ -176,7 +182,7 @@ struct socket_engine_poll : public socket_engine_base { } }; -std::unique_ptr create_socket_engine(cluster* creator) { +DPP_EXPORT std::unique_ptr create_socket_engine(cluster* creator) { return std::make_unique(creator); } diff --git a/src/dpp/thread_pool.cpp b/src/dpp/thread_pool.cpp index 7d32baf19d..2ee5624117 100644 --- a/src/dpp/thread_pool.cpp +++ b/src/dpp/thread_pool.cpp @@ -24,6 +24,8 @@ #include #include +namespace dpp { + thread_pool::thread_pool(size_t num_threads) { for (size_t i = 0; i < num_threads; ++i) { threads.emplace_back([this, i]() { @@ -51,24 +53,24 @@ thread_pool::thread_pool(size_t num_threads) { } } -thread_pool::~thread_pool() -{ +thread_pool::~thread_pool() { { std::unique_lock lock(queue_mutex); stop = true; } cv.notify_all(); - for (auto& thread : threads) { + for (auto &thread: threads) { thread.join(); } } -void thread_pool::enqueue(thread_pool_task task) -{ +void thread_pool::enqueue(thread_pool_task task) { { std::unique_lock lock(queue_mutex); tasks.emplace(std::move(task)); } cv.notify_one(); } + +} \ No newline at end of file diff --git a/src/dpp/voice/enabled/discover_ip.cpp b/src/dpp/voice/enabled/discover_ip.cpp index 7061c20f08..96ab091a13 100644 --- a/src/dpp/voice/enabled/discover_ip.cpp +++ b/src/dpp/voice/enabled/discover_ip.cpp @@ -148,7 +148,11 @@ std::string discord_voice_client::discover_ip() { return ""; } address_t bind_port(this->ip, this->port); +#ifndef _WIN32 if (::connect(socket.fd, bind_port.get_socket_address(), bind_port.size()) < 0) { +#else + if (WSAConnect(socket.fd, bind_port.get_socket_address(), bind_port.size(), nullptr, nullptr, nullptr, nullptr) < 0) { +#endif log(ll_warning, "Could not connect socket for IP discovery"); return ""; } diff --git a/src/sockettest/socket.cpp b/src/sockettest/socket.cpp index 08ce6b12da..057b7b112e 100644 --- a/src/sockettest/socket.cpp +++ b/src/sockettest/socket.cpp @@ -22,9 +22,18 @@ #include #include -#include #include #include +#ifndef _WIN32 + #include +#else + /* Windows-specific sockets includes */ + #include + #include + #include + /* Windows sockets library */ + #pragma comment(lib, "ws2_32") +#endif int main() { dpp::cluster cl("no-token"); @@ -37,7 +46,11 @@ int main() { if (sfd == INVALID_SOCKET) { std::cerr << "Couldn't create outbound socket on port 80\n"; exit(1); +#ifndef _WIN32 } else if (::connect(sfd, destination.get_socket_address(), destination.size()) != 0) { +#else + } else if (::WSAConnect(sfd, destination.get_socket_address(), destination.size(), nullptr, nullptr, nullptr, nullptr) != 0) { +#endif dpp::close_socket(sfd); std::cerr << "Couldn't connect outbound socket on port 80\n"; exit(1); @@ -50,7 +63,7 @@ int main() { int r = 0; do { char buf[128]{0}; - r = ::read(e.fd, buf, sizeof(buf)); + r = ::recv(e.fd, buf, sizeof(buf), 0); if (r > 0) { buf[127] = 0; std::cout << buf; @@ -65,7 +78,8 @@ int main() { [](dpp::socket fd, const struct dpp::socket_events& e) { std::cout << "WANT_WRITE event on socket " << fd << "\n"; constexpr std::string_view request{"GET / HTTP/1.0\r\nConnection: close\r\n\r\n"}; - auto written = ::write(e.fd, request.data(), request.length()); + std::cout << "Writing: " << request.data() << "\n"; + auto written = ::send(e.fd, request.data(), request.length(), 0); std::cout << "Written: " << written << "\n"; }, [](dpp::socket fd, const struct dpp::socket_events&, int error_code) {