diff --git a/include/cppcoro/detail/linux_async_operation.hpp b/include/cppcoro/detail/linux_async_operation.hpp index ec143245..a0ca26f1 100644 --- a/include/cppcoro/detail/linux_async_operation.hpp +++ b/include/cppcoro/detail/linux_async_operation.hpp @@ -86,6 +86,9 @@ namespace cppcoro { auto* operation = static_cast(ioState); operation->m_res = operation->m_completeFunc(); + if (operation->m_res < 0) { + operation->m_res = -errno; + } operation->m_awaitingCoroutine.resume(); } @@ -93,11 +96,11 @@ namespace cppcoro }; + static constexpr int error_operation_cancelled = ECANCELED; template class linux_async_operation_cancellable : protected linux_async_operation_base { - static constexpr int error_operation_cancelled = ECANCELED; protected: @@ -278,6 +281,9 @@ namespace cppcoro auto* operation = static_cast(ioState); operation->m_res = operation->m_completeFunc(); + if (operation->m_res < 0) { + operation->m_res = -errno; + } auto state = operation->m_state.load(std::memory_order_acquire); if (state == state::started) diff --git a/include/cppcoro/net/socket.hpp b/include/cppcoro/net/socket.hpp index a61eac61..efc0de3a 100644 --- a/include/cppcoro/net/socket.hpp +++ b/include/cppcoro/net/socket.hpp @@ -20,6 +20,8 @@ #if CPPCORO_OS_WINNT # include +#elif CPPCORO_OS_LINUX +# include #endif namespace cppcoro @@ -81,6 +83,9 @@ namespace cppcoro static socket create_udpv6(io_service& ioSvc); socket(socket&& other) noexcept; + socket& operator=(socket&& other) noexcept; + socket(const socket& other) noexcept; + socket& operator=(const socket& other) noexcept; /// Closes the socket, releasing any associated resources. /// @@ -90,7 +95,8 @@ namespace cppcoro /// disconnect() and wait until the disconnect operation completes. ~socket(); - socket& operator=(socket&& other) noexcept; + int close(); + #if CPPCORO_OS_WINNT /// Get the Win32 socket handle assocaited with this socket. @@ -104,6 +110,9 @@ namespace cppcoro /// operation completing synchronously or whether it should suspend the coroutine /// and wait until the I/O completion event is dispatched to an I/O thread. bool skip_completion_on_success() noexcept { return m_skipCompletionOnSuccess; } +#elif CPPCORO_OS_LINUX + /// Get the linux fd assocaited with this socket. + cppcoro::detail::linux::fd_t native_handle() noexcept { return m_handle; } #endif /// Get the address and port of the local end-point. @@ -242,20 +251,27 @@ namespace cppcoro void close_send(); void close_recv(); - private: - - friend class socket_accept_operation_impl; - friend class socket_connect_operation_impl; #if CPPCORO_OS_WINNT explicit socket( cppcoro::detail::win32::socket_t handle, bool skipCompletionOnSuccess) noexcept; +#elif CPPCORO_OS_LINUX + explicit socket( + cppcoro::detail::linux::fd_t handle, + cppcoro::detail::linux::message_queue* mq) noexcept; #endif + private: + + friend class socket_accept_operation_impl; + friend class socket_connect_operation_impl; #if CPPCORO_OS_WINNT cppcoro::detail::win32::socket_t m_handle; bool m_skipCompletionOnSuccess; +#elif CPPCORO_OS_LINUX + cppcoro::detail::linux::fd_t m_handle; + cppcoro::detail::linux::message_queue* m_mq; #endif ip_endpoint m_localEndPoint; diff --git a/include/cppcoro/net/socket_accept_operation.hpp b/include/cppcoro/net/socket_accept_operation.hpp index ae966f01..e477fbbc 100644 --- a/include/cppcoro/net/socket_accept_operation.hpp +++ b/include/cppcoro/net/socket_accept_operation.hpp @@ -9,17 +9,23 @@ #include #include +#include +#include + #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif -# include -# include namespace cppcoro { namespace net { +#if CPPCORO_OS_WINNT class socket; class socket_accept_operation_impl @@ -100,9 +106,82 @@ namespace cppcoro socket_accept_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_accept_operation_impl + { + public: + + socket_accept_operation_impl( + socket& listeningSocket, + socket& acceptingSocket) noexcept + : m_listeningSocket(listeningSocket) + , m_acceptingSocket(acceptingSocket) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void get_result(cppcoro::detail::linux_async_operation_base& operation); + + private: + socket& m_listeningSocket; + socket& m_acceptingSocket; + alignas(8) std::uint8_t m_addressBuffer[88]; + }; + + class socket_accept_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_accept_operation( + socket& listeningSocket, + socket& acceptingSocket, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(listeningSocket, acceptingSocket) + {} + + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_accept_operation_impl m_impl; + + }; + + class socket_accept_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + + socket_accept_operation_cancellable( + socket& listeningSocket, + socket& acceptingSocket, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(listeningSocket, acceptingSocket) + {} + + private: + + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_accept_operation_impl m_impl; + + }; +#endif } } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/net/socket_connect_operation.hpp b/include/cppcoro/net/socket_connect_operation.hpp index b7eedd3c..16d98bdb 100644 --- a/include/cppcoro/net/socket_connect_operation.hpp +++ b/include/cppcoro/net/socket_connect_operation.hpp @@ -12,11 +12,16 @@ #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif namespace cppcoro { namespace net { +#if CPPCORO_OS_WINNT class socket; class socket_connect_operation_impl @@ -87,9 +92,83 @@ namespace cppcoro socket_connect_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_connect_operation_impl + { + public: + + socket_connect_operation_impl( + socket& socket, + const ip_endpoint& remoteEndPoint) noexcept + : m_socket(socket) + , m_remoteEndPoint(remoteEndPoint) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void get_result(cppcoro::detail::linux_async_operation_base& operation); + + private: + + socket& m_socket; + ip_endpoint m_remoteEndPoint; + + }; + + class socket_connect_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_connect_operation( + socket& socket, + const ip_endpoint& remoteEndPoint, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(socket, remoteEndPoint) + {} + + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + decltype(auto) get_result() { return m_impl.get_result(*this); } + + socket_connect_operation_impl m_impl; + + }; + + class socket_connect_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + + socket_connect_operation_cancellable( + socket& socket, + const ip_endpoint& remoteEndPoint, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(socket, remoteEndPoint) + {} + + private: + + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_connect_operation_impl m_impl; + + }; +#endif } } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/net/socket_disconnect_operation.hpp b/include/cppcoro/net/socket_disconnect_operation.hpp index 7bdcc03f..e1b0fbaa 100644 --- a/include/cppcoro/net/socket_disconnect_operation.hpp +++ b/include/cppcoro/net/socket_disconnect_operation.hpp @@ -11,11 +11,16 @@ #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif namespace cppcoro { namespace net { +#if CPPCORO_OS_WINNT class socket; class socket_disconnect_operation_impl @@ -77,9 +82,77 @@ namespace cppcoro socket_disconnect_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_disconnect_operation_impl + { + public: + + socket_disconnect_operation_impl(socket& socket) noexcept + : m_socket(socket) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void get_result(cppcoro::detail::linux_async_operation_base& operation); + + private: + + socket& m_socket; + + }; + + class socket_disconnect_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_disconnect_operation( + socket& socket, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(socket) + {} + + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_disconnect_operation_impl m_impl; + + }; + + class socket_disconnect_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + + socket_disconnect_operation_cancellable( + socket& socket, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(socket) + {} + + private: + + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + void get_result() { m_impl.get_result(*this); } + + socket_disconnect_operation_impl m_impl; + + }; +#endif } } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/net/socket_recv_from_operation.hpp b/include/cppcoro/net/socket_recv_from_operation.hpp index 37f2d01a..95d9bb3c 100644 --- a/include/cppcoro/net/socket_recv_from_operation.hpp +++ b/include/cppcoro/net/socket_recv_from_operation.hpp @@ -15,9 +15,14 @@ #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif namespace cppcoro::net { +#if CPPCORO_OS_WINNT class socket; class socket_recv_from_operation_impl @@ -98,9 +103,94 @@ namespace cppcoro::net socket_recv_from_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_recv_from_operation_impl + { + public: + + socket_recv_from_operation_impl( + socket& socket, + void* buffer, + std::size_t byteCount) noexcept + : m_socket(socket) + , m_buffer(buffer) + , m_byteCount(byteCount) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + std::tuple get_result( + cppcoro::detail::linux_async_operation_base& operation); + + private: + + socket& m_socket; + void* m_buffer; + std::size_t m_byteCount; + + static constexpr std::size_t sockaddrStorageAlignment = 4; + + // Storage suitable for either SOCKADDR_IN or SOCKADDR_IN6 + alignas(sockaddrStorageAlignment) std::uint8_t m_sourceSockaddrStorage[28]; + int m_sourceSockaddrLength; + + }; + + class socket_recv_from_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_recv_from_operation( + socket& socket, + void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(socket, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + decltype(auto) get_result() { return m_impl.get_result(*this); } + + socket_recv_from_operation_impl m_impl; + + }; + + class socket_recv_from_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + + socket_recv_from_operation_cancellable( + socket& socket, + void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(socket, buffer, byteCount) + {} + + private: + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + decltype(auto) get_result() { return m_impl.get_result(*this); } + + socket_recv_from_operation_impl m_impl; + + }; +#endif } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/net/socket_recv_operation.hpp b/include/cppcoro/net/socket_recv_operation.hpp index c9dca8b2..2ef91bcb 100644 --- a/include/cppcoro/net/socket_recv_operation.hpp +++ b/include/cppcoro/net/socket_recv_operation.hpp @@ -13,9 +13,14 @@ #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif namespace cppcoro::net { +#if CPPCORO_OS_WINNT class socket; class socket_recv_operation_impl @@ -86,9 +91,84 @@ namespace cppcoro::net socket_recv_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_recv_operation_impl + { + public: + + socket_recv_operation_impl( + socket& s, + void* buffer, + std::size_t byteCount) noexcept + : m_socket(s) + , m_buffer(buffer) + , m_byteCount(byteCount) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + + private: + + socket& m_socket; + void* m_buffer; + std::size_t m_byteCount; + + }; + + class socket_recv_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_recv_operation( + socket& s, + void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + socket_recv_operation_impl m_impl; + + }; + + class socket_recv_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + socket_recv_operation_cancellable( + socket& s, + void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { m_impl.cancel(*this); } + + socket_recv_operation_impl m_impl; + + }; +#endif } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/net/socket_send_operation.hpp b/include/cppcoro/net/socket_send_operation.hpp index 702d2abd..1ba8e287 100644 --- a/include/cppcoro/net/socket_send_operation.hpp +++ b/include/cppcoro/net/socket_send_operation.hpp @@ -13,9 +13,14 @@ #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif namespace cppcoro::net { +#if CPPCORO_OS_WINNT class socket; class socket_send_operation_impl @@ -86,9 +91,84 @@ namespace cppcoro::net socket_send_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_send_operation_impl + { + public: + + socket_send_operation_impl( + socket& s, + const void* buffer, + std::size_t byteCount) noexcept + : m_socket(s) + , m_buffer(buffer) + , m_byteCount(byteCount) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + + private: + + socket& m_socket; + const void* m_buffer; + std::size_t m_byteCount; + + }; + + class socket_send_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_send_operation( + socket& s, + const void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(s, buffer, byteCount) + {} + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + socket_send_operation_impl m_impl; + + }; + + class socket_send_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + + socket_send_operation_cancellable( + socket& s, + const void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(s, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { return m_impl.cancel(*this); } + + socket_send_operation_impl m_impl; + + }; +#endif } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/net/socket_send_to_operation.hpp b/include/cppcoro/net/socket_send_to_operation.hpp index 60d51b24..3e2bb212 100644 --- a/include/cppcoro/net/socket_send_to_operation.hpp +++ b/include/cppcoro/net/socket_send_to_operation.hpp @@ -14,9 +14,14 @@ #if CPPCORO_OS_WINNT # include # include +#elif CPPCORO_OS_LINUX +# include +# include +#endif namespace cppcoro::net { +#if CPPCORO_OS_WINNT class socket; class socket_send_to_operation_impl @@ -92,9 +97,89 @@ namespace cppcoro::net socket_send_to_operation_impl m_impl; }; +#elif CPPCORO_OS_LINUX + class socket; + + class socket_send_to_operation_impl + { + public: + + socket_send_to_operation_impl( + socket& s, + const ip_endpoint& destination, + const void* buffer, + std::size_t byteCount) noexcept + : m_socket(s) + , m_destination(destination) + , m_buffer(buffer) + , m_byteCount(byteCount) + {} + + bool try_start(cppcoro::detail::linux_async_operation_base& operation) noexcept; + void cancel(cppcoro::detail::linux_async_operation_base& operation) noexcept; + + private: + + socket& m_socket; + ip_endpoint m_destination; + const void* m_buffer; + std::size_t m_byteCount; + + }; + + class socket_send_to_operation + : public cppcoro::detail::linux_async_operation + { + public: + + socket_send_to_operation( + socket& s, + const ip_endpoint& destination, + const void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq) noexcept + : cppcoro::detail::linux_async_operation(mq) + , m_impl(s, destination, buffer, byteCount) + {} + + private: + + friend class cppcoro::detail::linux_async_operation; + + bool try_start() noexcept { return m_impl.try_start(*this); } + + socket_send_to_operation_impl m_impl; + + }; + + class socket_send_to_operation_cancellable + : public cppcoro::detail::linux_async_operation_cancellable + { + public: + + socket_send_to_operation_cancellable( + socket& s, + const ip_endpoint& destination, + const void* buffer, + std::size_t byteCount, + cppcoro::detail::linux::message_queue* mq, + cancellation_token&& ct) noexcept + : cppcoro::detail::linux_async_operation_cancellable(mq, std::move(ct)) + , m_impl(s, destination, buffer, byteCount) + {} + + private: + friend class cppcoro::detail::linux_async_operation_cancellable; + + bool try_start() noexcept { return m_impl.try_start(*this); } + void cancel() noexcept { return m_impl.cancel(*this); } + + socket_send_to_operation_impl m_impl; + + }; +#endif } -#endif // CPPCORO_OS_WINNT #endif diff --git a/include/cppcoro/resume_on.hpp b/include/cppcoro/resume_on.hpp index 90710d97..47bc32ae 100644 --- a/include/cppcoro/resume_on.hpp +++ b/include/cppcoro/resume_on.hpp @@ -117,15 +117,12 @@ namespace cppcoro template async_generator resume_on(SCHEDULER& scheduler, async_generator source) { - auto iter = co_await source.begin(); - auto endIter = source.end(); - - while (iter != endIter) + for (detail::async_generator_iterator iter = co_await source.begin(); iter != source.end();) { auto& value = *iter; co_await scheduler.schedule(); co_yield value; - co_await ++iter; + co_await ++iter; // moved due to error: insufficient contextual information to determine type on old compilers } } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e20f51ab..0f73a3f3 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -164,6 +164,19 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Linux") list(TRANSFORM linuxDetailIncludes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/detail/") list(APPEND detailIncludes ${linuxDetailIncludes}) + set(linuxNetIncludes + socket.hpp + socket_accept_operation.hpp + socket_connect_operation.hpp + socket_disconnect_operation.hpp + socket_recv_operation.hpp + socket_recv_from_operation.hpp + socket_send_operation.hpp + socket_send_to_operation.hpp + ) + list(TRANSFORM linuxNetIncludes PREPEND "${PROJECT_SOURCE_DIR}/include/cppcoro/net/") + list(APPEND netIncludes ${linuxNetIncludes}) + set(linuxSources linux.cpp io_service.cpp @@ -175,15 +188,15 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Linux") read_write_file.cpp file_read_operation.cpp file_write_operation.cpp - # socket_helpers.cpp - # socket.cpp - # socket_accept_operation.cpp - # socket_connect_operation.cpp - # socket_disconnect_operation.cpp - # socket_send_operation.cpp - # socket_send_to_operation.cpp - # socket_recv_operation.cpp - # socket_recv_from_operation.cpp + socket_helpers.cpp + socket.cpp + socket_accept_operation.cpp + socket_connect_operation.cpp + socket_disconnect_operation.cpp + socket_send_operation.cpp + socket_send_to_operation.cpp + socket_recv_operation.cpp + socket_recv_from_operation.cpp ) list(APPEND sources ${linuxSources}) endif() diff --git a/lib/socket.cpp b/lib/socket.cpp index e686f76f..d5844393 100644 --- a/lib/socket.cpp +++ b/lib/socket.cpp @@ -23,16 +23,26 @@ # include # include # include - +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include +# include +#define closesocket close +#define INVALID_SOCKET (-1) +#define SOCKET_ERROR (-1) +#endif namespace { namespace local { - std::tuple create_socket( +#if CPPCORO_OS_WINNT + cppcoro::net::socket create_socket( int addressFamily, int socketType, int protocol, - HANDLE ioCompletionPort) + cppcoro::io_service& ioSvc) { // Enumerate available protocol providers for the specified socket type. @@ -132,7 +142,7 @@ namespace { const HANDLE result = ::CreateIoCompletionPort( (HANDLE)socketHandle, - ioCompletionPort, + ioSvc.native_iocp_handle(), ULONG_PTR(0), DWORD(0)); if (result == nullptr) @@ -192,19 +202,70 @@ namespace } } - return std::make_tuple(socketHandle, skipCompletionPortOnSuccess); + return cppcoro::net::socket(socketHandle, skipCompletionPortOnSuccess); + } +#elif CPPCORO_OS_LINUX + cppcoro::net::socket create_socket( + int addressFamily, + int socketType, + int protocol, + cppcoro::io_service& ioSvc) + { + + const int socketHandle = ::socket(addressFamily, socketType | SOCK_NONBLOCK, protocol); + if (socketHandle == INVALID_SOCKET) + { + const int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "Error creating socket"); + } + + auto closeSocketOnFailure = cppcoro::on_scope_failure([&] + { + ::closesocket(socketHandle); + }); + + if (socketType == SOCK_STREAM) + { + // Turn off linger so that the destructor doesn't block while closing + // the socket or silently continue to flush remaining data in the + // background after ::closesocket() is called, which could fail and + // we'd never know about it. + // We expect clients to call Disconnect() or use CloseSend() to cleanly + // shut-down connections instead. + struct linger value; + value.l_onoff = 0; + const int result = ::setsockopt(socketHandle, + SOL_SOCKET, + SO_LINGER, + reinterpret_cast(&value), + sizeof(value)); + if (result == SOCKET_ERROR) + { + const int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "Error creating socket: setsockopt(SO_LINGER)"); + } + } + + return cppcoro::net::socket(socketHandle, ioSvc.get_mq()); } +#endif } } cppcoro::net::socket cppcoro::net::socket::create_tcpv4(io_service& ioSvc) { +#if CPPCORO_OS_WINNT ioSvc.ensure_winsock_initialised(); +#endif + auto result = local::create_socket( + AF_INET, SOCK_STREAM, IPPROTO_TCP, ioSvc); - auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( - AF_INET, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle()); - - socket result(socketHandle, skipCompletionPortOnSuccess); result.m_localEndPoint = ipv4_endpoint(); result.m_remoteEndPoint = ipv4_endpoint(); return result; @@ -212,12 +273,12 @@ cppcoro::net::socket cppcoro::net::socket::create_tcpv4(io_service& ioSvc) cppcoro::net::socket cppcoro::net::socket::create_tcpv6(io_service& ioSvc) { +#if CPPCORO_OS_WINNT ioSvc.ensure_winsock_initialised(); +#endif + auto result = local::create_socket( + AF_INET6, SOCK_STREAM, IPPROTO_TCP, ioSvc); - auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( - AF_INET6, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle()); - - socket result(socketHandle, skipCompletionPortOnSuccess); result.m_localEndPoint = ipv6_endpoint(); result.m_remoteEndPoint = ipv6_endpoint(); return result; @@ -225,12 +286,12 @@ cppcoro::net::socket cppcoro::net::socket::create_tcpv6(io_service& ioSvc) cppcoro::net::socket cppcoro::net::socket::create_udpv4(io_service& ioSvc) { +#if CPPCORO_OS_WINNT ioSvc.ensure_winsock_initialised(); +#endif + auto result = local::create_socket( + AF_INET, SOCK_DGRAM, IPPROTO_UDP, ioSvc); - auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( - AF_INET, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle()); - - socket result(socketHandle, skipCompletionPortOnSuccess); result.m_localEndPoint = ipv4_endpoint(); result.m_remoteEndPoint = ipv4_endpoint(); return result; @@ -238,12 +299,12 @@ cppcoro::net::socket cppcoro::net::socket::create_udpv4(io_service& ioSvc) cppcoro::net::socket cppcoro::net::socket::create_udpv6(io_service& ioSvc) { +#if CPPCORO_OS_WINNT ioSvc.ensure_winsock_initialised(); +#endif + auto result = local::create_socket( + AF_INET6, SOCK_DGRAM, IPPROTO_UDP, ioSvc); - auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket( - AF_INET6, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle()); - - socket result(socketHandle, skipCompletionPortOnSuccess); result.m_localEndPoint = ipv6_endpoint(); result.m_remoteEndPoint = ipv6_endpoint(); return result; @@ -251,19 +312,15 @@ cppcoro::net::socket cppcoro::net::socket::create_udpv6(io_service& ioSvc) cppcoro::net::socket::socket(socket&& other) noexcept : m_handle(std::exchange(other.m_handle, INVALID_SOCKET)) +#if CPPCORO_OS_WINNT , m_skipCompletionOnSuccess(other.m_skipCompletionOnSuccess) +#elif CPPCORO_OS_LINUX + , m_mq(other.m_mq) +#endif , m_localEndPoint(std::move(other.m_localEndPoint)) , m_remoteEndPoint(std::move(other.m_remoteEndPoint)) {} -cppcoro::net::socket::~socket() -{ - if (m_handle != INVALID_SOCKET) - { - ::closesocket(m_handle); - } -} - cppcoro::net::socket& cppcoro::net::socket::operator=(socket&& other) noexcept { @@ -274,33 +331,80 @@ cppcoro::net::socket::operator=(socket&& other) noexcept } m_handle = handle; +#if CPPCORO_OS_WINNT + m_skipCompletionOnSuccess = other.m_skipCompletionOnSuccess; +#elif CPPCORO_OS_LINUX + m_mq = other.m_mq; +#endif + m_localEndPoint = other.m_localEndPoint; + m_remoteEndPoint = other.m_remoteEndPoint; + + return *this; +} + +#if CPPCORO_OS_WINNT +cppcoro::detail::win32::socket_t duplicate_socket(const cppcoro::detail::win32::socket_t& handle) { + WSAPROTOCOL_INFO wsa_pi; + WSADuplicateSocket(handle, GetCurrentProcessId(), &wsa_pi); + return WSASocket(wsa_pi.iAddressFamily, wsa_pi.iSocketType, wsa_pi.iProtocol, &wsa_pi, 0, 0); +} +#elif CPPCORO_OS_LINUX +cppcoro::detail::linux::fd_t duplicate_socket(const cppcoro::detail::linux::fd_t& handle) { + return dup(handle); +} +#endif + +cppcoro::net::socket::socket(const socket& other) noexcept + : m_handle(duplicate_socket(other.m_handle)) +#if CPPCORO_OS_WINNT + , m_skipCompletionOnSuccess(other.m_skipCompletionOnSuccess) +#elif CPPCORO_OS_LINUX + , m_mq(other.m_mq) +#endif + , m_localEndPoint(other.m_localEndPoint) + , m_remoteEndPoint(other.m_remoteEndPoint) +{} + +cppcoro::net::socket& +cppcoro::net::socket::operator=(const socket& other) noexcept +{ + m_handle = duplicate_socket(other.m_handle); +#if CPPCORO_OS_WINNT m_skipCompletionOnSuccess = other.m_skipCompletionOnSuccess; +#elif CPPCORO_OS_LINUX + m_mq = other.m_mq; +#endif m_localEndPoint = other.m_localEndPoint; m_remoteEndPoint = other.m_remoteEndPoint; return *this; } +cppcoro::net::socket::~socket() +{ + close(); +} + +int cppcoro::net::socket::close() +{ + if (m_handle != INVALID_SOCKET) + { + int res = ::closesocket(m_handle); + m_handle = INVALID_SOCKET; + return res; + } + return 0; +} + +#if CPPCORO_OS_WINNT void cppcoro::net::socket::bind(const ip_endpoint& localEndPoint) { SOCKADDR_STORAGE sockaddrStorage = { 0 }; SOCKADDR* sockaddr = reinterpret_cast(&sockaddrStorage); - if (localEndPoint.is_ipv4()) - { - SOCKADDR_IN& ipv4Sockaddr = *reinterpret_cast(sockaddr); - ipv4Sockaddr.sin_family = AF_INET; - std::memcpy(&ipv4Sockaddr.sin_addr, localEndPoint.to_ipv4().address().bytes(), 4); - ipv4Sockaddr.sin_port = localEndPoint.to_ipv4().port(); - } - else - { - SOCKADDR_IN6& ipv6Sockaddr = *reinterpret_cast(sockaddr); - ipv6Sockaddr.sin6_family = AF_INET6; - std::memcpy(&ipv6Sockaddr.sin6_addr, localEndPoint.to_ipv6().address().bytes(), 16); - ipv6Sockaddr.sin6_port = localEndPoint.to_ipv6().port(); - } + const int addrLength = + detail::ip_endpoint_to_sockaddr(localEndPoint, std::ref(sockaddrStorage)); - int result = ::bind(m_handle, sockaddr, sizeof(sockaddrStorage)); + int result = ::bind(m_handle, sockaddr, addrLength); if (result != 0) { // WSANOTINITIALISED: WSAStartup not called @@ -373,91 +477,210 @@ void cppcoro::net::socket::listen(std::uint32_t backlog) "Failed to start listening on bound endpoint: listen"); } } +#elif CPPCORO_OS_LINUX + +void cppcoro::net::socket::bind(const ip_endpoint& localEndPoint) +{ + sockaddr_storage sockaddrStorage = { 0 }; + sockaddr* sa = reinterpret_cast(&sockaddrStorage); + const int addrLength = + detail::ip_endpoint_to_sockaddr(localEndPoint, std::ref(sockaddrStorage)); + + int result = ::bind(m_handle, sa, addrLength); + if (result != 0) + { + int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "Error binding to endpoint: bind()"); + } + + socklen_t sockaddrLen = sizeof(sockaddrStorage); + result = ::getsockname(m_handle, sa, &sockaddrLen); + if (result == 0) + { + m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(*sa); + } + else + { + m_localEndPoint = localEndPoint; + } +} + +void cppcoro::net::socket::listen() +{ + int result = ::listen(m_handle, SOMAXCONN); + if (result != 0) + { + int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "Failed to start listening on bound endpoint: listen"); + } +} + +void cppcoro::net::socket::listen(std::uint32_t backlog) +{ + if (backlog > 0x7FFFFFFF) + { + backlog = 0x7FFFFFFF; + } + + int result = ::listen(m_handle, (int)backlog); + if (result != 0) + { + int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "Failed to start listening on bound endpoint: listen"); + } +} +#endif cppcoro::net::socket_accept_operation cppcoro::net::socket::accept(socket& acceptingSocket) noexcept { - return socket_accept_operation{ *this, acceptingSocket }; +#if CPPCORO_OS_WINNT + return socket_accept_operation{ *this, acceptingSocket}; +#elif CPPCORO_OS_LINUX + return socket_accept_operation{ *this, acceptingSocket, m_mq}; +#endif } cppcoro::net::socket_accept_operation_cancellable cppcoro::net::socket::accept(socket& acceptingSocket, cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_accept_operation_cancellable{ *this, acceptingSocket, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_accept_operation_cancellable{ *this, acceptingSocket, m_mq, std::move(ct) }; +#endif } cppcoro::net::socket_connect_operation cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint) noexcept { +#if CPPCORO_OS_WINNT return socket_connect_operation{ *this, remoteEndPoint }; +#elif CPPCORO_OS_LINUX + return socket_connect_operation{ *this, remoteEndPoint, m_mq }; +#endif } cppcoro::net::socket_connect_operation_cancellable cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint, cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_connect_operation_cancellable{ *this, remoteEndPoint, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_connect_operation_cancellable{ *this, remoteEndPoint, m_mq, std::move(ct) }; +#endif } cppcoro::net::socket_disconnect_operation cppcoro::net::socket::disconnect() noexcept { +#if CPPCORO_OS_WINNT return socket_disconnect_operation(*this); +#elif CPPCORO_OS_LINUX + return socket_disconnect_operation(*this, m_mq); +#endif } cppcoro::net::socket_disconnect_operation_cancellable cppcoro::net::socket::disconnect(cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_disconnect_operation_cancellable{ *this, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_disconnect_operation_cancellable{ *this, m_mq, std::move(ct) }; +#endif } cppcoro::net::socket_send_operation cppcoro::net::socket::send(const void* buffer, std::size_t byteCount) noexcept { +#if CPPCORO_OS_WINNT return socket_send_operation{ *this, buffer, byteCount }; +#elif CPPCORO_OS_LINUX + return socket_send_operation{ *this, buffer, byteCount, m_mq }; +#endif } cppcoro::net::socket_send_operation_cancellable cppcoro::net::socket::send(const void* buffer, std::size_t byteCount, cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_send_operation_cancellable{ *this, buffer, byteCount, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_send_operation_cancellable{ *this, buffer, byteCount, m_mq, std::move(ct) }; +#endif } cppcoro::net::socket_recv_operation cppcoro::net::socket::recv(void* buffer, std::size_t byteCount) noexcept { +#if CPPCORO_OS_WINNT return socket_recv_operation{ *this, buffer, byteCount }; +#elif CPPCORO_OS_LINUX + return socket_recv_operation{ *this, buffer, byteCount, m_mq }; +#endif } cppcoro::net::socket_recv_operation_cancellable cppcoro::net::socket::recv(void* buffer, std::size_t byteCount, cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_recv_operation_cancellable{ *this, buffer, byteCount, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_recv_operation_cancellable{ *this, buffer, byteCount, m_mq, std::move(ct) }; +#endif } cppcoro::net::socket_recv_from_operation cppcoro::net::socket::recv_from(void* buffer, std::size_t byteCount) noexcept { +#if CPPCORO_OS_WINNT return socket_recv_from_operation{ *this, buffer, byteCount }; +#elif CPPCORO_OS_LINUX + return socket_recv_from_operation{ *this, buffer, byteCount, m_mq }; +#endif } cppcoro::net::socket_recv_from_operation_cancellable cppcoro::net::socket::recv_from(void* buffer, std::size_t byteCount, cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_recv_from_operation_cancellable{ *this, buffer, byteCount, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_recv_from_operation_cancellable{ *this, buffer, byteCount, m_mq, std::move(ct) }; +#endif } cppcoro::net::socket_send_to_operation cppcoro::net::socket::send_to(const ip_endpoint& destination, const void* buffer, std::size_t byteCount) noexcept { +#if CPPCORO_OS_WINNT return socket_send_to_operation{ *this, destination, buffer, byteCount }; +#elif CPPCORO_OS_LINUX + return socket_send_to_operation{ *this, destination, buffer, byteCount, m_mq }; +#endif } cppcoro::net::socket_send_to_operation_cancellable cppcoro::net::socket::send_to(const ip_endpoint& destination, const void* buffer, std::size_t byteCount, cancellation_token ct) noexcept { +#if CPPCORO_OS_WINNT return socket_send_to_operation_cancellable{ *this, destination, buffer, byteCount, std::move(ct) }; +#elif CPPCORO_OS_LINUX + return socket_send_to_operation_cancellable{ *this, destination, buffer, byteCount, m_mq, std::move(ct) }; +#endif } +#if CPPCORO_OS_WINNT void cppcoro::net::socket::close_send() { int result = ::shutdown(m_handle, SD_SEND); @@ -491,5 +714,39 @@ cppcoro::net::socket::socket( , m_skipCompletionOnSuccess(skipCompletionOnSuccess) { } +#elif CPPCORO_OS_LINUX +void cppcoro::net::socket::close_send() +{ + int result = ::shutdown(m_handle, SHUT_WR); + if (result == SOCKET_ERROR) + { + int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "failed to close socket send stream: shutdown(SD_SEND)"); + } +} + +void cppcoro::net::socket::close_recv() +{ + int result = ::shutdown(m_handle, SHUT_RD); + if (result == SOCKET_ERROR) + { + int errorCode = errno; + throw std::system_error( + errorCode, + std::system_category(), + "failed to close socket receive stream: shutdown(SD_RECEIVE)"); + } +} + +cppcoro::net::socket::socket( + cppcoro::detail::linux::fd_t handle, + cppcoro::detail::linux::message_queue* mq) noexcept + : m_handle(handle) + , m_mq(mq) +{ +} #endif diff --git a/lib/socket_accept_operation.cpp b/lib/socket_accept_operation.cpp index 6fb01747..4d54c69e 100644 --- a/lib/socket_accept_operation.cpp +++ b/lib/socket_accept_operation.cpp @@ -125,5 +125,59 @@ void cppcoro::net::socket_accept_operation_impl::get_result( } } } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include +bool cppcoro::net::socket_accept_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + static_assert( + (sizeof(m_addressBuffer) / 2) >= (16 + sizeof(sockaddr_in)) && + (sizeof(m_addressBuffer) / 2) >= (16 + sizeof(sockaddr_in6)), + "AcceptEx requires address buffer to be at least 16 bytes more than largest address."); + + operation.m_completeFunc = [=]() { + socklen_t len = sizeof(m_addressBuffer) / 2; + int res = accept(m_listeningSocket.native_handle(), reinterpret_cast(m_addressBuffer), &len); + operation.m_mq->remove_fd_watch(m_listeningSocket.native_handle()); + return res; + }; + operation.m_mq->add_fd_watch(m_listeningSocket.native_handle(), reinterpret_cast(&operation), EPOLLIN); + return true; +} + +void cppcoro::net::socket_accept_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_listeningSocket.native_handle()); +} + +void cppcoro::net::socket_accept_operation_impl::get_result( + cppcoro::detail::linux_async_operation_base& operation) +{ + if (operation.m_res < 0) + { + throw std::system_error{ + static_cast(-operation.m_res), + std::system_category(), + "Accepting a connection failed: accept" + }; + } + + m_acceptingSocket = socket(operation.m_res, m_acceptingSocket.m_mq); + sockaddr* remoteSockaddr = reinterpret_cast(m_addressBuffer); + sockaddr* localSockaddr = reinterpret_cast(m_addressBuffer + sizeof(m_addressBuffer)/2); + + socklen_t len = sizeof(m_addressBuffer) / 2; + getsockname(m_acceptingSocket.native_handle(), localSockaddr, &len); + + m_acceptingSocket.m_localEndPoint = + detail::sockaddr_to_ip_endpoint(*localSockaddr); + + m_acceptingSocket.m_remoteEndPoint = + detail::sockaddr_to_ip_endpoint(*remoteSockaddr); +} #endif diff --git a/lib/socket_connect_operation.cpp b/lib/socket_connect_operation.cpp index 6e9efcf2..ae8858db 100644 --- a/lib/socket_connect_operation.cpp +++ b/lib/socket_connect_operation.cpp @@ -174,5 +174,99 @@ void cppcoro::net::socket_connect_operation_impl::get_result( } } } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include +bool cppcoro::net::socket_connect_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + + sockaddr_storage remoteSockaddrStorage {0}; + const socklen_t sockaddrNameLength = cppcoro::net::detail::ip_endpoint_to_sockaddr( + m_remoteEndPoint, + std::ref(remoteSockaddrStorage)); + + int res = connect(m_socket.native_handle(), reinterpret_cast(&remoteSockaddrStorage), sockaddrNameLength); + if (res < 0 && errno != EINPROGRESS){ + operation.m_res = -errno; + return false; + } + operation.m_completeFunc = [=]() { + int res = connect(m_socket.native_handle(), reinterpret_cast(&remoteSockaddrStorage), sockaddrNameLength); + operation.m_mq->remove_fd_watch(m_socket.native_handle()); + return res; + }; + operation.m_mq->add_fd_watch(m_socket.native_handle(), reinterpret_cast(&operation), EPOLLOUT); + return true; +} + +void cppcoro::net::socket_connect_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_socket.native_handle()); +} + +void cppcoro::net::socket_connect_operation_impl::get_result( + cppcoro::detail::linux_async_operation_base& operation) +{ + if (operation.m_res < 0) + { + if (operation.m_res == -cppcoro::detail::error_operation_cancelled) + { + throw operation_cancelled{}; + } + + throw std::system_error{ + static_cast(-operation.m_res), + std::system_category(), + "Connect operation failed: connect" + }; + } + + { + sockaddr_storage localSockaddr; + socklen_t nameLength = sizeof(localSockaddr); + const int result = ::getsockname( + m_socket.native_handle(), + reinterpret_cast(&localSockaddr), + &nameLength); + if (result == 0) + { + m_socket.m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint( + *reinterpret_cast(&localSockaddr)); + } + else + { + // Failed to get the updated local-end-point + // Just leave m_localEndPoint set to whatever bind() left it as. + // + // TODO: Should we be throwing an exception here instead? + } + } + + { + sockaddr_storage remoteSockaddr; + socklen_t nameLength = sizeof(remoteSockaddr); + const int result = ::getpeername( + m_socket.native_handle(), + reinterpret_cast(&remoteSockaddr), + &nameLength); + if (result == 0) + { + m_socket.m_remoteEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint( + *reinterpret_cast(&remoteSockaddr)); + } + else + { + // Failed to get the actual remote end-point so just fall back to + // remembering the actual end-point that was passed to connect(). + // + // TODO: Should we be throwing an exception here instead? + m_socket.m_remoteEndPoint = m_remoteEndPoint; + } + } +} #endif diff --git a/lib/socket_disconnect_operation.cpp b/lib/socket_disconnect_operation.cpp index 26d12fa6..06bc6458 100644 --- a/lib/socket_disconnect_operation.cpp +++ b/lib/socket_disconnect_operation.cpp @@ -103,5 +103,45 @@ void cppcoro::net::socket_disconnect_operation_impl::get_result( }; } } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include +bool cppcoro::net::socket_disconnect_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_completeFunc = [=]() { + operation.m_mq->remove_fd_watch(m_socket.native_handle()); + int res = m_socket.close(); + return res; + }; + operation.m_mq->add_fd_watch(m_socket.native_handle(), reinterpret_cast(&operation), EPOLLOUT); + return true; +} + +void cppcoro::net::socket_disconnect_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_socket.native_handle()); +} + +void cppcoro::net::socket_disconnect_operation_impl::get_result( + cppcoro::detail::linux_async_operation_base& operation) +{ + if (operation.m_res < 0) + { + if (operation.m_res == -cppcoro::detail::error_operation_cancelled) + { + throw operation_cancelled{}; + } + + throw std::system_error{ + static_cast(-operation.m_res), + std::system_category(), + "Disconnect operation failed: disconnect" + }; + } +} #endif diff --git a/lib/socket_helpers.cpp b/lib/socket_helpers.cpp index d46e7764..7bc890e3 100644 --- a/lib/socket_helpers.cpp +++ b/lib/socket_helpers.cpp @@ -7,14 +7,21 @@ #include #include -#if CPPCORO_OS_WINNT #include #include - -#include -#include -#include -#include +#if CPPCORO_OS_WINNT +# include +# include +# include +# include +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include +#define SOCKADDR_IN sockaddr_in +#define SOCKADDR_IN6 sockaddr_in6 +#endif cppcoro::net::ip_endpoint @@ -41,7 +48,7 @@ cppcoro::net::detail::sockaddr_to_ip_endpoint(const sockaddr& address) noexcept std::memcpy(&ipv6Address, &address, sizeof(ipv6Address)); return ipv6_endpoint{ - ipv6_address{ ipv6Address.sin6_addr.u.Byte }, + ipv6_address{ ipv6Address.sin6_addr.s6_addr }, ntohs(ipv6Address.sin6_port) }; } @@ -69,17 +76,14 @@ int cppcoro::net::detail::ip_endpoint_to_sockaddr( { const auto& ipv6EndPoint = endPoint.to_ipv6(); - SOCKADDR_IN6 ipv6Address; + SOCKADDR_IN6 ipv6Address {0}; ipv6Address.sin6_family = AF_INET6; std::memcpy(&ipv6Address.sin6_addr, ipv6EndPoint.address().bytes(), 16); ipv6Address.sin6_port = htons(ipv6EndPoint.port()); ipv6Address.sin6_flowinfo = 0; - ipv6Address.sin6_scope_struct = SCOPEID_UNSPECIFIED_INIT; std::memcpy(&address.get(), &ipv6Address, sizeof(ipv6Address)); return sizeof(SOCKADDR_IN6); } } - -#endif // CPPCORO_OS_WINNT diff --git a/lib/socket_helpers.hpp b/lib/socket_helpers.hpp index 2083f3a0..2decfd33 100644 --- a/lib/socket_helpers.hpp +++ b/lib/socket_helpers.hpp @@ -5,13 +5,16 @@ #ifndef CPPCORO_PRIVATE_SOCKET_HELPERS_HPP_INCLUDED #define CPPCORO_PRIVATE_SOCKET_HELPERS_HPP_INCLUDED +#include #include #if CPPCORO_OS_WINNT # include +#elif CPPCORO_OS_LINUX +# include +#endif struct sockaddr; struct sockaddr_storage; -#endif namespace cppcoro { @@ -21,7 +24,6 @@ namespace cppcoro namespace detail { -#if CPPCORO_OS_WINNT /// Convert a sockaddr to an IP endpoint. ip_endpoint sockaddr_to_ip_endpoint(const sockaddr& address) noexcept; @@ -38,8 +40,6 @@ namespace cppcoro int ip_endpoint_to_sockaddr( const ip_endpoint& endPoint, std::reference_wrapper address) noexcept; - -#endif } } } diff --git a/lib/socket_recv_from_operation.cpp b/lib/socket_recv_from_operation.cpp index 7865f6cb..f2f798f8 100644 --- a/lib/socket_recv_from_operation.cpp +++ b/lib/socket_recv_from_operation.cpp @@ -6,9 +6,9 @@ #include #include -#if CPPCORO_OS_WINNT # include "socket_helpers.hpp" +#if CPPCORO_OS_WINNT # include # include # include @@ -93,4 +93,63 @@ cppcoro::net::socket_recv_from_operation_impl::get_result( *reinterpret_cast(&m_sourceSockaddrStorage))); } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include +bool cppcoro::net::socket_recv_from_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + static_assert( + sizeof(m_sourceSockaddrStorage) >= sizeof(sockaddr_in) && + sizeof(m_sourceSockaddrStorage) >= sizeof(sockaddr_in6)); + static_assert( + sockaddrStorageAlignment >= alignof(sockaddr_in) && + sockaddrStorageAlignment >= alignof(sockaddr_in6)); + m_sourceSockaddrLength = sizeof(m_sourceSockaddrStorage); + + operation.m_completeFunc = [=]() { + int res = recvfrom( + m_socket.native_handle(), m_buffer, m_byteCount, MSG_TRUNC, + reinterpret_cast(&m_sourceSockaddrStorage), + reinterpret_cast(&m_sourceSockaddrLength) + ); + operation.m_mq->remove_fd_watch(m_socket.native_handle()); + return res; + }; + operation.m_mq->add_fd_watch(m_socket.native_handle(), reinterpret_cast(&operation), EPOLLIN); + return true; +} + +void cppcoro::net::socket_recv_from_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_socket.native_handle()); +} + +std::tuple +cppcoro::net::socket_recv_from_operation_impl::get_result( + cppcoro::detail::linux_async_operation_base& operation) +{ + if (operation.m_res < 0) + { + throw std::system_error( + static_cast(-operation.m_res), + std::system_category(), + "Error receiving message on socket: recvfrom"); + } + if (operation.m_res > m_byteCount) { + throw std::system_error( + ENOMEM, + std::system_category(), + "Error receiving message on socket: recvfrom - receiving buffer was too small"); + + } + + return std::make_tuple( + static_cast(operation.m_res), + detail::sockaddr_to_ip_endpoint( + *reinterpret_cast(&m_sourceSockaddrStorage))); +} #endif diff --git a/lib/socket_recv_operation.cpp b/lib/socket_recv_operation.cpp index 932d1923..f1127379 100644 --- a/lib/socket_recv_operation.cpp +++ b/lib/socket_recv_operation.cpp @@ -63,4 +63,29 @@ void cppcoro::net::socket_recv_operation_impl::cancel( operation.get_overlapped()); } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include + +bool cppcoro::net::socket_recv_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_completeFunc = [=]() { + int res = recv(m_socket.native_handle(), m_buffer, m_byteCount, 0); + operation.m_mq->remove_fd_watch(m_socket.native_handle()); + return res; + }; + operation.m_mq->add_fd_watch(m_socket.native_handle(), reinterpret_cast(&operation), EPOLLIN); + return true; +} + + +void cppcoro::net::socket_recv_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_socket.native_handle()); +} + #endif diff --git a/lib/socket_send_operation.cpp b/lib/socket_send_operation.cpp index ef6a9b7c..f642ff64 100644 --- a/lib/socket_send_operation.cpp +++ b/lib/socket_send_operation.cpp @@ -60,5 +60,28 @@ void cppcoro::net::socket_send_operation_impl::cancel( reinterpret_cast(m_socket.native_handle()), operation.get_overlapped()); } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include + +bool cppcoro::net::socket_send_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_completeFunc = [=]() { + int res = send(m_socket.native_handle(), m_buffer, m_byteCount, 0); + operation.m_mq->remove_fd_watch(m_socket.native_handle()); + return res; + }; + operation.m_mq->add_fd_watch(m_socket.native_handle(), reinterpret_cast(&operation), EPOLLOUT); + return true; +} + +void cppcoro::net::socket_send_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_socket.native_handle()); +} #endif diff --git a/lib/socket_send_to_operation.cpp b/lib/socket_send_to_operation.cpp index 07a875fe..b343b064 100644 --- a/lib/socket_send_to_operation.cpp +++ b/lib/socket_send_to_operation.cpp @@ -6,9 +6,9 @@ #include #include -#if CPPCORO_OS_WINNT -# include "socket_helpers.hpp" +#include "socket_helpers.hpp" +#if CPPCORO_OS_WINNT # include # include # include @@ -69,4 +69,34 @@ void cppcoro::net::socket_send_to_operation_impl::cancel( operation.get_overlapped()); } +#elif CPPCORO_OS_LINUX +# include +# include +# include +# include + +bool cppcoro::net::socket_send_to_operation_impl::try_start( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + sockaddr_storage destinationAddress = {0}; + const socklen_t destinationLength = detail::ip_endpoint_to_sockaddr( + m_destination, std::ref(destinationAddress)); + operation.m_completeFunc = [=]() { + int res = sendto( + m_socket.native_handle(), m_buffer, m_byteCount, 0, + reinterpret_cast(&destinationAddress), + destinationLength + ); + operation.m_mq->remove_fd_watch(m_socket.native_handle()); + return res; + }; + operation.m_mq->add_fd_watch(m_socket.native_handle(), reinterpret_cast(&operation), EPOLLOUT); + return true; +} + +void cppcoro::net::socket_send_to_operation_impl::cancel( + cppcoro::detail::linux_async_operation_base& operation) noexcept +{ + operation.m_mq->remove_fd_watch(m_socket.native_handle()); +} #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 16593c2e..528292bf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -51,7 +51,7 @@ else() scheduling_operator_tests.cpp io_service_tests.cpp file_tests.cpp - # socket_tests.cpp + socket_tests.cpp ) endif() # let more time for some tests diff --git a/test/socket_tests.cpp b/test/socket_tests.cpp index 9d17ac24..961a7f93 100644 --- a/test/socket_tests.cpp +++ b/test/socket_tests.cpp @@ -140,14 +140,14 @@ TEST_CASE("send/recv TCP/IPv4") co_await connectingSocket.connect(listeningSocket.local_endpoint()); - auto receive = [&]() -> task + auto receive = [](socket sock) -> task { std::uint8_t buffer[100]; std::uint64_t totalBytesReceived = 0; std::size_t bytesReceived; do { - bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer)); + bytesReceived = co_await sock.recv(buffer, sizeof(buffer)); for (std::size_t i = 0; i < bytesReceived; ++i) { std::uint64_t byteIndex = totalBytesReceived + i; @@ -163,7 +163,7 @@ TEST_CASE("send/recv TCP/IPv4") co_return 0; }; - auto send = [&]() -> task + auto send = [](socket sock) -> task { std::uint8_t buffer[100]; for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer)) @@ -176,16 +176,16 @@ TEST_CASE("send/recv TCP/IPv4") std::size_t bytesSent = 0; do { - bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent); + bytesSent += co_await sock.send(buffer + bytesSent, sizeof(buffer) - bytesSent); } while (bytesSent < sizeof(buffer)); } - connectingSocket.close_send(); + sock.close_send(); co_return 0; }; - co_await when_all(send(), receive()); + co_await when_all(send(connectingSocket), receive(connectingSocket)); co_await connectingSocket.disconnect(); @@ -283,14 +283,14 @@ TEST_CASE("send/recv TCP/IPv4 many connections") co_await connectingSocket.connect(listeningSocket.local_endpoint()); - auto receive = [&]() -> task<> + auto receive = [](socket sock) -> task<> { std::uint8_t buffer[100]; std::uint64_t totalBytesReceived = 0; std::size_t bytesReceived; do { - bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer)); + bytesReceived = co_await sock.recv(buffer, sizeof(buffer)); for (std::size_t i = 0; i < bytesReceived; ++i) { std::uint64_t byteIndex = totalBytesReceived + i; @@ -304,7 +304,7 @@ TEST_CASE("send/recv TCP/IPv4 many connections") CHECK(totalBytesReceived == 1000); }; - auto send = [&]() -> task<> + auto send = [](socket sock) -> task<> { std::uint8_t buffer[100]; for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer)) @@ -317,14 +317,14 @@ TEST_CASE("send/recv TCP/IPv4 many connections") std::size_t bytesSent = 0; do { - bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent); + bytesSent += co_await sock.send(buffer + bytesSent, sizeof(buffer) - bytesSent); } while (bytesSent < sizeof(buffer)); } - connectingSocket.close_send(); + sock.close_send(); }; - co_await when_all(send(), receive()); + co_await when_all(send(connectingSocket), receive(connectingSocket)); co_await connectingSocket.disconnect(); };