diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..2a60a0a --- /dev/null +++ b/.clang-format @@ -0,0 +1,22 @@ +BasedOnStyle: LLVM +BreakBeforeBraces: Attach + +ColumnLimit: 120 # Match GitHub UI + +UseTab: Always +TabWidth: 4 +IndentWidth: 4 +AccessModifierOffset: -4 +ContinuationIndentWidth: 4 +NamespaceIndentation: All +IndentCaseLabels: false + +PointerAlignment: Left +AlwaysBreakTemplateDeclarations: Yes +SpaceAfterTemplateKeyword: false +AllowShortCaseLabelsOnASingleLine: true +AllowShortIfStatementsOnASingleLine: WithoutElse +AllowShortBlocksOnASingleLine: Always + +FixNamespaceComments: true +ReflowComments: false diff --git a/.github/actions/badge/action.yml b/.github/actions/badge/action.yml new file mode 100644 index 0000000..5bb03d4 --- /dev/null +++ b/.github/actions/badge/action.yml @@ -0,0 +1,27 @@ +name: Regular badging sequence +description: Publishes a badge based on the job status +inputs: + category: + description: The subfolder where to group the badges + required: true + badges: + description: A json object of label => status for all badges + required: true + github_token: + description: The token to use to publish the changes + required: false + default: ${{ github.token }} +runs: + using: composite + steps: + - run: | + node ./.github/actions/badge/write-json-object.js ${{ inputs.category }} '${{ inputs.badges }}' + shell: bash + - uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ inputs.github_token }} + publish_branch: badges + publish_dir: ./badges + keep_files: true + user_name: "github-actions[bot]" + user_email: "github-actions[bot]@users.noreply.github.com" \ No newline at end of file diff --git a/.github/actions/badge/write-json-object.js b/.github/actions/badge/write-json-object.js new file mode 100644 index 0000000..2303fee --- /dev/null +++ b/.github/actions/badge/write-json-object.js @@ -0,0 +1,17 @@ +const fs = require('fs'); +const category = process.argv[2]; +const status = JSON.parse(process.argv[3]); + +if (!fs.existsSync("./badges")) fs.mkdirSync("./badges"); +if (!fs.existsSync("./badges/" + category)) fs.mkdirSync("./badges/" + category); +for (let e in status) { + const path = "./badges/" + category + "/" + e; + if (!fs.existsSync(path)) fs.mkdirSync(path); + const ok = status[e] == "success"; + fs.writeFileSync(path + "/shields.json", JSON.stringify({ + "schemaVersion": 1, + "label": e, + "message": ok ? "Passing" : "Failing", + "color": ok ? "brightgreen" : "red" + })); +} diff --git a/.github/actions/process-linting-results/action.yml b/.github/actions/process-linting-results/action.yml new file mode 100644 index 0000000..a90bbab --- /dev/null +++ b/.github/actions/process-linting-results/action.yml @@ -0,0 +1,26 @@ +name: Process Linting Results +description: Add a comment to a pull request with when `git diff` present and save the changes as an artifact so they can be applied manually +inputs: + linter_name: + description: The name of the tool to credit in the comment + required: true +runs: + using: "composite" + steps: + - run: git add --update + shell: bash + - id: stage + #continue-on-error: true + uses: Thalhammer/patch-generator-action@v2 + + # Unfortunately the previous action reports a failure so nothing else can run + # partially a limitation on composite actions since `continue-on-error` is not + # yet supported + - if: steps.stage.outputs.result == 'dirty' + uses: actions-ecosystem/action-create-comment@v1 + with: + github_token: ${{ github.token }} + body: | + Hello, @${{ github.actor }}! `${{ inputs.linter_name }}` had some concerns :scream: + - run: exit $(git status -uno -s | wc -l) + shell: bash \ No newline at end of file diff --git a/.github/logo.svg b/.github/logo.svg new file mode 100644 index 0000000..be78f4e --- /dev/null +++ b/.github/logo.svg @@ -0,0 +1,80 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + diff --git a/.github/scripts/add-newline-if-missing.sh b/.github/scripts/add-newline-if-missing.sh new file mode 100755 index 0000000..9f751ef --- /dev/null +++ b/.github/scripts/add-newline-if-missing.sh @@ -0,0 +1,7 @@ +#!/bin/bash +if [[ -f "$1" && -s "$1" ]]; then + if [[ -n "$(tail -c 1 "$1")" ]]; then + echo "Fixed missing newline in file $1" + sed -i -e '$a\' $1 + fi +fi \ No newline at end of file diff --git a/.github/workflows/compiler-support.yml b/.github/workflows/compiler-support.yml new file mode 100644 index 0000000..50dafd3 --- /dev/null +++ b/.github/workflows/compiler-support.yml @@ -0,0 +1,128 @@ +name: Compiler Compatibility CI + +on: + push: + branches: [master] + pull_request: + +jobs: + build: + strategy: + fail-fast: false + matrix: + compiler: + # GCC 13 on MacOS seems to be generally broken (https://github.com/actions/runner-images/issues/9997) and therefore disabled + - { tag: "ubuntu-2204_clang-13", name: "Ubuntu 22.04 Clang 13", cxx: "/usr/bin/clang++-13", cc: "/usr/bin/clang-13", runs-on: "ubuntu-22.04" } + - { tag: "ubuntu-2204_clang-14", name: "Ubuntu 22.04 Clang 14", cxx: "/usr/bin/clang++-14", cc: "/usr/bin/clang-14", runs-on: "ubuntu-22.04" } + - { tag: "ubuntu-2204_clang-15", name: "Ubuntu 22.04 Clang 15", cxx: "/usr/bin/clang++-15", cc: "/usr/bin/clang-15", runs-on: "ubuntu-22.04" } + - { tag: "ubuntu-2204_gcc-10", name: "Ubuntu 22.04 G++ 10", cxx: "/usr/bin/g++-10", cc: "/usr/bin/gcc-10", runs-on: "ubuntu-22.04" } + - { tag: "ubuntu-2204_gcc-11", name: "Ubuntu 22.04 G++ 11", cxx: "/usr/bin/g++-11", cc: "/usr/bin/gcc-11", runs-on: "ubuntu-22.04" } + - { tag: "ubuntu-2004_clang-12", name: "Ubuntu 20.04 Clang 12", cxx: "/usr/bin/clang++-12", cc: "/usr/bin/clang-12", runs-on: "ubuntu-20.04" } + - { tag: "ubuntu-2004_clang-11", name: "Ubuntu 20.04 Clang 11", cxx: "/usr/bin/clang++-11", cc: "/usr/bin/clang-11", runs-on: "ubuntu-20.04" } + - { tag: "ubuntu-2004_clang-10", name: "Ubuntu 20.04 Clang 10", cxx: "/usr/bin/clang++-10", cc: "/usr/bin/clang-10", runs-on: "ubuntu-20.04" } + - { tag: "ubuntu-2004_gcc-10", name: "Ubuntu 20.04 G++ 10", cxx: "/usr/bin/g++-10", cc: "/usr/bin/gcc-10", runs-on: "ubuntu-20.04" } + #- { tag: "windows-2022_msvc17", name: "Windows Server 2022 MSVC 17", cxx: "", cc: "", runs-on: "windows-2022" } + #- { tag: "windows-2019_msvc16", name: "Windows Server 2019 MSVC 16", cxx: "", cc: "", runs-on: "windows-2019" } + - { tag: "macos-12_gcc-12", name: "MacOS 12 G++ 12", cxx: "g++-12", cc: "gcc-12", runs-on: "macos-12" } + #- { tag: "macos-12_gcc-13", name: "MacOS 12 G++ 13", cxx: "g++-13", cc: "gcc-13", runs-on: "macos-12" } + - { tag: "macos-12_gcc-14", name: "MacOS 12 G++ 14", cxx: "g++-14", cc: "gcc-14", runs-on: "macos-12" } + - { tag: "macos-12_clang-15", name: "MacOS 12 Clang 15", cxx: "/usr/local/opt/llvm@15/bin/clang++", cc: "/usr/local/opt/llvm@15/bin/clang", runs-on: "macos-12" } + - { tag: "macos-13_gcc-12", name: "MacOS 13 G++ 12", cxx: "g++-12", cc: "gcc-12", runs-on: "macos-13" } + #- { tag: "macos-13_gcc-13", name: "MacOS 13 G++ 13", cxx: "g++-13", cc: "gcc-13", runs-on: "macos-13" } + - { tag: "macos-13_gcc-14", name: "MacOS 13 G++ 14", cxx: "g++-14", cc: "gcc-14", runs-on: "macos-13" } + - { tag: "macos-13_clang-15", name: "MacOS 13 Clang 15", cxx: "/usr/local/opt/llvm@15/bin/clang++", cc: "/usr/local/opt/llvm@15/bin/clang", runs-on: "macos-13" } + - { tag: "macos-14_gcc-12", name: "MacOS 14 G++ 12", cxx: "g++-12", cc: "gcc-12", runs-on: "macos-14" } + #- { tag: "macos-14_gcc-13", name: "MacOS 14 G++ 13", cxx: "g++-13", cc: "gcc-13", runs-on: "macos-14" } + - { tag: "macos-14_gcc-14", name: "MacOS 14 G++ 14", cxx: "g++-14", cc: "gcc-14", runs-on: "macos-14" } + - { tag: "macos-14_clang-15", name: "MacOS 14 Clang 15", cxx: "/opt/homebrew/Cellar/llvm@15/15.0.7/bin/clang++", cc: "/opt/homebrew/Cellar/llvm@15/15.0.7/bin/clang", runs-on: "macos-14" } + runs-on: ${{ matrix.compiler.runs-on }} + name: Compiler ${{ matrix.compiler.name }} + env: + CXX: ${{ matrix.compiler.cxx }} + CC: ${{ matrix.compiler.cc }} + outputs: + # Because github wants us to suffer we need to list out every output instead of using a matrix statement or some kind of dynamic setting + ubuntu-2204_clang-13: ${{ steps.status.outputs.ubuntu-2204_clang-13 }} + ubuntu-2204_clang-14: ${{ steps.status.outputs.ubuntu-2204_clang-14 }} + ubuntu-2204_clang-15: ${{ steps.status.outputs.ubuntu-2204_clang-15 }} + ubuntu-2204_gcc-10: ${{ steps.status.outputs.ubuntu-2204_gcc-10 }} + ubuntu-2204_gcc-11: ${{ steps.status.outputs.ubuntu-2204_gcc-11 }} + ubuntu-2004_clang-12: ${{ steps.status.outputs.ubuntu-2004_clang-12 }} + ubuntu-2004_clang-11: ${{ steps.status.outputs.ubuntu-2004_clang-11 }} + ubuntu-2004_clang-10: ${{ steps.status.outputs.ubuntu-2004_clang-10 }} + ubuntu-2004_gcc-10: ${{ steps.status.outputs.ubuntu-2004_gcc-10 }} + windows-2022_msvc17: ${{ steps.status.outputs.windows-2022_msvc17 }} + windows-2019_msvc16: ${{ steps.status.outputs.windows-2019_msvc16 }} + macos-12_gcc-12: ${{ steps.status.outputs.macos-12_gcc-12 }} + macos-12_gcc-13: ${{ steps.status.outputs.macos-12_gcc-13 }} + macos-12_gcc-14: ${{ steps.status.outputs.macos-12_gcc-14 }} + macos-12_clang-15: ${{ steps.status.outputs.macos-12_clang-15 }} + macos-13_gcc-12: ${{ steps.status.outputs.macos-13_gcc-12 }} + macos-13_gcc-13: ${{ steps.status.outputs.macos-13_gcc-13 }} + macos-13_gcc-14: ${{ steps.status.outputs.macos-13_gcc-14 }} + macos-13_clang-15: ${{ steps.status.outputs.macos-13_clang-15 }} + macos-14_gcc-12: ${{ steps.status.outputs.macos-14_gcc-12 }} + macos-14_gcc-13: ${{ steps.status.outputs.macos-14_gcc-13 }} + macos-14_gcc-14: ${{ steps.status.outputs.macos-14_gcc-14 }} + macos-14_clang-15: ${{ steps.status.outputs.macos-14_clang-15 }} + defaults: + run: + shell: bash -l {0} + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Set LDFLAGS=-Wl,-ld_classic + if: contains(matrix.compiler.tag, 'macos-13') || contains(matrix.compiler.tag, 'macos-14') + run: echo "LDFLAGS=-Wl,-ld_classic" >> $GITHUB_ENV + # Ubuntu 22.04 container has libstdc++13 installed which is incompatible with clang < 15 in C++20 + - name: Uninstall libstdc++-13-dev + if: (matrix.compiler.tag == 'ubuntu-2204_clang-14') || (matrix.compiler.tag == 'ubuntu-2204_clang-13') + run: | + sudo apt autoremove libstdc++-13-dev gcc-13 libgcc-13-dev + sudo apt install libstdc++-12-dev gcc-12 libgcc-12-dev + - name: Install liburing + # Ubuntu 22.04 can just pull liburing from apt + if: contains(matrix.compiler.tag, 'ubuntu-2204') + run: sudo apt install liburing-dev + - name: Install liburing + # Ubuntu 20.04 does not have liburing in apt, pull in deb files from 22.04 instead + if: contains(matrix.compiler.tag, 'ubuntu-2004') + run: | + wget -O /tmp/liburing2_2.1-2build1_amd64.deb http://mirrors.kernel.org/ubuntu/pool/main/libu/liburing/liburing2_2.1-2build1_amd64.deb + wget -O /tmp/liburing-dev_2.1-2build1_amd64.deb http://mirrors.kernel.org/ubuntu/pool/main/libu/liburing/liburing-dev_2.1-2build1_amd64.deb + sudo dpkg -i /tmp/liburing-dev_2.1-2build1_amd64.deb /tmp/liburing2_2.1-2build1_amd64.deb + - name: Configure + if: contains(matrix.compiler.tag, 'ubuntu') + run: cmake -S. -Bbuild -DASYNCPP_BUILD_TEST=ON -DASYNCPP_WITH_ASAN=ON -DASYNCPP_WITH_TSAN=OFF + - name: Configure + if: contains(matrix.compiler.tag, 'ubuntu') != true + run: cmake -S. -Bbuild -DASYNCPP_BUILD_TEST=ON -DASYNCPP_WITH_ASAN=OFF -DASYNCPP_WITH_TSAN=OFF + - name: Build + run: cmake --build build --config Debug + - name: Test + working-directory: ${{ github.workspace }}/build + if: contains(matrix.compiler.tag, 'windows') != true + run: ./asyncpp_io-test + - name: Test + if: contains(matrix.compiler.tag, 'windows') + working-directory: ${{ github.workspace }}/build + run: Debug/asyncpp_io-test.exe + - name: Update Result + id: status + if: ${{ always() }} + run: echo "${{ matrix.compiler.tag }}=${{ job.status }}" >> $GITHUB_OUTPUT + + badge-upload: + if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' && always() }} + needs: [build] + runs-on: ubuntu-20.04 + name: Publish badges + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Publish Badges + uses: ./.github/actions/badge + with: + category: compiler + badges: ${{ toJson(needs.build.outputs) }} + \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..ec85483 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,47 @@ +name: Lint CI + +on: + push: + branches: [master] + pull_request: + +jobs: + clang-format: + runs-on: ubuntu-22.04 + steps: + - run: | + echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-17 main" | sudo tee /etc/apt/sources.list.d/llvm.list + echo "deb-src http://apt.llvm.org/jammy/ llvm-toolchain-jammy-17 main" | sudo tee -a /etc/apt/sources.list.d/llvm.list + wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc + sudo apt update && sudo apt-get install clang-format-17 + shopt -s globstar + - uses: actions/checkout@v3 + - run: find \( -name "*.h" -or -name "*.cpp" \) -exec clang-format-17 -i {} \; + - run: find \( -name "*.h" -or -name "*.cpp" \) -exec ./.github/scripts/add-newline-if-missing.sh {} \; + - uses: ./.github/actions/process-linting-results + with: + linter_name: clang-format + + cmake-format: + runs-on: ubuntu-20.04 + steps: + - uses: actions/setup-python@v4.3.0 + with: + python-version: "3.x" + - run: | + pip install cmakelang + shopt -s globstar + - uses: actions/checkout@v3 + - run: find \( -name "CMakeLists.txt" -or -name "*.cmake" \) -exec cmake-format -i {} \; + - uses: ./.github/actions/process-linting-results + with: + linter_name: cmake-format + + line-ending: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v3 + - run: git add --renormalize . + - uses: ./.github/actions/process-linting-results + with: + linter_name: line-ending \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3457d1b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/build +/.vscode +# MSVC +/.vs +/out/Build \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..7b5a12d --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,107 @@ +cmake_minimum_required(VERSION 3.15) + +project(AsyncppIO) +find_package(Threads REQUIRED) + +if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(ASYNCPP_ENABLE_URING_DEFAULT ON) +else() + set(ASYNCPP_ENABLE_URING_DEFAULT OFF) +endif() + +option(ASYNCPP_BUILD_TEST "Enable test builds" ON) +option(ASYNCPP_WITH_ASAN "Enable asan for test builds" ON) +option(ASYNCPP_ENABLE_URING "Enable support for linux uring" + ${ASYNCPP_ENABLE_URING_DEFAULT}) +include(cmake/import_openssl.cmake) +include(cmake/import_asyncpp.cmake) + +if(ASYNCPP_ENABLE_URING) + find_package(PkgConfig REQUIRED) + if(HUNTER_ENABLED) + # Workaround hunter hideing system libs + set(HUNTER_LIBPATH $ENV{PKG_CONFIG_LIBDIR}) + unset(ENV{PKG_CONFIG_LIBDIR}) + pkg_search_module(URING REQUIRED NO_CMAKE_PATH liburing uring) + set(ENV{PKG_CONFIG_LIBDIR} ${HUNTER_LIBPATH}) + else() + pkg_search_module(URING REQUIRED NO_CMAKE_PATH liburing uring) + endif() +endif() + +add_library( + asyncpp_io + ${CMAKE_CURRENT_SOURCE_DIR}/src/address.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/dns.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/file.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine_select.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine_uring.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/io_service.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/socket.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tls.cpp) +target_link_libraries(asyncpp_io PUBLIC asyncpp OpenSSL::SSL Threads::Threads) +target_include_directories(asyncpp_io + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_compile_features(asyncpp_io PUBLIC cxx_std_20) + +if(ASYNCPP_ENABLE_URING) + target_link_libraries(asyncpp_io PRIVATE ${URING_LINK_LIBRARIES}) + target_include_directories(asyncpp_io PRIVATE ${URING_INCLUDE_DIRS}) + target_compile_definitions(asyncpp_io PRIVATE ASYNCPP_ENABLE_URING=1) +endif() + +if(ASYNCPP_WITH_ASAN) + if(MSVC) + target_compile_options(asyncpp_io PRIVATE -fsanitize=address /Zi) + target_compile_definitions(asyncpp_io PRIVATE _DISABLE_VECTOR_ANNOTATION) + target_compile_definitions(asyncpp_io PRIVATE _DISABLE_STRING_ANNOTATION) + target_link_libraries(asyncpp_io PRIVATE libsancov.lib) + else() + target_compile_options(asyncpp_io PRIVATE -fsanitize=address) + target_link_libraries(asyncpp_io PRIVATE asan) + endif() +endif() + +# G++ below 11 needs a flag +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.0") + target_compile_options(asyncpp_io PUBLIC -fcoroutines) + endif() +endif() + +if(ASYNCPP_BUILD_TEST) + include(cmake/import_gtest.cmake) + + add_executable( + asyncpp_io-test + ${CMAKE_CURRENT_SOURCE_DIR}/test/address.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/dns.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/endpoint.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/file.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/network.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/so_compat.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/socket.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/tls.cpp) + target_link_libraries( + asyncpp_io-test PRIVATE asyncpp_io GTest::gtest GTest::gtest_main + Threads::Threads) + + if(ASYNCPP_WITH_ASAN) + message(STATUS "Building with asan enabled") + + if(MSVC) + target_compile_options(asyncpp_io-test PRIVATE -fsanitize=address /Zi) + target_compile_definitions(asyncpp_io-test + PRIVATE _DISABLE_VECTOR_ANNOTATION) + target_compile_definitions(asyncpp_io-test + PRIVATE _DISABLE_STRING_ANNOTATION) + target_link_libraries(asyncpp_io-test PRIVATE libsancov.lib) + else() + target_compile_options(asyncpp_io-test PRIVATE -fsanitize=address) + target_link_libraries(asyncpp_io-test PRIVATE asan) + endif() + endif() + + gtest_discover_tests(asyncpp_io-test) +endif() diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..00c1652 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Dominik Thalhammer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..367fea8 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +# Async++ Network library +This library provides a c++20 coroutine wrapper for networking and IO related functionality. +It is an addition to [async++](https://github.com/asyncpp/asyncpp) which provides general coroutine tasks and support classes. + +This library is developed and tested on Ubuntu linux. There is also experimental support for MacOS. +Windows is currently unsupported. \ No newline at end of file diff --git a/cmake/import_asyncpp.cmake b/cmake/import_asyncpp.cmake new file mode 100644 index 0000000..5588a2b --- /dev/null +++ b/cmake/import_asyncpp.cmake @@ -0,0 +1,10 @@ +if(TARGET asyncpp) + message(STATUS "Using existing asyncpp target.") +else() + message(STATUS "Missing asyncpp, using Fetch to import it.") + + include(FetchContent) + FetchContent_Declare(asyncpp + GIT_REPOSITORY "https://github.com/asyncpp/asyncpp.git") + FetchContent_MakeAvailable(asyncpp) +endif() diff --git a/cmake/import_gtest.cmake b/cmake/import_gtest.cmake new file mode 100644 index 0000000..72ef41a --- /dev/null +++ b/cmake/import_gtest.cmake @@ -0,0 +1,26 @@ +enable_testing() +include(GoogleTest) + +if(TARGET GTest::gtest) + message(STATUS "Using existing GTest::gtest target.") +else() + if(HUNTER_ENABLED) + hunter_add_package(GTest) + find_package(GTest CONFIG REQUIRED) + else() + include(FetchContent) + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.12.1) + if(WIN32) + set(gtest_force_shared_crt + ON + CACHE BOOL "" FORCE) + set(BUILD_GMOCK + OFF + CACHE BOOL "" FORCE) + endif() + FetchContent_MakeAvailable(googletest) + endif() +endif() diff --git a/cmake/import_openssl.cmake b/cmake/import_openssl.cmake new file mode 100644 index 0000000..6f7b966 --- /dev/null +++ b/cmake/import_openssl.cmake @@ -0,0 +1,13 @@ +if(TARGET OpenSSL::SSL) + message(STATUS "Using existing OpenSSL::SSL target.") +else() + if(HUNTER_ENABLED) + hunter_add_package(OpenSSL) + find_package(OpenSSL REQUIRED) + else() + find_package(OpenSSL) + if(NOT OPENSSL_FOUND) + message(FATAL_ERROR "Could not find OpenSSL and Hunter is disabled") + endif() + endif() +endif() diff --git a/include/asyncpp/io/address.h b/include/asyncpp/io/address.h new file mode 100644 index 0000000..91af18f --- /dev/null +++ b/include/asyncpp/io/address.h @@ -0,0 +1,532 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +struct sockaddr_storage; +struct sockaddr_in; +struct sockaddr_in6; +struct sockaddr_un; +namespace asyncpp::io { + enum class address_type { + ipv4, + ipv6, +#ifndef _WIN32 + uds +#endif + }; + + class ipv4_address { + alignas(uint32_t) std::array m_data{}; + + public: + constexpr ipv4_address() noexcept {} + explicit constexpr ipv4_address(uint32_t nbo_addr, std::endian order = std::endian::big) noexcept + : m_data{static_cast(nbo_addr >> 24), static_cast(nbo_addr >> 16), + static_cast(nbo_addr >> 8), static_cast(nbo_addr >> 0)} { + if (order != std::endian::big) std::reverse(m_data.begin(), m_data.end()); + } + explicit constexpr ipv4_address(std::span data, std::endian order = std::endian::big) noexcept + : m_data{data[0], data[1], data[2], data[3]} { + if (order != std::endian::big) std::reverse(m_data.begin(), m_data.end()); + } + constexpr ipv4_address(uint8_t a, uint8_t b, uint8_t c, uint8_t d, + std::endian order = std::endian::big) noexcept + : m_data{a, b, c, d} { + if (order != std::endian::big) std::reverse(m_data.begin(), m_data.end()); + } + explicit ipv4_address(const sockaddr_storage& addr); + explicit ipv4_address(const sockaddr_in& addr) noexcept; + + constexpr std::span data() const noexcept { return m_data; } + constexpr uint32_t integer(std::endian order = std::endian::big) const noexcept { + if (order == std::endian::big) { + return (std::uint32_t(m_data[0]) << 24) | (std::uint32_t(m_data[1]) << 16) | + (std::uint32_t(m_data[2]) << 8) | (std::uint32_t(m_data[3]) << 0); + } else { + return (std::uint32_t(m_data[0]) << 0) | (std::uint32_t(m_data[1]) << 8) | + (std::uint32_t(m_data[2]) << 16) | (std::uint32_t(m_data[3]) << 24); + } + } + + constexpr bool is_any() const noexcept { return *this == any(); } + constexpr bool is_multicast() const noexcept { return (m_data[0] & 0xf0) == 0xe0; } + constexpr bool is_loopback() const noexcept { return m_data[0] == 127; } + constexpr bool is_private() const noexcept { + return m_data[0] == 10 || (m_data[0] == 172 && (m_data[1] & 0xf0) == 16) || + (m_data[0] == 192 && m_data[1] == 168); + // FIXME: technically 0.0.0.0/8, 100.64.0.0/10, 198.18.0.0/15 and 169.254.0.0/16 are considered private as well, + // however those are usually not meant when talking about private/public ips. In particular 100.64.0.0/10 (carrier grade nat) + // can appear as a users "public" ip from the perspective of the user. + } + + constexpr std::strong_ordering operator<=>(const ipv4_address& rhs) const noexcept = default; + + std::string to_string() const { + char buf[16]{}; + auto ptr = std::begin(buf); + for (auto& e : m_data) { + if (ptr != std::begin(buf)) *ptr++ = '.'; + if (e >= 100) { + *ptr++ = '0' + (e / 100); + *ptr++ = '0' + ((e % 100) / 10); + *ptr++ = '0' + (e % 10); + } else if (e >= 10) { + *ptr++ = '0' + (e / 10); + *ptr++ = '0' + (e % 10); + } else { + *ptr++ = '0' + e; + } + } + return std::string(buf, ptr); + } + std::pair to_sockaddr() const noexcept; + std::pair to_sockaddr_in() const noexcept; + + static constexpr ipv4_address loopback() noexcept { return ipv4_address(127, 0, 0, 1); } + static constexpr ipv4_address any() noexcept { return ipv4_address(0, 0, 0, 0); } + static constexpr std::optional parse(std::string_view str) noexcept { + constexpr auto parse_part = [](std::string_view::const_iterator& it, std::string_view::const_iterator end) { + if (it == end || (*it < '0' && *it > '9')) return -1; + int32_t result = 0; + while (*it >= '0' && *it <= '9') { + result = (result * 10) + (*it - '0'); + it++; + } + return result; + }; + auto it = str.begin(); + auto p1 = parse_part(it, str.end()); + if (p1 < 0 || p1 > 255 || it == str.end() || *it++ != '.') return std::nullopt; + auto p2 = parse_part(it, str.end()); + if (p2 < 0 || p2 > 255 || it == str.end() || *it++ != '.') return std::nullopt; + auto p3 = parse_part(it, str.end()); + if (p3 < 0 || p3 > 255 || it == str.end() || *it++ != '.') return std::nullopt; + auto p4 = parse_part(it, str.end()); + if (p4 < 0 || p4 > 255 || it != str.end()) return std::nullopt; + return ipv4_address(p1, p2, p3, p4); + } + }; + + class ipv6_address { + alignas(uint64_t) std::array m_data{}; + + public: + constexpr ipv6_address() noexcept {} + explicit constexpr ipv6_address(std::span data, + std::endian order = std::endian::big) noexcept + : m_data{data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], + data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15]} { + if (order != std::endian::big) std::reverse(m_data.begin(), m_data.end()); + } + constexpr ipv6_address(uint64_t a, uint64_t b) noexcept + : m_data{static_cast(a >> 56), static_cast(a >> 48), + static_cast(a >> 40), static_cast(a >> 32), + static_cast(a >> 24), static_cast(a >> 16), + static_cast(a >> 8), static_cast(a >> 0), + static_cast(b >> 56), static_cast(b >> 48), + static_cast(b >> 40), static_cast(b >> 32), + static_cast(b >> 24), static_cast(b >> 16), + static_cast(b >> 8), static_cast(b >> 0)} {} + constexpr ipv6_address(uint16_t a, uint16_t b, uint16_t c, uint16_t d, uint16_t e, uint16_t f, uint16_t g, + uint16_t h) noexcept + : m_data{static_cast(a >> 8), static_cast(a >> 0), + static_cast(b >> 8), static_cast(b >> 0), + static_cast(c >> 8), static_cast(c >> 0), + static_cast(d >> 8), static_cast(d >> 0), + static_cast(e >> 8), static_cast(e >> 0), + static_cast(f >> 8), static_cast(f >> 0), + static_cast(g >> 8), static_cast(g >> 0), + static_cast(h >> 8), static_cast(h >> 0)} {} + constexpr ipv6_address(uint8_t a, uint8_t b, uint8_t c, uint8_t d, uint8_t e, uint8_t f, uint8_t g, uint8_t h, + uint8_t i, uint8_t j, uint8_t k, uint8_t l, uint8_t m, uint8_t n, uint8_t o, + uint8_t p) noexcept + : m_data{a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p} {} + explicit constexpr ipv6_address(std::span data) noexcept + : ipv6_address{data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7]} {} + explicit constexpr ipv6_address(ipv4_address addr) noexcept + : ipv6_address(std::array{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, addr.data()[0], + addr.data()[1], addr.data()[2], addr.data()[3]}) {} + explicit ipv6_address(const sockaddr_storage& addr); + explicit ipv6_address(const sockaddr_in6& addr) noexcept; + + constexpr std::span data() const noexcept { return m_data; } + constexpr std::span ipv4_data() const noexcept { + return std::span{&m_data[12], &m_data[16]}; + } + + constexpr uint64_t subnet_prefix() const noexcept { + return static_cast(m_data[0]) << 56 | static_cast(m_data[1]) << 48 | + static_cast(m_data[2]) << 40 | static_cast(m_data[3]) << 32 | + static_cast(m_data[4]) << 24 | static_cast(m_data[5]) << 16 | + static_cast(m_data[6]) << 8 | static_cast(m_data[7]); + } + constexpr uint64_t interface_identifier() const noexcept { + return static_cast(m_data[8]) << 56 | static_cast(m_data[9]) << 48 | + static_cast(m_data[10]) << 40 | static_cast(m_data[11]) << 32 | + static_cast(m_data[12]) << 24 | static_cast(m_data[13]) << 16 | + static_cast(m_data[14]) << 8 | static_cast(m_data[15]); + } + + constexpr std::strong_ordering operator<=>(const ipv6_address& rhs) const noexcept = default; + + constexpr bool is_any() const noexcept { return *this == any(); } + constexpr bool is_loopback() const noexcept { return *this == loopback(); } + constexpr bool is_multicast() const noexcept { return m_data[0] == 0xff; } + constexpr bool is_link_local() const noexcept { return m_data[0] == 0xfe && (m_data[1] & 0xc0) == 0x80; } + constexpr bool is_global() const noexcept { + return !(is_any() || is_loopback() || is_multicast() || is_link_local()); + } + constexpr bool is_ipv4_mapped() const noexcept { + for (size_t i = 0; i < 10; i++) + if (m_data[i]) return false; + return m_data[10] == 0xff && m_data[11] == 0xff; + } + constexpr ipv4_address mapped_ipv4() const noexcept { + if (!is_ipv4_mapped()) return ipv4_address(); + return ipv4_address(std::span(&m_data[12], &m_data[16])); + } + + std::string to_string(bool full = false) const { + static constexpr const char* table = "0123456789abcdef"; + // A ipv6 address is represented by 8 16bit blocks separated by a colon. + // Leading zeros in a block are suppressed, but a block may not be empty. + // If more than one zero blocks follow each other the longest one may be replaced + // by :: + + // Search for the longest run of zero blocks + int zerorun_start = -1; + int zerorun_length = 0; + if (!full) { + for (int i = 0; i < m_data.size(); i += 2) { + if (m_data[i] == 0 && m_data[i + 1] == 0) { + int run_start = i; + for (; i < m_data.size() && m_data[i] == 0 && m_data[i + 1] == 0; i += 2) + ; + if (i - run_start > zerorun_length && i - run_start >= 4) { + zerorun_start = run_start; + zerorun_length = i - run_start; + } + } + } + } + + std::string res; + for (int i = 0; i < m_data.size();) { + // This is the start of the zero run + if (i == zerorun_start) { + i += zerorun_length; + res += ':'; + // if this is the end, append an extra colon, otherwise it is added by the next block + if (i >= m_data.size()) res += ':'; + continue; + } + if (i != 0) res += ':'; + if (full || m_data[i] & 0xf0) res += table[(m_data[i] >> 4) & 0xf]; + if (full || m_data[i]) res += table[m_data[i] & 0xf]; + if (full || m_data[i] || m_data[i + 1] & 0xf0) res += table[(m_data[i + 1] >> 4) & 0xf]; + res += table[m_data[i + 1] & 0xf]; + i += 2; + } + return res; + } + std::pair to_sockaddr() const noexcept; + std::pair to_sockaddr_in6() const noexcept; + + static constexpr ipv6_address any() noexcept { return ipv6_address(); } + static constexpr ipv6_address loopback() noexcept { return ipv6_address(0, 0, 0, 0, 0, 0, 0, 1); } + static constexpr std::optional parse(std::string_view str) noexcept { + if (str.starts_with("[")) str.remove_prefix(1); + if (str.starts_with("]")) str.remove_suffix(1); + std::array buf{}; + int idx = 0; + int dcidx = -1; + auto it = str.begin(); + auto part_start = it; + bool is_v4_interop = false; + if (*it == ':') { + dcidx = idx++; + it++; + if (it == str.end() || *it != ':') return std::nullopt; + it++; + } + while (it != str.end()) { + part_start = it; + if (*it == ':') { + if (dcidx != -1) return std::nullopt; + dcidx = idx++; + it++; + } else { + while (it != str.end()) { + if (idx == 8) return std::nullopt; + if (*it != ':') { + if (*it >= '0' && *it <= '9') + buf[idx] = buf[idx] * 16 + (*it - '0'); + else if (*it >= 'a' && *it <= 'f') + buf[idx] = buf[idx] * 16 + (*it - 'a') + 10; + else if (*it >= 'A' && *it <= 'F') + buf[idx] = buf[idx] * 16 + (*it - 'A') + 10; + else if (*it == '.') { + auto ip4 = ipv4_address::parse(std::string_view(part_start, str.end())); + if (!ip4) return std::nullopt; + auto data = ip4->data(); + buf[idx++] = (static_cast(data[0]) << 8) | data[1]; + if (idx >= 8) return std::nullopt; + buf[idx] = (static_cast(data[2]) << 8) | data[3]; + it = str.end(); + is_v4_interop = true; + continue; + } else + return std::nullopt; + it++; + } else { + if (std::distance(part_start, it) > 4) return std::nullopt; + it++; + if (it == str.end()) return std::nullopt; + break; + } + } + idx++; + } + } + if (dcidx != -1) { + const auto ncopy = idx - dcidx; + const auto dest = dcidx + ncopy - 1; + for (auto i = 0; i < ncopy; i++) { + buf[7 - i] = buf[dest - i]; + buf[dest - i] = 0; + } + } else if (idx != 8) + return std::nullopt; + ipv6_address res{buf}; + if (is_v4_interop && !res.is_ipv4_mapped()) return std::nullopt; + return res; + } + }; + +#ifndef _WIN32 + class uds_address { + std::array m_data{}; + uint8_t m_len{}; + + public: + constexpr uds_address() noexcept {} + explicit constexpr uds_address(std::string_view path) noexcept { + m_len = (std::min)(path.size(), m_data.size() - 1); + for (size_t i = 0; i < m_len; i++) + m_data[i] = path[i]; + for (size_t i = m_len; i < m_data.size(); i++) + m_data[i] = '\0'; + if (m_len != 0 && m_data[0] == '@') m_data[0] = '\0'; + } + explicit uds_address(const sockaddr_storage& addr, size_t len); + explicit uds_address(const sockaddr_un& addr, size_t len) noexcept; + + constexpr std::span data() const noexcept { return {m_data.data(), m_len}; } + + constexpr std::strong_ordering operator<=>(const uds_address& rhs) const noexcept = default; + + constexpr bool is_unnamed() const noexcept { return m_len == 0; } + constexpr bool is_abstract() const noexcept { return m_len != 0 && m_data[0] == '\0'; } + + std::string to_string() const { + std::string res{reinterpret_cast(m_data.data()), m_len}; + if (!res.empty() && res[0] == '\0') res[0] = '@'; + return res; + } + std::pair to_sockaddr() const noexcept; + std::pair to_sockaddr_un() const noexcept; + + static constexpr std::optional parse(std::string_view str) noexcept { + if (str.size() > 108) return std::nullopt; + if (!str.empty() && std::accumulate(str.begin(), str.end(), 0ull) == 0) return std::nullopt; + if (!str.empty() && str[0] == '@' && std::accumulate(str.begin() + 1, str.end(), 0ull) == 0) + return std::nullopt; + if (!str.empty() && (str.front() == ' ' || str.back() == ' ')) return std::nullopt; + return uds_address(str); + } + }; +#endif + + class address { + union { + ipv4_address m_ipv4{}; + ipv6_address m_ipv6; +#ifndef _WIN32 + uds_address m_uds; +#endif + }; + address_type m_type{address_type::ipv4}; + + public: + constexpr address() noexcept {} + explicit constexpr address(ipv4_address addr) noexcept : m_ipv4(addr), m_type{address_type::ipv4} {} + explicit constexpr address(ipv6_address addr) noexcept { + if (addr.is_ipv4_mapped()) { + m_ipv4 = addr.mapped_ipv4(); + m_type = address_type::ipv4; + } else { + m_ipv6 = addr; + m_type = address_type::ipv6; + } + } +#ifndef _WIN32 + explicit constexpr address(uds_address addr) noexcept : m_uds(addr), m_type{address_type::uds} {} +#endif + explicit address(const sockaddr_storage& addr, size_t len); + + constexpr address_type type() const noexcept { return m_type; } + constexpr bool is_ipv4() const noexcept { return m_type == address_type::ipv4; } + constexpr bool is_ipv6() const noexcept { return m_type == address_type::ipv6; } +#ifndef _WIN32 + constexpr bool is_uds() const noexcept { return m_type == address_type::uds; } +#endif + + constexpr ipv4_address ipv4() const noexcept { + switch (m_type) { + case address_type::ipv4: return m_ipv4; + case address_type::ipv6: return {}; +#ifndef _WIN32 + case address_type::uds: return {}; +#endif + } + return {}; + } + constexpr ipv6_address ipv6() const noexcept { + switch (m_type) { + case address_type::ipv4: return ipv6_address(m_ipv4); +#ifndef _WIN32 + case address_type::uds: return {}; +#endif + case address_type::ipv6: return m_ipv6; + } + return {}; + } + +#ifndef _WIN32 + constexpr uds_address uds() const noexcept { + switch (m_type) { + case address_type::ipv4: + case address_type::ipv6: return {}; + case address_type::uds: return m_uds; + } + return {}; + } +#endif + + constexpr bool is_any() const noexcept { + switch (m_type) { + case address_type::ipv4: return m_ipv4.is_any(); + case address_type::ipv6: return m_ipv6.is_any(); +#ifndef _WIN32 + case address_type::uds: return false; +#endif + } + } + constexpr bool is_loopback() const noexcept { + switch (m_type) { + case address_type::ipv4: return m_ipv4.is_loopback(); + case address_type::ipv6: return m_ipv6.is_loopback(); +#ifndef _WIN32 + case address_type::uds: return false; +#endif + } + } + + constexpr std::span bytes() const noexcept { + switch (m_type) { + case address_type::ipv4: return m_ipv4.data(); + case address_type::ipv6: return m_ipv6.data(); +#ifndef _WIN32 + case address_type::uds: return m_uds.data(); +#endif + } + } + + constexpr std::strong_ordering operator<=>(const address& rhs) const noexcept { + auto order = m_type <=> rhs.m_type; + if (order != std::strong_ordering::equal) return order; + switch (m_type) { + case address_type::ipv4: return m_ipv4 <=> rhs.m_ipv4; + case address_type::ipv6: return m_ipv6 <=> rhs.m_ipv6; +#ifndef _WIN32 + case address_type::uds: return m_uds <=> rhs.m_uds; +#endif + } + } + constexpr bool operator==(const address& rhs) const noexcept { + return (*this <=> rhs) == std::strong_ordering::equal; + } + constexpr bool operator!=(const address& rhs) const noexcept { + return (*this <=> rhs) != std::strong_ordering::equal; + } + + std::string to_string(bool full = false) const { + switch (m_type) { + case address_type::ipv4: return m_ipv4.to_string(); + case address_type::ipv6: return m_ipv6.to_string(full); +#ifndef _WIN32 + case address_type::uds: return m_uds.to_string(); +#endif + } + } + std::pair to_sockaddr() const noexcept; + + static constexpr address any() noexcept { return address(ipv6_address::any()); } + static constexpr address loopback() noexcept { return address(ipv6_address::loopback()); } + static constexpr std::optional
parse(std::string_view str) noexcept { + auto ip4 = ipv4_address::parse(str); + if (ip4) return address(*ip4); + auto ip6 = ipv6_address::parse(str); + if (ip6) return address(*ip6); + return std::nullopt; + } + }; +} // namespace asyncpp::io + +namespace std { + template<> + struct hash { + size_t operator()(const asyncpp::io::ipv4_address& x) const noexcept { + return std::hash{}(x.integer()); + } + }; + template<> + struct hash { + size_t operator()(const asyncpp::io::ipv6_address& x) const noexcept { + std::hash h{}; + auto res = h(x.subnet_prefix()); + return res ^ (h(x.interface_identifier()) + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); + } + }; +#ifndef _WIN32 + template<> + struct hash { + size_t operator()(const asyncpp::io::uds_address& x) const noexcept { + size_t res = 0; + for (auto e : x.data()) + res = res ^ (e + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); + return res; + } + }; +#endif + template<> + struct hash { + size_t operator()(const asyncpp::io::address& x) const noexcept { + size_t res; + switch (x.type()) { + case asyncpp::io::address_type::ipv4: res = std::hash{}(x.ipv4()); break; + case asyncpp::io::address_type::ipv6: res = std::hash{}(x.ipv6()); break; +#ifndef _WIN32 + case asyncpp::io::address_type::uds: res = std::hash{}(x.uds()); break; +#endif + } + return res ^ (static_cast(x.type()) + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); + } + }; +} // namespace std diff --git a/include/asyncpp/io/buffer.h b/include/asyncpp/io/buffer.h new file mode 100644 index 0000000..a18b978 --- /dev/null +++ b/include/asyncpp/io/buffer.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace asyncpp::io { + + template + requires(std::is_trivial_v) + inline void raw_set(void* ptr, std::type_identity_t val) noexcept { + memcpy(ptr, &val, sizeof(T)); + if constexpr (std::endian::native != Endian) + std::reverse(static_cast(ptr), static_cast(ptr) + sizeof(T)); + } + + template + requires(std::is_trivial_v) + inline T raw_get(const void* ptr) noexcept { + T res; + memcpy(&res, ptr, sizeof(T)); + if constexpr (std::endian::native != Endian) { + std::reverse(reinterpret_cast(&res), reinterpret_cast(&res) + sizeof(T)); + } + return res; + } + + using buffer = std::span; + using const_buffer = std::span; +} // namespace asyncpp::io diff --git a/include/asyncpp/io/detail/cancel_awaitable.h b/include/asyncpp/io/detail/cancel_awaitable.h new file mode 100644 index 0000000..048380d --- /dev/null +++ b/include/asyncpp/io/detail/cancel_awaitable.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include +#include + +#include + +namespace asyncpp::io::detail { + struct cancel_io_stop_callback { + io_engine::completion_data* m_data; + io_engine* m_engine; + void operator()() noexcept { + if (m_engine && m_data) m_engine->cancel(m_data); + } + }; + + template + class cancellable_awaitable { + T m_child; + asyncpp::stop_token m_stop_token; + std::optional> m_cancel_callback; + + public: + template + cancellable_awaitable(asyncpp::stop_token st, Args&&... args) noexcept + : m_child(std::forward(args)...), m_stop_token{std::move(st)} {} + bool await_ready() const noexcept { return m_child.await_ready(); } + bool await_suspend(coroutine_handle<> hdl) { + if (m_stop_token.stop_requested()) { + m_child.m_completion.result = -ECANCELED; + return false; + } + auto res = m_child.await_suspend(hdl); + if (res) + m_cancel_callback.emplace( + m_stop_token, cancel_io_stop_callback{&m_child.m_completion, m_child.m_socket.service().engine()}); + return res; + } + auto await_resume() { return m_child.await_resume(); } + }; +} // namespace asyncpp::io::detail diff --git a/include/asyncpp/io/detail/io_engine.h b/include/asyncpp/io/detail/io_engine.h new file mode 100644 index 0000000..5c7e9a3 --- /dev/null +++ b/include/asyncpp/io/detail/io_engine.h @@ -0,0 +1,56 @@ +#pragma once +#include + +#include + +namespace asyncpp::io::detail { + class io_engine { + public: + using file_handle_t = int; + constexpr static file_handle_t invalid_file_handle = -1; + using socket_handle_t = int; + constexpr static socket_handle_t invalid_socket_handle = -1; + enum class fsync_flags { none, datasync }; + + struct completion_data { + // Info provided by caller + void (*callback)(void*); + void* userdata; + + // Filled by io_engine + int result; + + // Private data the engine can use to associate state + void* engine_state{}; + }; + + public: + virtual ~io_engine() = default; + + virtual std::string_view name() const noexcept = 0; + + virtual size_t run(bool nowait = false) = 0; + virtual void wake() = 0; + + // Networking api + virtual bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) = 0; + virtual bool enqueue_accept(socket_handle_t socket, completion_data* cd) = 0; + virtual bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) = 0; + virtual bool enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) = 0; + virtual bool enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) = 0; + virtual bool enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) = 0; + + // Filesystem IO + virtual bool enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) = 0; + virtual bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, + completion_data* cd) = 0; + virtual bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) = 0; + + // Cancelation + virtual bool cancel(completion_data* cd) = 0; + }; + + std::unique_ptr create_io_engine(); +} // namespace asyncpp::io::detail diff --git a/include/asyncpp/io/dns.h b/include/asyncpp/io/dns.h new file mode 100644 index 0000000..664cd1b --- /dev/null +++ b/include/asyncpp/io/dns.h @@ -0,0 +1,919 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace asyncpp::io::dns { + enum class api_error : int { + ok = 0, + not_enough_space, + label_invalid, + label_too_long, + incomplete_message, + recursion_limit_exceeded, + extra_data, + duplicate_id, + no_id, + + cancelled, + timeout, + internal, + }; + const std::error_category& error_category() noexcept; + inline std::error_code make_error_code(api_error e) noexcept { + return std::error_code(static_cast(e), error_category()); + } +} // namespace asyncpp::io::dns +namespace std { + template<> + struct is_error_code_enum : std::true_type {}; +} // namespace std +namespace asyncpp::io::dns { + constexpr size_t max_message_size = std::numeric_limits::max(); + constexpr size_t max_label_size = 63; + constexpr size_t max_name_size = 255; + + enum class rcode : uint8_t { + no_error = 0, // No error condition. + form_error = 1, // The name server was unable to interpret the request due to a format error. + server_failure = 2, // The name server encountered an internal failure while processing this request, + // for example an operating system error or a forwarding timeout. + nx_domain = 3, // Some name that ought to exist, does not exist. + not_implemented = 4, // The name server does not support the specified Opcode. + refused = 5, // The name server refuses to perform the specified operation for policy or security reasons. + domain_exists = 6, // Some name that ought not to exist, does exist. + rrset_exists = 7, // Some RRset that ought not to exist, does exist. + nx_rrset = 8, // Some RRset that ought to exist, does not exist. + not_authoritative = 9, // The server is not authoritative for the zone named in the Zone Section. + not_zone = 10, // A name used in the Prerequisite or Update Section is + // not within the zone denoted by the Zone Section. + bad_signature = 16, // tsig signature was invalid (likely invalid key). + bad_key = 17, // TSIG Key is not known by server. + bad_time = 18, // TSIG Timestamp was wrong (are your clocks in sync ?). + }; + + enum class opcode : uint8_t { + query = 0, + iquery = 1, + status = 2, + update = 5, + }; + + enum class qtype : uint16_t { + a = 1, + ns = 2, + md = 3, + mf = 4, + cname = 5, + soa = 6, + mb = 7, + mg = 8, + mr = 9, + null = 10, + wks = 11, + ptr = 12, + hinfo = 13, + minfo = 14, + mx = 15, + txt = 16, + rp = 17, + afsdb = 18, + x25 = 19, + isdn = 20, + rt = 21, + nsap = 22, + nsap_ptr = 23, + sig = 24, + key = 25, + px = 26, + gpos = 27, + aaaa = 28, + loc = 29, + nxt = 30, + eid = 31, + nimloc = 32, + srv = 33, + atma = 34, + naptr = 35, + kx = 36, + cert = 37, + a6 = 38, + dname = 39, + sink = 40, + opt = 41, + apl = 42, + ds = 43, + sshfp = 44, + ipseckey = 45, + rrsig = 46, + nsec = 47, + dnskey = 48, + dhcid = 49, + nsec3 = 50, + nsec3param = 51, + tlsa = 52, + smimea = 53, + hip = 55, + ninfo = 56, + rkey = 57, + talink = 58, + cds = 59, + cdnskey = 60, + openpgpkey = 61, + csync = 62, + spf = 99, + uinfo = 100, + uid = 101, + gid = 102, + unspec = 103, + nid = 104, + l32 = 105, + l64 = 106, + lp = 107, + eui48 = 108, + eui64 = 109, + tkey = 249, + tsig = 250, + ixfr = 251, + axfr = 252, + mailb = 253, + maila = 254, + any = 255, + uri = 256, + caa = 257, + avc = 258, + ta = 32768, + dlv = 32769, + }; + + enum class qclass : uint16_t { + in = 1, /*%< Internet. */ + csnet = 2, + chaos = 3, /*%< MIT Chaos-net. */ + hs = 4, /*%< MIT Hesiod. */ + any = 0xff, + }; + + class binary_writer { + uint8_t* const m_start{}; + uint8_t* m_end{}; + uint8_t* const m_cap{}; + bool m_truncated{}; + bool m_throwing{}; + + uint8_t* alloc(size_t n) { + if (n > remaining_space() || m_truncated) { + m_truncated = true; + if (m_throwing) throw std::out_of_range("not enough remaining space"); + return nullptr; + } + auto ptr = m_end; + m_end += n; + return ptr; + } + + public: + binary_writer(uint8_t* ptr, size_t size) : m_start(ptr), m_end(ptr), m_cap(ptr + size) {} + + size_t remaining_space() const noexcept { return m_cap - m_end; } + size_t used_space() const noexcept { return m_end - m_start; } + size_t total_space() const noexcept { return m_cap - m_start; } + bool is_truncated() const noexcept { return m_truncated; } + binary_writer& set_throwing(bool throwing) noexcept { + m_throwing = throwing; + return *this; + } + + binary_writer& u8(uint8_t val) { + if (auto p = alloc(1); p) *p = val; + return *this; + } + binary_writer& u16(uint16_t val, std::endian e = std::endian::little) { + if (auto p = alloc(2); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 2); + } + return *this; + } + binary_writer& u24(uint32_t val, std::endian e = std::endian::little) { + if (auto p = alloc(3); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + p[2] = (val >> 16) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 3); + } + return *this; + } + binary_writer& u32(uint32_t val, std::endian e = std::endian::little) { + if (auto p = alloc(4); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + p[2] = (val >> 16) & 0xff; + p[3] = (val >> 24) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 4); + } + return *this; + } + binary_writer& u40(uint64_t val, std::endian e = std::endian::little) { + if (auto p = alloc(5); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + p[2] = (val >> 16) & 0xff; + p[3] = (val >> 24) & 0xff; + p[4] = (val >> 32) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 5); + } + return *this; + } + binary_writer& u48(uint64_t val, std::endian e = std::endian::little) { + if (auto p = alloc(6); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + p[2] = (val >> 16) & 0xff; + p[3] = (val >> 24) & 0xff; + p[4] = (val >> 32) & 0xff; + p[5] = (val >> 40) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 6); + } + return *this; + } + binary_writer& u56(uint64_t val, std::endian e = std::endian::little) { + if (auto p = alloc(7); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + p[2] = (val >> 16) & 0xff; + p[3] = (val >> 24) & 0xff; + p[4] = (val >> 32) & 0xff; + p[5] = (val >> 40) & 0xff; + p[6] = (val >> 48) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 7); + } + return *this; + } + binary_writer& u64(uint64_t val, std::endian e = std::endian::little) { + if (auto p = alloc(8); p) { + p[0] = val & 0xff; + p[1] = (val >> 8) & 0xff; + p[2] = (val >> 16) & 0xff; + p[3] = (val >> 24) & 0xff; + p[4] = (val >> 32) & 0xff; + p[5] = (val >> 40) & 0xff; + p[6] = (val >> 48) & 0xff; + p[7] = (val >> 56) & 0xff; + if (e != std::endian::little) std::reverse(p, p + 8); + } + return *this; + } + binary_writer& raw(const void* ptr, size_t len, std::endian e = std::endian::native) { + if (auto p = alloc(len); p) { + memcpy(p, ptr, len); + if (e != std::endian::little) std::reverse(p, p + len); + } + return *this; + } + binary_writer& dns_name(std::string_view name, std::map>* compress = nullptr) { + if (compress != nullptr) { + if (auto it = compress->find(name); it != compress->end()) { + u16(0xc000 | it->second, std::endian::big); + return *this; + } + compress->emplace(name, used_space()); + } + + std::string_view part = name.substr(0, name.find('.')); + name.remove_prefix(part.size() + 1); + while (!part.empty()) { + u8(part.size()); + raw(part.data(), part.size()); + if (compress != nullptr) { + if (auto it = compress->find(name); it != compress->end()) { + u16(0xc000 | it->second, std::endian::big); + return *this; + } + if (!name.empty()) compress->emplace(name, used_space()); + } + part = name.substr(0, name.find('.')); + name.remove_prefix(part.size() + 1); + } + u8(0); + return *this; + } + }; + + size_t convert_name(void* out, size_t outlen, std::string_view name, std::error_code& ec) noexcept; + /** + * \brief Parses a label and resolves all pointers. Returns the end of the initial label. + * + * \param msg The entire dns message for resolving pointers + * \param label Start of the label + * \param res Out param for the parsed string + * \return End of the first label/piece or nullptr if an error occurred + */ + const std::byte* parse_label(const_buffer msg, const std::byte* label, std::string& res, + std::error_code& ec) noexcept; + + struct mx_record { + uint16_t preference; + std::string name; + }; + struct soa_record { + std::string name; + std::string rname; + uint32_t serial; + uint32_t refresh; + uint32_t retry; + uint32_t expire; + uint32_t minimum; + }; + struct srv_record { + uint16_t priority; + uint16_t weight; + uint16_t port; + std::string target; + }; + struct tsig_record { + std::string algorithm; + std::chrono::system_clock::time_point timestamp; + uint16_t fudge; + std::vector mac; + uint16_t original_id; + rcode error; + std::vector other; + }; + + ipv4_address parse_a(const_buffer data, const_buffer msg, std::error_code& ec) noexcept; + ipv6_address parse_aaaa(const_buffer data, const_buffer msg, std::error_code& ec) noexcept; + std::string parse_txt(const_buffer data, const_buffer msg, std::error_code& ec) noexcept; + std::string parse_cname(const_buffer data, const_buffer msg, std::error_code& ec) noexcept; + std::string parse_ns(const_buffer data, const_buffer msg, std::error_code& ec) noexcept; + mx_record parse_mx(const_buffer data, const_buffer msg, std::error_code& ec); + std::string parse_ptr(const_buffer data, const_buffer msg, std::error_code& ec) noexcept; + soa_record parse_soa(const_buffer data, const_buffer msg, std::error_code& ec); + srv_record parse_srv(const_buffer data, const_buffer msg, std::error_code& ec); + tsig_record parse_tsig(const_buffer data, const_buffer msg, std::error_code& ec); + + std::vector build_a(ipv4_address addr); + std::vector build_aaaa(ipv6_address addr); + std::vector build_txt(std::string_view str); + std::vector build_cname(std::string_view str); + std::vector build_ns(std::string_view str); + std::vector build_mx(const mx_record& rr); + std::vector build_ptr(std::string_view str); + std::vector build_soa(const soa_record& rr); + std::vector build_srv(const srv_record& rr); + std::vector build_tsig(const tsig_record& rr); + + class message_header { + uint8_t m_data[12]{}; + + public: + message_header(const_buffer msg) { memcpy(m_data, msg.data(), (std::min)(msg.size(), 12)); } + + uint16_t id() const noexcept { return raw_get(m_data); } + constexpr bool qr() const noexcept { return (m_data[2] & 0x80) != 0; } + constexpr dns::opcode opcode() const noexcept { return static_cast((m_data[2] >> 3) & 0x0f); } + constexpr bool authoritative() const noexcept { return (m_data[2] & 0x04) != 0; } + constexpr bool truncated() const noexcept { return (m_data[2] & 0x02) != 0; } + constexpr bool recursion_desired() const noexcept { return (m_data[2] & 0x01) != 0; } + constexpr bool recursion_available() const noexcept { return (m_data[3] & 0x80) != 0; } + constexpr dns::rcode rcode() const noexcept { return static_cast(m_data[3] & 0x0f); } + + uint16_t query_count() const noexcept { return raw_get(m_data + 4); } + uint16_t answer_count() const noexcept { return raw_get(m_data + 6); } + uint16_t authoritative_count() const noexcept { return raw_get(m_data + 8); } + uint16_t additional_count() const noexcept { return raw_get(m_data + 10); } + }; + + template + concept MessageVisitor = // + requires(T& v) { + { + v.on_header(std::declval(), std::declval()) + } -> std::convertible_to; + { + v.on_question(std::declval(), std::declval(), std::declval()) + } -> std::convertible_to; + { + v.on_answer(std::declval(), std::declval(), std::declval(), + std::declval(), std::declval()) + } -> std::convertible_to; + { + v.on_authority(std::declval(), std::declval(), std::declval(), + std::declval(), std::declval()) + } -> std::convertible_to; + { + v.on_additional(std::declval(), std::declval(), std::declval(), + std::declval(), std::declval()) + } -> std::convertible_to; + }; + + inline bool visit_message(const_buffer msg, MessageVisitor auto& visitor, std::error_code& ec) { + const auto pmsg = msg.data(); + const auto pend = pmsg + msg.size(); + if (msg.size() < 12) { + ec = api_error::incomplete_message; + return false; + } + message_header hdr(msg); + if (!visitor.on_header(hdr, msg)) return false; + const std::byte* ptr = pmsg + 12; + std::string name; + for (size_t i = hdr.query_count(); i > 0; i--) { + name.clear(); + const auto pfixed = parse_label(msg, ptr, name, ec); + if (ec) return false; + if (pfixed == nullptr || (pfixed + 4) > (pend)) { + ec = api_error::incomplete_message; + return false; + } + auto qtype = raw_get(pfixed); + auto qclass = raw_get(pfixed + 2); + if (!visitor.on_question(name, qtype, qclass)) return false; + ptr = pfixed + 4; + } + const auto answers = hdr.answer_count(); + const auto authorities = hdr.authoritative_count(); + const auto additionals = hdr.additional_count(); + for (size_t i = answers + authorities + additionals; i > 0; i--) { + name.clear(); + const auto pfixed = parse_label(msg, ptr, name, ec); + if (ec) return false; + if (pfixed == nullptr || (pfixed + 10) > (pend)) { + ec = api_error::incomplete_message; + return false; + } + const auto rtype = raw_get(pfixed); + const auto rclass = raw_get(pfixed + 2); + const auto ttl = raw_get(pfixed + 4); + const auto rdata_len = raw_get(pfixed + 8); + if (pfixed + 10 + rdata_len > pend) { + ec = api_error::incomplete_message; + return false; + } + if (i > authorities + additionals) { + if (!visitor.on_answer(name, rtype, rclass, ttl, const_buffer(pfixed + 10, rdata_len))) return false; + } else if (i > additionals) { + if (!visitor.on_authority(name, rtype, rclass, ttl, const_buffer(pfixed + 10, rdata_len))) return false; + } else { + if (!visitor.on_additional(name, rtype, rclass, ttl, const_buffer(pfixed + 10, rdata_len))) + return false; + } + ptr = pfixed + 10 + rdata_len; + } + if (ptr != pend) { + ec = api_error::extra_data; + return false; + } + return true; + } + + inline bool visit_message(const_buffer msg, MessageVisitor auto& visitor) { + std::error_code ec; + auto res = visit_message(msg, visitor, ec); + if (res && ec) throw std::system_error(ec); + return res; + } + + struct question { + std::string name; + dns::qtype qtype; + dns::qclass qclass; + + std::vector serialize() const; + const std::byte* parse(const_buffer msg, const std::byte* const rr); + }; + + struct resource_record { + std::string name; + qtype rtype; + qclass rclass; + uint32_t ttl; + std::vector rdata; + + std::vector serialize() const; + const std::byte* parse(const_buffer msg, const std::byte* const rr); + }; + + class message_builder { + uint8_t* const m_start; + uint8_t* m_end; + uint8_t* const m_cap; + uint8_t* m_question_end; + uint8_t* m_answer_end; + uint8_t* m_authority_end; + bool m_truncated{}; + + template + uint8_t get_flag() { + return (m_start[pos] & mask) >> shift; + } + + template + void set_flag(uint8_t value) { + auto v = m_start[pos] & ~mask; + v |= ((value << shift) & mask); + m_start[pos] = v; + } + + public: + message_builder(void* buf, size_t buf_size) noexcept + : m_start(static_cast(buf)), m_end(m_start + 12), m_cap(m_start + buf_size) { + m_question_end = m_answer_end = m_authority_end = m_end; + if (m_cap < m_end) m_truncated = true; + memset(m_start, 0, (std::min)(m_cap - m_start, 12)); + } + + message_builder& set_id(uint16_t id) noexcept { + raw_set(m_start, id); + return *this; + } + message_builder& set_qr(bool is_response) noexcept { + set_flag<2, 0x80, 7>(is_response); + return *this; + } + message_builder& set_opcode(opcode op) noexcept { + set_flag<2, 0x78, 3>(static_cast(op)); + return *this; + } + message_builder& set_authoritative(bool aa) noexcept { + set_flag<2, 0x04, 2>(aa); + return *this; + } + message_builder& set_truncated(bool tc) noexcept { + set_flag<2, 0x02, 1>(tc); + return *this; + } + message_builder& set_recursion_desired(bool rd) noexcept { + set_flag<2, 0x01, 0>(rd); + return *this; + } + message_builder& set_recursion_available(bool ra) noexcept { + set_flag<3, 0x80, 7>(ra); + return *this; + } + message_builder& set_rcode(rcode code) noexcept { + set_flag<3, 0x0f, 0>(static_cast(code)); + return *this; + } + template + message_builder& add_question(const T& record) noexcept { + if (m_truncated) return *this; + if (m_cap - m_end < record.size()) { + m_truncated = true; + set_truncated(true); + return *this; + } + memmove(m_question_end + record.size(), m_question_end, m_end - m_question_end); + memcpy(m_question_end, record.data(), record.size()); + raw_set(m_start + 4, raw_get(m_start + 4) + 1); + m_question_end += record.size(); + m_answer_end += record.size(); + m_authority_end += record.size(); + m_end += record.size(); + return *this; + } + message_builder& add_question(std::string_view label, qtype qt, qclass qc) { + auto record_size = label.size() + (label.empty() ? 1 : 2) + 4; + if (m_truncated) return *this; + if (m_cap - m_end < record_size) { + m_truncated = true; + set_truncated(true); + return *this; + } + memmove(m_question_end + record_size, m_question_end, m_end - m_question_end); + + std::error_code ec; + convert_name(m_question_end, record_size - 4, label, ec); + if (ec) throw std::system_error(ec); + raw_set(m_question_end + record_size - 4, qt); + raw_set(m_question_end + record_size - 2, qc); + + raw_set(m_start + 4, raw_get(m_start + 4) + 1); + m_question_end += record_size; + m_answer_end += record_size; + m_authority_end += record_size; + m_end += record_size; + return *this; + } + message_builder& add_question(const question& q) { return add_question(q.serialize()); } + template + message_builder& add_answer(const T& record) { + if (m_truncated) return *this; + if (m_cap - m_end < record.size()) { + m_truncated = true; + set_truncated(true); + return *this; + } + memmove(m_answer_end + record.size(), m_answer_end, m_end - m_answer_end); + memcpy(m_answer_end, record.data(), record.size()); + raw_set(m_start + 6, raw_get(m_start + 6) + 1); + m_answer_end += record.size(); + m_authority_end += record.size(); + m_end += record.size(); + return *this; + } + message_builder& add_answer(const resource_record& rr) { return add_answer(rr.serialize()); } + message_builder& add_answer(std::string name, qtype rtype, qclass rclass, uint32_t ttl, + std::vector rdata) { + resource_record rr; + rr.name = std::move(name); + rr.rtype = rtype; + rr.rclass = rclass; + rr.ttl = ttl; + rr.rdata = std::move(rdata); + return add_answer(std::move(rr)); + } + template + message_builder& add_authority(const T& record) { + if (m_truncated) return *this; + if (m_cap - m_end < record.size()) { + m_truncated = true; + set_truncated(true); + return *this; + } + memmove(m_authority_end + record.size(), m_authority_end, m_end - m_authority_end); + memcpy(m_authority_end, record.data(), record.size()); + raw_set(m_start + 8, raw_get(m_start + 8) + 1); + m_authority_end += record.size(); + m_end += record.size(); + return *this; + } + message_builder& add_authority(const resource_record& rr) { return add_authority(rr.serialize()); } + message_builder& add_authority(std::string name, qtype rtype, qclass rclass, uint32_t ttl, + std::vector rdata) { + resource_record rr; + rr.name = std::move(name); + rr.rtype = rtype; + rr.rclass = rclass; + rr.ttl = ttl; + rr.rdata = std::move(rdata); + return add_authority(std::move(rr)); + } + + template + message_builder& add_additional(const T& record) { + if (m_truncated) return *this; + if (m_cap - m_end < record.size()) { + m_truncated = true; + set_truncated(true); + return *this; + } + memcpy(m_end, record.data(), record.size()); + raw_set(m_start + 10, raw_get(m_start + 10) + 1); + m_end += record.size(); + return *this; + } + message_builder& add_additional(const resource_record& rr) { return add_additional(rr.serialize()); } + message_builder& add_additional(std::string name, qtype rtype, qclass rclass, uint32_t ttl, + std::vector rdata) { + resource_record rr; + rr.name = std::move(name); + rr.rtype = rtype; + rr.rclass = rclass; + rr.ttl = ttl; + rr.rdata = std::move(rdata); + return add_additional(std::move(rr)); + } + + size_t bytes_used() const noexcept { return m_end - m_start; } + + /** Aliases for RFC2136 */ + template + message_builder& add_zone(const T& record) noexcept { + return add_question(record); + } + message_builder& add_zone(std::string_view label, qtype qt, qclass qc) { return add_question(label, qt, qc); } + template + message_builder& add_prerequisite(const T& record) { + return add_answer(record); + } + message_builder& add_prerequisite(std::string name, qtype rtype, qclass rclass, uint32_t ttl, + std::vector rdata) { + return add_answer(std::move(name), rtype, rclass, ttl, std::move(rdata)); + } + template + message_builder& add_update(const T& record) { + return add_authority(record); + } + message_builder& add_update(std::string name, qtype rtype, qclass rclass, uint32_t ttl, + std::vector rdata) { + return add_authority(std::move(name), rtype, rclass, ttl, std::move(rdata)); + } + + message_builder& add_tsig_signature(std::string_view keyname, std::span key); + }; + + struct message { + uint16_t id; + bool is_response; + bool is_authoritative; + bool is_truncated; + bool is_recursion_desired; + bool is_recursion_available; + dns::opcode opcode; + dns::rcode rcode; + + std::vector questions; + std::vector answers; + std::vector authorities; + std::vector additional; + + std::vector serialize() const; + void serialize(void* buf, size_t bufsize) const; + void parse(const_buffer msg); + }; + + std::ostream& operator<<(std::ostream& s, rcode r); + std::ostream& operator<<(std::ostream& s, opcode o); + std::ostream& operator<<(std::ostream& s, qtype t); + std::ostream& operator<<(std::ostream& s, qclass c); + + class print_message_visitor { + std::ostream* m_out; + const_buffer m_message; + bool m_question_header_done{}; + bool m_answer_header_done{}; + bool m_authority_header_done{}; + bool m_additional_header_done{}; + bool m_is_update{}; + + void print_rr(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, const_buffer rdata) noexcept; + + public: + print_message_visitor(std::ostream& str) : m_out(&str) {} + bool on_header(const message_header& hdr, const_buffer message) noexcept; + bool on_question(std::string_view name, qtype qtype, qclass qclass) noexcept; + bool on_answer(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, const_buffer rdata) noexcept; + bool on_authority(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, const_buffer rdata) noexcept; + bool on_additional(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + const_buffer rdata) noexcept; + }; + + template + void visit_answer(const_buffer msg, FN&& fn) { + struct visitor { + FN m_fn; + bool on_header(const message_header& hdr, const_buffer message) noexcept { + return hdr.answer_count() != 0 && hdr.rcode() == rcode::no_error; + } + bool on_question(std::string_view name, qtype qtype, qclass qclass) noexcept { return true; } + bool on_answer(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + const_buffer rdata) noexcept { + return m_fn(name, rtype, rclass, ttl, rdata); + } + bool on_authority(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + const_buffer rdata) noexcept { + return false; + } + bool on_additional(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + const_buffer rdata) noexcept { + return false; + } + } visit{std::move(fn)}; + visit_message(msg, visit); + } + + class client { + class query_awaiter; + + public: + client(io_service& service); + client(const client&) = delete; + client& operator=(const client&) = delete; + client(client&&) = delete; + client& operator=(client&&) = delete; + ~client(); + + enum class protocol { udp }; + + void add_nameserver(endpoint ep, protocol proto = protocol::udp) { + std::unique_lock lck{m_mutex}; + m_nameservers.push_back(ep); + } + void add_nameserver(address addr, protocol proto = protocol::udp) { add_nameserver(endpoint(addr, 53), proto); } + void set_timeout(std::chrono::milliseconds timeout); + void set_retries(size_t retries) { + std::unique_lock lck{m_mutex}; + m_retries = retries; + } + + std::optional get_free_id() const noexcept; + + void query(const_buffer query, std::function callback); + query_awaiter query(const_buffer query) noexcept; + + void query(std::string_view name, dns::qtype qtype, dns::qclass qclass, + std::function callback); + auto query(std::string_view name, dns::qtype qtype, dns::qclass qclass) noexcept; + + void resolve(std::string name, dns::qtype type, std::function res)> callback); + auto resolve(std::string name, dns::qtype type) noexcept; + + void stop(); + + private: + struct request; + mutable std::recursive_mutex m_mutex; + asyncpp::stop_source m_stop; + asyncpp::stop_source m_stop_timer; + socket m_socket_ipv4; + socket m_socket_ipv6; + std::vector m_nameservers; + std::map m_inflight; + std::chrono::milliseconds m_timeout{250}; + size_t m_retries{5}; + + void send_requests(request& req); + }; + + class client::query_awaiter { + protected: + client* m_parent; + asyncpp::coroutine_handle<> m_handle{}; + const_buffer m_query{}; + api_error m_result{}; + std::vector m_response{}; + + public: + query_awaiter(client* that, const_buffer query) noexcept : m_parent(that), m_query{query} {} + + constexpr bool await_ready() const noexcept { return false; } + void await_suspend(asyncpp::coroutine_handle<> h) { + m_parent->query(m_query, [h, this](api_error error, const_buffer response) { + m_result = error; + try { + m_response.resize(response.size()); + memcpy(m_response.data(), response.data(), response.size()); + } catch (...) { m_result = api_error::not_enough_space; } + h.resume(); + }); + } + std::vector await_resume() { + if (m_result != api_error::ok) throw make_error_code(m_result); + return std::move(m_response); + } + }; + + inline client::query_awaiter client::query(const_buffer query) noexcept { return query_awaiter(this, query); } + inline auto client::query(std::string_view name, dns::qtype qtype, dns::qclass qclass) noexcept { + class awaiter : public client::query_awaiter { + std::vector m_buffer; + + public: + awaiter(client* that, std::string_view name, dns::qtype qtype, dns::qclass qclass) + : client::query_awaiter(that, {}) { + m_buffer.resize(max_message_size); + m_buffer.resize(message_builder(m_buffer.data(), m_buffer.size()) + .set_opcode(opcode::query) + .set_recursion_desired(true) + .add_question(name, qtype, qclass) + .bytes_used()); + m_query = {m_buffer.data(), m_buffer.size()}; + } + }; + return awaiter(this, name, qtype, qclass); + } + + inline auto client::resolve(std::string name, qtype type) noexcept { + struct awaiter { + client* m_parent; + std::string m_name; + qtype m_type; + asyncpp::coroutine_handle<> m_handle{}; + std::vector
m_response{}; + + constexpr bool await_ready() const noexcept { return false; } + void await_suspend(asyncpp::coroutine_handle<> h) { + m_parent->resolve(std::move(m_name), m_type, [h, this](std::vector
res) { + m_response = std::move(res); + h.resume(); + }); + } + std::vector
await_resume() { return std::move(m_response); } + }; + return awaiter{this, std::move(name), type}; + } + +} // namespace asyncpp::io::dns diff --git a/include/asyncpp/io/endpoint.h b/include/asyncpp/io/endpoint.h new file mode 100644 index 0000000..70dea3c --- /dev/null +++ b/include/asyncpp/io/endpoint.h @@ -0,0 +1,227 @@ +#pragma once +#include + +namespace asyncpp::io { + class ipv4_endpoint { + ipv4_address m_ip{}; + uint16_t m_port{}; + + public: + constexpr ipv4_endpoint() noexcept {} + constexpr ipv4_endpoint(ipv4_address addr, uint16_t port) noexcept : m_ip{addr}, m_port{port} {} + explicit ipv4_endpoint(const sockaddr_storage& addr); + explicit ipv4_endpoint(const sockaddr_in& addr) noexcept; + + constexpr ipv4_address address() const noexcept { return m_ip; } + constexpr uint16_t port() const noexcept { return m_port; } + + constexpr std::strong_ordering operator<=>(const ipv4_endpoint& rhs) const noexcept = default; + + std::string to_string() const { return m_ip.to_string() + ":" + std::to_string(m_port); } + std::pair to_sockaddr() const noexcept; + std::pair to_sockaddr_in() const noexcept; + + static constexpr std::optional parse(std::string_view str) noexcept { + auto pos = str.find(':'); + auto ip = ipv4_address::parse(str.substr(0, pos)); + if (!ip) return std::nullopt; + if (pos == std::string::npos) return ipv4_endpoint(*ip, 0); + if (pos + 1 == str.size()) return std::nullopt; + uint16_t port = 0; + for (auto it = str.begin() + pos + 1; it != str.end(); it++) { + if (*it < '0' || *it > '9') return std::nullopt; + port = port * 10 + (*it - '0'); + } + return ipv4_endpoint(*ip, port); + } + }; + + class ipv6_endpoint { + ipv6_address m_ip{}; + uint16_t m_port{}; + + public: + constexpr ipv6_endpoint() noexcept {} + constexpr ipv6_endpoint(ipv6_address addr, uint16_t port) noexcept : m_ip{addr}, m_port{port} {} + constexpr ipv6_endpoint(ipv4_address addr, uint16_t port) noexcept : m_ip{addr}, m_port{port} {} + explicit ipv6_endpoint(const sockaddr_storage& addr); + explicit ipv6_endpoint(const sockaddr_in6& addr) noexcept; + + constexpr ipv6_address address() const noexcept { return m_ip; } + constexpr uint16_t port() const noexcept { return m_port; } + + constexpr std::strong_ordering operator<=>(const ipv6_endpoint& rhs) const noexcept = default; + + std::string to_string(bool full = false) const { + return "[" + m_ip.to_string(full) + "]:" + std::to_string(m_port); + } + std::pair to_sockaddr() const noexcept; + std::pair to_sockaddr_in6() const noexcept; + + static constexpr std::optional parse(std::string_view str) noexcept { + auto pos = str.find(']'); + if (pos == std::string::npos || str[0] != '[') return std::nullopt; + auto ip = ipv6_address::parse(str.substr(1, pos - 1)); + if (!ip) return std::nullopt; + pos = str.find(':', pos); + if (pos == std::string::npos) return ipv6_endpoint(*ip, 0); + if (pos + 1 == str.size()) return std::nullopt; + uint16_t port = 0; + for (auto it = str.begin() + pos + 1; it != str.end(); it++) { + if (*it < '0' || *it > '9') return std::nullopt; + port = port * 10 + (*it - '0'); + } + return ipv6_endpoint(*ip, port); + } + }; + +#ifndef _WIN32 + using uds_endpoint = uds_address; +#endif + + class endpoint { + union { + ipv4_endpoint m_ipv4 = {}; + ipv6_endpoint m_ipv6; +#ifndef _WIN32 + uds_endpoint m_uds; +#endif + }; + address_type m_type{}; + + public: + constexpr endpoint() noexcept {} + constexpr endpoint(ipv4_address addr, uint16_t port) noexcept + : m_ipv4(addr, port), m_type(address_type::ipv4) {} + constexpr endpoint(ipv6_address addr, uint16_t port) noexcept + : m_ipv6(addr, port), m_type(address_type::ipv6) {} +#ifndef _WIN32 + constexpr endpoint(uds_address addr) noexcept : m_uds(addr), m_type(address_type::uds) {} +#endif + constexpr endpoint(address addr, uint16_t port) noexcept { + switch (addr.type()) { + case address_type::ipv4: + m_ipv4 = {addr.ipv4(), port}; + m_type = address_type::ipv4; + break; + case address_type::ipv6: + m_ipv6 = {addr.ipv6(), port}; + m_type = address_type::ipv6; + break; + case address_type::uds: + m_uds = addr.uds(); + m_type = address_type::uds; + break; + } + } + explicit constexpr endpoint(ipv4_endpoint ep) noexcept : m_ipv4(ep), m_type(address_type::ipv4) {} + explicit constexpr endpoint(ipv6_endpoint ep) noexcept : m_ipv6(ep), m_type(address_type::ipv6) {} + explicit endpoint(const sockaddr_storage& addr, size_t len); + + constexpr address_type type() const noexcept { return m_type; } + constexpr bool is_ipv4() const noexcept { return m_type == address_type::ipv4; } + constexpr bool is_ipv6() const noexcept { return m_type == address_type::ipv6; } +#ifndef _WIN32 + constexpr bool is_uds() const noexcept { return m_type == address_type::uds; } +#endif + + constexpr ipv4_endpoint ipv4() const noexcept { + switch (m_type) { + case address_type::ipv4: return m_ipv4; + case address_type::ipv6: + case address_type::uds: return {}; + } + } + constexpr ipv6_endpoint ipv6() const noexcept { + switch (m_type) { + case address_type::ipv4: return {}; + case address_type::ipv6: return m_ipv6; + case address_type::uds: return {}; + } + } +#ifndef _WIN32 + constexpr uds_endpoint uds() const noexcept { + switch (m_type) { + case address_type::ipv4: return {}; + case address_type::ipv6: return {}; + case address_type::uds: return m_uds; + } + } +#endif + + constexpr std::strong_ordering operator<=>(const endpoint& rhs) const noexcept { + auto order = m_type <=> rhs.m_type; + if (order != std::strong_ordering::equal) return order; + switch (m_type) { + case address_type::ipv4: return m_ipv4 <=> rhs.m_ipv4; + case address_type::ipv6: return m_ipv6 <=> rhs.m_ipv6; +#ifndef _WIN32 + case address_type::uds: return m_uds <=> rhs.m_uds; +#endif + } + return std::strong_ordering::equal; + } + constexpr bool operator==(const endpoint& rhs) const noexcept { + return (*this <=> rhs) == std::strong_ordering::equal; + } + constexpr bool operator!=(const endpoint& rhs) const noexcept { + return (*this <=> rhs) != std::strong_ordering::equal; + } + + std::string to_string(bool full = false) const { + switch (m_type) { + case address_type::ipv4: return m_ipv4.to_string(); + case address_type::ipv6: return m_ipv6.to_string(full); +#ifndef _WIN32 + case address_type::uds: return m_uds.to_string(); +#endif + } + } + std::pair to_sockaddr() const noexcept; + + static constexpr std::optional parse(std::string_view str, bool allow_uds = false) noexcept { + auto ep6 = ipv6_endpoint::parse(str); + if (ep6) return endpoint(*ep6); + auto ep4 = ipv4_endpoint::parse(str); + if (ep4) return endpoint(*ep4); +#ifndef _WIN32 + if (allow_uds) { + auto epuds = uds_endpoint::parse(str); + if (epuds) return endpoint(*epuds); + } +#endif + return std::nullopt; + } + }; +} // namespace asyncpp::io + +namespace std { + template<> + struct hash { + size_t operator()(const asyncpp::io::ipv4_endpoint& x) const noexcept { + return std::hash{}((static_cast(x.address().integer()) << 16) | x.port()); + } + }; + template<> + struct hash { + size_t operator()(const asyncpp::io::ipv6_endpoint& x) const noexcept { + std::hash h{}; + auto res = h(x.address()); + return res ^ (x.port() + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); + } + }; + template<> + struct hash { + size_t operator()(const asyncpp::io::endpoint& x) const noexcept { + size_t res; + switch (x.type()) { + case asyncpp::io::address_type::ipv4: res = std::hash{}(x.ipv4()); break; + case asyncpp::io::address_type::ipv6: res = std::hash{}(x.ipv6()); break; +#ifndef _WIN32 + case asyncpp::io::address_type::uds: res = std::hash{}(x.uds()); break; +#endif + } + return res ^ (static_cast(x.type()) + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); + } + }; +} // namespace std diff --git a/include/asyncpp/io/file.h b/include/asyncpp/io/file.h new file mode 100644 index 0000000..91117a8 --- /dev/null +++ b/include/asyncpp/io/file.h @@ -0,0 +1,339 @@ +#pragma once +#include +#include +#include + +#include +#include +#include +#include + +namespace asyncpp::io { + namespace detail { +// C++26 makes this trivial, but sofar this is only implemented on libstdc++ > 14 and libc++ > 18 +#if defined(__cpp_lib_fstream_native_handle) +#define ASYNCPP_IO_HANDLE_FROM_FILEBUF 1 + template + inline auto get_file_handle_from_filebuf(const std::basic_filebuf* buf) noexcept { + return buf->native_handle(); + } +#elif defined(__GLIBCXX__) // Lets do a little magic, shall we ? +#define ASYNCPP_IO_HANDLE_FROM_FILEBUF 1 + template + inline auto get_file_handle_from_filebuf(const std::basic_filebuf* buf) noexcept { + if (!buf->is_open()) return -1; + struct magic : std::basic_filebuf { + auto get_file() const noexcept { return &this->_M_file; } + }; + // fd() is actually const because it only calls fileno on the c file, but its not marked as such + auto f = const_cast*>(static_cast(buf)->get_file()); + return f->fd(); + } +#elif defined(_LIBCPP_VERSION) // Lets do a little more magic, shall we ? +#define ASYNCPP_IO_HANDLE_FROM_FILEBUF 1 + template + inline auto get_file_handle_from_filebuf(const std::basic_filebuf* buf) noexcept { + // std::filebuf in libc++ has had this layout at least since version 1.0 (release 2010) + // and we don't care if it changes in the future cause then __cpp_lib_fstream_native_handle above + // will take over. + struct magic : public std::basic_streambuf { + char* __extbuf_; + const char* __extbufnext_; + const char* __extbufend_; + char __extbuf_min_[8]; + size_t __ebs_; + CharT* __intbuf_; + size_t __ibs_; + FILE* __file_; + const std::codecvt* __cv_; + typename Traits::state_type __st_; + typename Traits::state_type __st_last_; + std::ios_base::openmode __om_; + std::ios_base::openmode __cm_; + bool __owns_eb_; + bool __owns_ib_; + bool __always_noconv_; + + auto fd() const noexcept { return __file_ == nullptr ? -1 : fileno(__file_); } + }; + static_assert(sizeof(magic) == sizeof(std::filebuf), "Implementation changed"); + return reinterpret_cast(buf)->fd(); + } +#endif + +#ifndef ASYNCPP_IO_HANDLE_FROM_FILEBUF +#define ASYNCPP_IO_HANDLE_FROM_FILEBUF 0 +#endif + +#if ASYNCPP_IO_HANDLE_FROM_FILEBUF + template + inline auto get_file_handle_from_filebuf(const std::basic_filebuf& buf) { + return get_file_handle_from_filebuf(&buf); + } + template + inline auto get_file_handle_from_filebuf(const std::basic_streambuf* buf) { + auto filebuf = dynamic_cast*>(buf); + if (filebuf == nullptr) throw std::logic_error("not a filebuf"); + return get_file_handle_from_filebuf(filebuf); + } + template + inline auto get_file_handle_from_filebuf(const std::basic_streambuf& buf) { + return get_file_handle_from_filebuf(&buf); + } +#endif + + class file_read_awaitable { + file_read_awaitable(const file_read_awaitable&) = delete; + file_read_awaitable(file_read_awaitable&&) = delete; + file_read_awaitable& operator=(const file_read_awaitable&) = delete; + file_read_awaitable& operator=(file_read_awaitable&&) = delete; + + template + friend class detail::cancellable_awaitable; + + io_engine* const m_engine; + io_engine::file_handle_t const m_fd; + void* const m_buf; + size_t const m_len; + uint64_t const m_offset; + std::error_code* const m_ec; + + protected: + detail::io_engine::completion_data m_completion; + + public: + constexpr file_read_awaitable(io_engine* engine, io_engine::file_handle_t fd, void* buf, size_t len, + uint64_t offset, std::error_code* ec) noexcept + : m_engine(engine), m_fd(fd), m_buf(buf), m_len(len), m_offset(offset), m_ec(ec), m_completion{} {} + bool await_ready() const noexcept { return false; } + bool await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_engine->enqueue_readv(m_fd, m_buf, m_len, m_offset, &m_completion); + } + size_t await_resume() { + if (m_completion.result >= 0) return static_cast(m_completion.result); + if (m_ec == nullptr) + throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + return 0; + } + }; + + class file_write_awaitable { + file_write_awaitable(const file_write_awaitable&) = delete; + file_write_awaitable(file_write_awaitable&&) = delete; + file_write_awaitable& operator=(const file_write_awaitable&) = delete; + file_write_awaitable& operator=(file_write_awaitable&&) = delete; + + template + friend class detail::cancellable_awaitable; + + io_engine* const m_engine; + io_engine::file_handle_t const m_fd; + const void* const m_buf; + size_t const m_len; + uint64_t const m_offset; + std::error_code* const m_ec; + + protected: + detail::io_engine::completion_data m_completion; + + public: + constexpr file_write_awaitable(io_engine* engine, io_engine::file_handle_t fd, const void* buf, size_t len, + uint64_t offset, std::error_code* ec) noexcept + : m_engine(engine), m_fd(fd), m_buf(buf), m_len(len), m_offset(offset), m_ec(ec), m_completion{} {} + bool await_ready() const noexcept { return false; } + bool await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_engine->enqueue_writev(m_fd, m_buf, m_len, m_offset, &m_completion); + } + size_t await_resume() { + if (m_completion.result >= 0) return static_cast(m_completion.result); + if (m_ec == nullptr) + throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + return 0; + } + }; + + class file_fsync_awaitable { + file_fsync_awaitable(const file_fsync_awaitable&) = delete; + file_fsync_awaitable(file_fsync_awaitable&&) = delete; + file_fsync_awaitable& operator=(const file_fsync_awaitable&) = delete; + file_fsync_awaitable& operator=(file_fsync_awaitable&&) = delete; + + template + friend class detail::cancellable_awaitable; + + io_engine* const m_engine; + io_engine::file_handle_t const m_fd; + std::error_code* const m_ec; + + protected: + detail::io_engine::completion_data m_completion; + + public: + constexpr file_fsync_awaitable(io_engine* engine, io_engine::file_handle_t fd, std::error_code* ec) noexcept + : m_engine(engine), m_fd(fd), m_ec(ec), m_completion{} {} + bool await_ready() const noexcept { return false; } + bool await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_engine->enqueue_fsync(m_fd, io_engine::fsync_flags::none, &m_completion); + } + void await_resume() { + if (m_completion.result >= 0) return; + if (m_ec == nullptr) + throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + } + }; + } // namespace detail + + inline auto read(detail::io_engine& engine, detail::io_engine::file_handle_t fd, void* buf, size_t len, + uint64_t offset) { + return detail::file_read_awaitable(&engine, fd, buf, len, offset, nullptr); + } + + inline auto read(detail::io_engine& engine, detail::io_engine::file_handle_t fd, void* buf, size_t len, + uint64_t offset, std::error_code& ec) { + return detail::file_read_awaitable(&engine, fd, buf, len, offset, &ec); + } + + inline auto read(detail::io_engine& engine, detail::io_engine::file_handle_t fd, void* buf, size_t len, + uint64_t offset, asyncpp::stop_token st) { + return detail::cancellable_awaitable(std::move(st), &engine, fd, buf, len, offset, + nullptr); + } + + inline auto read(detail::io_engine& engine, detail::io_engine::file_handle_t fd, void* buf, size_t len, + uint64_t offset, asyncpp::stop_token st, std::error_code& ec) { + return detail::cancellable_awaitable(std::move(st), &engine, fd, buf, len, offset, + &ec); + } + + inline auto write(detail::io_engine& engine, detail::io_engine::file_handle_t fd, const void* buf, size_t len, + uint64_t offset) { + return detail::file_write_awaitable(&engine, fd, buf, len, offset, nullptr); + } + + inline auto write(detail::io_engine& engine, detail::io_engine::file_handle_t fd, const void* buf, size_t len, + uint64_t offset, std::error_code& ec) { + return detail::file_write_awaitable(&engine, fd, buf, len, offset, &ec); + } + + inline auto write(detail::io_engine& engine, detail::io_engine::file_handle_t fd, const void* buf, size_t len, + uint64_t offset, asyncpp::stop_token st) { + return detail::cancellable_awaitable(std::move(st), &engine, fd, buf, len, offset, + nullptr); + } + + inline auto write(detail::io_engine& engine, detail::io_engine::file_handle_t fd, const void* buf, size_t len, + uint64_t offset, asyncpp::stop_token st, std::error_code& ec) { + return detail::cancellable_awaitable(std::move(st), &engine, fd, buf, len, offset, + &ec); + } + + inline auto fsync(detail::io_engine& engine, detail::io_engine::file_handle_t fd) { + return detail::file_fsync_awaitable(&engine, fd, nullptr); + } + + inline auto fsync(detail::io_engine& engine, detail::io_engine::file_handle_t fd, std::error_code& ec) { + return detail::file_fsync_awaitable(&engine, fd, &ec); + } + + inline auto fsync(detail::io_engine& engine, detail::io_engine::file_handle_t fd, asyncpp::stop_token st) { + return detail::cancellable_awaitable(std::move(st), &engine, fd, nullptr); + } + + inline auto fsync(detail::io_engine& engine, detail::io_engine::file_handle_t fd, asyncpp::stop_token st, + std::error_code& ec) { + return detail::cancellable_awaitable(std::move(st), &engine, fd, &ec); + } + + class file { + io_service* m_io; + detail::io_engine::file_handle_t m_fd; + + public: + explicit file(io_service& io); + file(io_service& io, detail::io_engine::file_handle_t fd); + explicit file(io_service& io, const char* filename, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out); + explicit file(io_service& io, const std::string& filename, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out); + explicit file(io_service& io, const std::filesystem::path& filename, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out); + file(const file&) = delete; + file(file&&) noexcept; + file& operator=(const file&) = delete; + file& operator=(file&&); + ~file(); + + [[nodiscard]] io_service& service() const noexcept { return *m_io; } + [[nodiscard]] detail::io_engine::file_handle_t native_handle() const noexcept { return m_fd; } + [[nodiscard]] detail::io_engine::file_handle_t release() noexcept { + return std::exchange(m_fd, detail::io_engine::invalid_file_handle); + } + + void open(const char* filename, std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out); + void open(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out); + void open(const std::filesystem::path& filename, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out); + + [[nodiscard]] bool is_open() const noexcept; + [[nodiscard]] bool operator!() const noexcept { return !is_open(); } + [[nodiscard]] operator bool() const noexcept { return is_open(); } + + void close(); + + void swap(file& other); + + [[nodiscard]] uint64_t size(); + + auto read(void* buf, size_t len, uint64_t offset) { + return detail::file_read_awaitable(m_io->engine(), m_fd, buf, len, offset, nullptr); + } + auto read(void* buf, size_t len, uint64_t offset, std::error_code& ec) { + return detail::file_read_awaitable(m_io->engine(), m_fd, buf, len, offset, &ec); + } + auto read(void* buf, size_t len, uint64_t offset, asyncpp::stop_token st) { + return detail::cancellable_awaitable(std::move(st), m_io->engine(), m_fd, buf, + len, offset, nullptr); + } + auto read(void* buf, size_t len, uint64_t offset, asyncpp::stop_token st, std::error_code& ec) { + return detail::cancellable_awaitable(std::move(st), m_io->engine(), m_fd, buf, + len, offset, &ec); + } + + auto write(const void* buf, size_t len, uint64_t offset) { + return detail::file_write_awaitable(m_io->engine(), m_fd, buf, len, offset, nullptr); + } + auto write(const void* buf, size_t len, uint64_t offset, std::error_code& ec) { + return detail::file_write_awaitable(m_io->engine(), m_fd, buf, len, offset, &ec); + } + auto write(const void* buf, size_t len, uint64_t offset, asyncpp::stop_token st) { + return detail::cancellable_awaitable(std::move(st), m_io->engine(), m_fd, buf, + len, offset, nullptr); + } + auto write(const void* buf, size_t len, uint64_t offset, asyncpp::stop_token st, std::error_code& ec) { + return detail::cancellable_awaitable(std::move(st), m_io->engine(), m_fd, buf, + len, offset, &ec); + } + + auto fsync() { return detail::file_fsync_awaitable(m_io->engine(), m_fd, nullptr); } + auto fsync(std::error_code& ec) { return detail::file_fsync_awaitable(m_io->engine(), m_fd, &ec); } + auto fsync(asyncpp::stop_token st) { + return detail::cancellable_awaitable(std::move(st), m_io->engine(), m_fd, + nullptr); + } + auto fsync(asyncpp::stop_token st, std::error_code& ec) { + return detail::cancellable_awaitable(std::move(st), m_io->engine(), m_fd, + &ec); + } + }; + + inline void swap(file& lhs, file& rhs) { lhs.swap(rhs); } +} // namespace asyncpp::io diff --git a/include/asyncpp/io/io_service.h b/include/asyncpp/io/io_service.h new file mode 100644 index 0000000..304ef7d --- /dev/null +++ b/include/asyncpp/io/io_service.h @@ -0,0 +1,44 @@ +#pragma once +#include +#include +#include +#include +#include + +#include + +namespace asyncpp::io { + namespace detail { + class io_engine; + } + class io_service : public dispatcher { + std::unique_ptr m_engine; + threadsafe_queue> m_dispatched; + bool m_stopped{false}; + + public: + enum class run_mode { + until_stopped, + while_active, + once, + nowait, + }; + + io_service(); + io_service(const io_service&) = delete; + io_service(io_service&&) = delete; + io_service& operator=(const io_service&) = delete; + io_service& operator=(io_service&&) = delete; + ~io_service() noexcept(false); + + bool run(run_mode mode = run_mode::while_active); + void stop(); + bool stopped() const noexcept { return m_stopped; } + + void push(std::function fn) override; + + detail::io_engine* engine() noexcept { return m_engine.get(); } + + static std::shared_ptr get_default(); + }; +} // namespace asyncpp::io diff --git a/include/asyncpp/io/network.h b/include/asyncpp/io/network.h new file mode 100644 index 0000000..d9ae096 --- /dev/null +++ b/include/asyncpp/io/network.h @@ -0,0 +1,166 @@ +#pragma once +#include + +#include +#include +#include + +namespace asyncpp::io { + class ipv4_network { + ipv4_address m_ip{}; + uint8_t m_prefix{}; + + constexpr static ipv4_address make_canonical(ipv4_address addr, uint8_t prefix) noexcept { + if (prefix == 0) return ipv4_address(); + if (prefix >= 32) return addr; + const uint32_t mask = ((uint32_t(1) << prefix) - 1) << (32 - prefix); + return ipv4_address(addr.integer() & mask); + } + + public: + constexpr ipv4_network() noexcept = default; + constexpr ipv4_network(ipv4_address addr, uint8_t prefix) noexcept + : m_ip{make_canonical(addr, prefix)}, m_prefix{prefix} {} + constexpr ipv4_network(ipv4_address addr, ipv4_address mask) noexcept + : m_ip{}, m_prefix{static_cast(std::countl_one(mask.integer()))} { + m_ip = make_canonical(m_ip, m_prefix); + } + + constexpr uint8_t prefix_length() const noexcept { return m_prefix; } + constexpr ipv4_address canonical() const noexcept { return m_ip; } + constexpr ipv4_address broadcast() const noexcept { + if (m_prefix == 0) return ipv4_address(255, 255, 255, 255); + if (m_prefix >= 32) return m_ip; + const uint32_t mask = ((uint32_t(1) << m_prefix) - 1) << (32 - m_prefix); + return ipv4_address((m_ip.integer() & mask) | (std::numeric_limits::max() & ~mask)); + } + + constexpr bool is_subnet(const ipv4_network& subnet) const noexcept { + const auto base = subnet.canonical(); + return (subnet.m_prefix > m_prefix && base >= canonical() && base < broadcast()); + } + constexpr bool is_subnet_of(const ipv4_network& parent) const noexcept { + const ipv4_network base(m_ip, parent.m_prefix); + return parent.m_prefix < m_prefix && base.canonical() == parent.canonical(); + } + constexpr bool contains(const ipv4_address& host) const noexcept { + constexpr uint32_t max = std::numeric_limits::max(); + if (m_prefix == 0) return true; + const uint32_t mask = m_prefix >= 32 ? max : (((uint32_t(1) << m_prefix) - 1) << (32 - m_prefix)); + return (host.integer() & mask) == (m_ip.integer() & mask); + } + + constexpr std::strong_ordering operator<=>(const ipv4_network& rhs) const noexcept = default; + + std::string to_string() const { return m_ip.to_string() + "/" + std::to_string(m_prefix); } + + static constexpr std::optional parse(std::string_view str) noexcept { + auto pos = str.find('/'); + auto ip = ipv4_address::parse(str.substr(0, pos)); + if (!ip) return std::nullopt; + if (pos == std::string::npos) return ipv4_network(*ip, 32); + if (pos + 1 == str.size()) return std::nullopt; + uint8_t prefix = 0; + for (auto it = str.begin() + pos + 1; it != str.end(); it++) { + if (*it < '0' || *it > '9') return std::nullopt; + prefix = prefix * 10 + (*it - '0'); + } + return ipv4_network(*ip, prefix); + } + }; + + class ipv6_network { + ipv6_address m_ip{}; + uint8_t m_prefix{}; + + constexpr static std::pair make_mask(uint8_t prefix) noexcept { + constexpr uint64_t max = std::numeric_limits::max(); + if (prefix >= 128) return {max, max}; + if (prefix == 0) return {0, 0}; + const uint64_t mask1 = prefix >= 64 ? max : (((uint64_t(1) << prefix) - 1) << (64 - prefix)); + const uint64_t mask2 = prefix < 64 ? 0 : (((uint64_t(1) << (prefix - 64)) - 1) << (64 - (prefix - 64))); + return {mask1, mask2}; + } + + constexpr static ipv6_address make_canonical(ipv6_address addr, uint8_t prefix) noexcept { + if (prefix == 0) return ipv6_address(); + if (prefix >= 128) return addr; + auto [mask1, mask2] = make_mask(prefix); + return ipv6_address(addr.subnet_prefix() & mask1, addr.interface_identifier() & mask2); + } + + public: + constexpr ipv6_network() noexcept = default; + constexpr ipv6_network(ipv6_address addr, uint8_t prefix) noexcept + : m_ip{make_canonical(addr, prefix)}, m_prefix{prefix} {} + constexpr ipv6_network(ipv6_address addr, ipv6_address mask) noexcept : m_ip{}, m_prefix{0} { + uint8_t prefix = std::countl_one(mask.subnet_prefix()); + if (prefix == 64) prefix = 64 + std::countl_one(mask.interface_identifier()); + m_prefix = prefix; + m_ip = make_canonical(addr, prefix); + } + + constexpr uint8_t prefix_length() const noexcept { return m_prefix; } + + constexpr ipv6_address canonical() const noexcept { return m_ip; } + constexpr ipv6_address broadcast() const noexcept { + if (m_prefix == 0) return ipv6_address(); + if (m_prefix >= 128) return m_ip; + constexpr uint64_t max = std::numeric_limits::max(); + auto [mask1, mask2] = make_mask(m_prefix); + return ipv6_address((m_ip.subnet_prefix() & mask1) | (max & ~mask1), + (m_ip.interface_identifier() & mask2) | (max & ~mask2)); + } + + constexpr bool is_subnet(const ipv6_network& subnet) const noexcept { + const auto base = subnet.canonical(); + return (subnet.m_prefix > m_prefix && base >= canonical() && base < broadcast()); + } + constexpr bool is_subnet_of(const ipv6_network& parent) const noexcept { + const ipv6_network base(m_ip, parent.m_prefix); + return parent.m_prefix < m_prefix && base.canonical() == parent.canonical(); + } + constexpr bool contains(const ipv6_address& host) const noexcept { + if (m_prefix == 0) return true; + if (m_prefix >= 128) return canonical() == host; + auto [mask1, mask2] = make_mask(m_prefix); + return (host.subnet_prefix() & mask1) == (m_ip.subnet_prefix() & mask1) && + (host.interface_identifier() & mask2) == (m_ip.interface_identifier() & mask2); + } + + constexpr std::strong_ordering operator<=>(const ipv6_network& rhs) const noexcept = default; + + std::string to_string(bool full = false) const { return m_ip.to_string(full) + "/" + std::to_string(m_prefix); } + + static constexpr std::optional parse(std::string_view str) noexcept { + auto pos = str.find('/'); + auto ip = ipv6_address::parse(str.substr(0, pos)); + if (!ip) return std::nullopt; + if (pos == std::string::npos) return ipv6_network(*ip, 128); + if (pos + 1 == str.size()) return std::nullopt; + uint16_t prefix = 0; + for (auto it = str.begin() + pos + 1; it != str.end(); it++) { + if (*it < '0' || *it > '9') return std::nullopt; + prefix = prefix * 10 + (*it - '0'); + } + return ipv6_network(*ip, prefix); + } + }; +} // namespace asyncpp::io + +namespace std { + template<> + struct hash { + size_t operator()(const asyncpp::io::ipv4_network& x) const noexcept { + return std::hash{}((static_cast(x.canonical().integer()) << 16) | x.prefix_length()); + } + }; + template<> + struct hash { + size_t operator()(const asyncpp::io::ipv6_network& x) const noexcept { + std::hash h{}; + auto res = h(x.canonical()); + return res ^ (x.prefix_length() + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); + } + }; +} // namespace std diff --git a/include/asyncpp/io/socket.h b/include/asyncpp/io/socket.h new file mode 100644 index 0000000..45079e8 --- /dev/null +++ b/include/asyncpp/io/socket.h @@ -0,0 +1,744 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace asyncpp::io { + class socket; + + namespace detail { + class socket_awaitable_base { + socket_awaitable_base(const socket_awaitable_base&) = delete; + socket_awaitable_base(socket_awaitable_base&&) = delete; + socket_awaitable_base& operator=(const socket_awaitable_base&) = delete; + socket_awaitable_base& operator=(socket_awaitable_base&&) = delete; + + template + friend class detail::cancellable_awaitable; + + protected: + socket& m_socket; + detail::io_engine::completion_data m_completion; + + public: + constexpr socket_awaitable_base(socket& sock) noexcept : m_socket{sock}, m_completion{} {} + bool await_ready() const noexcept { return false; } + }; + } // namespace detail + + class socket_connect_awaitable; + class socket_create_and_connect_awaitable; + class socket_accept_awaitable; + class socket_accept_error_code_awaitable; + class socket_send_awaitable; + class socket_recv_awaitable; + class socket_recv_exact_awaitable; + class socket_recv_from_awaitable; + class socket_send_to_awaitable; + + using socket_connect_cancellable_awaitable = detail::cancellable_awaitable; + class socket_create_and_connect_cancellable_awaitable; + using socket_accept_cancellable_awaitable = detail::cancellable_awaitable; + using socket_accept_error_code_cancellable_awaitable = + detail::cancellable_awaitable; + using socket_send_cancellable_awaitable = detail::cancellable_awaitable; + using socket_recv_cancellable_awaitable = detail::cancellable_awaitable; + using socket_recv_exact_cancellable_awaitable = detail::cancellable_awaitable; + using socket_recv_from_cancellable_awaitable = detail::cancellable_awaitable; + using socket_send_to_cancellable_awaitable = detail::cancellable_awaitable; + + class socket { + public: + [[deprecated("use create_tcp instead")]] [[nodiscard]] static socket create_tcpv4(io_service& io) { + return create_tcp(io, address_type::ipv4); + } + [[deprecated("use create_tcp instead")]] [[nodiscard]] static socket create_tcpv6(io_service& io) { + return create_tcp(io, address_type::ipv6); + } + [[nodiscard]] static socket create_tcp(io_service& io, address_type addr_type); + [[nodiscard]] static socket_create_and_connect_awaitable create_connected_tcp(io_service& io, endpoint ep); + [[nodiscard]] static socket_create_and_connect_cancellable_awaitable + create_connected_tcp(io_service& io, endpoint ep, asyncpp::stop_token token); + [[deprecated("use create_udp instead")]] [[nodiscard]] static socket create_udpv4(io_service& io) { + return create_udp(io, address_type::ipv4); + } + [[deprecated("use create_udp instead")]] [[nodiscard]] static socket create_udpv6(io_service& io) { + return create_udp(io, address_type::ipv6); + } + [[nodiscard]] static socket create_udp(io_service& io, address_type addr_type); + [[nodiscard]] static socket create_and_bind_tcp(io_service& io, const endpoint& ep); + [[nodiscard]] static socket create_and_bind_udp(io_service& io, const endpoint& ep); + [[nodiscard]] static socket from_fd(io_service& io, detail::io_engine::socket_handle_t fd); +#ifndef __WIN32 + [[nodiscard]] static std::pair connected_pair_tcp(io_service& io, address_type addrtype); + [[nodiscard]] static std::pair connected_pair_udp(io_service& io, address_type addrtype); +#endif + + constexpr socket() noexcept = default; + socket(socket&& other) noexcept; + socket& operator=(socket&& other) noexcept; + ~socket(); + + [[nodiscard]] bool valid() const noexcept { return m_io != nullptr; } + [[nodiscard]] operator bool() const noexcept { return m_io != nullptr; } + [[nodiscard]] bool operator!() const noexcept { return m_io == nullptr; } + + [[nodiscard]] io_service& service() const noexcept { return *m_io; } + + [[nodiscard]] const endpoint& local_endpoint() const noexcept { return m_local_ep; } + [[nodiscard]] const endpoint& remote_endpoint() const noexcept { return m_remote_ep; } + + void bind(const endpoint& ep); + void listen(std::uint32_t backlog = 0); + + void allow_broadcast(bool enable); + + [[nodiscard]] detail::io_engine::socket_handle_t native_handle() const noexcept { return m_fd; } + [[nodiscard]] detail::io_engine::socket_handle_t release() noexcept { + m_io = nullptr; + m_remote_ep = {}; + m_local_ep = {}; + return std::exchange(m_fd, -1); + } + + [[nodiscard]] constexpr socket_connect_awaitable connect(const endpoint& ep) noexcept; + [[nodiscard]] constexpr socket_connect_awaitable connect(const endpoint& ep, std::error_code& ec) noexcept; + [[nodiscard]] constexpr socket_accept_awaitable accept() noexcept; + [[nodiscard]] constexpr socket_accept_error_code_awaitable accept(std::error_code& ec) noexcept; + [[nodiscard]] constexpr socket_send_awaitable send(const void* buffer, std::size_t size) noexcept; + [[nodiscard]] constexpr socket_send_awaitable send(const void* buffer, std::size_t size, + std::error_code& ec) noexcept; + [[nodiscard]] constexpr socket_recv_awaitable recv(void* buffer, std::size_t size) noexcept; + [[nodiscard]] constexpr socket_recv_awaitable recv(void* buffer, std::size_t size, + std::error_code& ec) noexcept; + [[nodiscard]] constexpr socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size) noexcept; + [[nodiscard]] constexpr socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size, + std::error_code& ec) noexcept; + [[nodiscard]] constexpr socket_send_to_awaitable send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep) noexcept; + [[nodiscard]] constexpr socket_send_to_awaitable send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep, std::error_code& ec) noexcept; + [[nodiscard]] constexpr socket_recv_from_awaitable recv_from(void* buffer, std::size_t size) noexcept; + [[nodiscard]] constexpr socket_recv_from_awaitable recv_from(void* buffer, std::size_t size, + std::error_code& ec) noexcept; + + [[nodiscard]] socket_connect_cancellable_awaitable connect(const endpoint& ep, asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_connect_cancellable_awaitable connect(const endpoint& ep, asyncpp::stop_token st, + std::error_code& ec) noexcept; + [[nodiscard]] socket_accept_cancellable_awaitable accept(asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_accept_error_code_cancellable_awaitable accept(asyncpp::stop_token st, + std::error_code& ec) noexcept; + [[nodiscard]] socket_send_cancellable_awaitable send(const void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_send_cancellable_awaitable send(const void* buffer, std::size_t size, + asyncpp::stop_token st, std::error_code& ec) noexcept; + [[nodiscard]] socket_recv_cancellable_awaitable recv(void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_recv_cancellable_awaitable recv(void* buffer, std::size_t size, asyncpp::stop_token st, + std::error_code& ec) noexcept; + [[nodiscard]] socket_recv_exact_cancellable_awaitable recv_exact(void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_recv_exact_cancellable_awaitable + recv_exact(void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept; + [[nodiscard]] socket_send_to_cancellable_awaitable + send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_send_to_cancellable_awaitable send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep, asyncpp::stop_token st, + std::error_code& ec) noexcept; + [[nodiscard]] socket_recv_from_cancellable_awaitable recv_from(void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_recv_from_cancellable_awaitable + recv_from(void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept; + + template + requires(std::is_invocable_v) + void connect(const endpoint& ep, FN&& cb, asyncpp::stop_token st = {}); + template + requires(std::is_invocable_v>) + void accept(FN&& cb, asyncpp::stop_token st = {}); + template + requires(std::is_invocable_v) + void send(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {}); + template + requires(std::is_invocable_v) + void recv(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {}); + template + requires(std::is_invocable_v) + void send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, FN&& cb, + asyncpp::stop_token st = {}); + template + requires(std::is_invocable_v) + void recv_from(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {}); + + void close_send(); + void close_recv(); + + friend void swap(socket& a, socket& b) noexcept; + + private: + socket(io_service* io, detail::io_engine::socket_handle_t fd) noexcept; + void update_endpoint_info(); + + io_service* m_io{}; + detail::io_engine::socket_handle_t m_fd{detail::io_engine::invalid_socket_handle}; + endpoint m_remote_ep{}; + endpoint m_local_ep{}; + }; + + inline void swap(socket& a, socket& b) noexcept { + std::swap(a.m_io, b.m_io); + std::swap(a.m_fd, b.m_fd); + std::swap(a.m_local_ep, b.m_local_ep); + std::swap(a.m_remote_ep, b.m_remote_ep); + } + + class socket_connect_awaitable : public detail::socket_awaitable_base { + const endpoint m_ep; + std::error_code* const m_ec; + + public: + constexpr socket_connect_awaitable(socket& sock, endpoint ep, std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_ep{ep}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + void await_resume(); + }; + + class socket_create_and_connect_awaitable { + socket m_sock; + socket_connect_awaitable m_child; + + public: + socket_create_and_connect_awaitable(io_service& io, endpoint ep) noexcept + : m_sock{socket::create_tcp(io, ep.type())}, m_child{m_sock, ep} {} + bool await_suspend(coroutine_handle<> hdl) { return m_child.await_suspend(hdl); } + bool await_ready() const noexcept { return false; } + socket await_resume() { + m_child.await_resume(); + return std::move(m_sock); + } + }; + + class socket_create_and_connect_cancellable_awaitable { + socket m_sock; + socket_connect_cancellable_awaitable m_child; + + public: + socket_create_and_connect_cancellable_awaitable(asyncpp::stop_token token, io_service& io, endpoint ep) noexcept + : m_sock{socket::create_tcp(io, ep.type())}, m_child{std::move(token), m_sock, ep} {} + bool await_suspend(coroutine_handle<> hdl) { return m_child.await_suspend(hdl); } + socket await_resume() { + m_child.await_resume(); + return std::move(m_sock); + } + bool await_ready() const noexcept { return false; } + }; + + class socket_send_awaitable : public detail::socket_awaitable_base { + const void* const m_buffer; + std::size_t const m_size; + std::error_code* const m_ec; + + public: + constexpr socket_send_awaitable(socket& sock, const void* buffer, std::size_t size, + std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + void await_resume(); + }; + + class socket_recv_awaitable : public detail::socket_awaitable_base { + void* const m_buffer; + std::size_t const m_size; + std::error_code* const m_ec; + + public: + constexpr socket_recv_awaitable(socket& sock, void* buffer, std::size_t size, + std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + + class socket_recv_exact_awaitable : public asyncpp::io::detail::socket_awaitable_base { + unsigned char* m_buffer; + std::size_t const m_size; + std::size_t m_remaining; + asyncpp::coroutine_handle<> m_handle; + std::error_code* const m_ec; + + public: + constexpr socket_recv_exact_awaitable(asyncpp::io::socket& sock, void* buffer, std::size_t size, + std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_buffer{static_cast(buffer)}, m_size{size}, + m_remaining{size}, m_ec{ec} {} + bool await_suspend(asyncpp::coroutine_handle<> hdl); + size_t await_resume(); + }; + + class socket_accept_awaitable : public detail::socket_awaitable_base { + public: + constexpr socket_accept_awaitable(socket& sock) noexcept : socket_awaitable_base{sock} {} + bool await_suspend(coroutine_handle<> hdl); + socket await_resume(); + }; + + class socket_accept_error_code_awaitable : public detail::socket_awaitable_base { + std::error_code& m_ec; + + public: + constexpr socket_accept_error_code_awaitable(socket& sock, std::error_code& ec) noexcept + : socket_awaitable_base{sock}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + std::optional await_resume(); + }; + + class socket_send_to_awaitable : public detail::socket_awaitable_base { + const void* const m_buffer; + std::size_t const m_size; + endpoint const m_destination; + std::error_code* const m_ec; + + public: + constexpr socket_send_to_awaitable(socket& sock, const void* buffer, std::size_t size, endpoint dst, + std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_destination{dst}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + + class socket_recv_from_awaitable : public detail::socket_awaitable_base { + void* const m_buffer; + std::size_t const m_size; + endpoint m_source; + std::error_code* const m_ec; + + public: + constexpr socket_recv_from_awaitable(socket& sock, void* buffer, std::size_t size, + std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + std::pair await_resume(); + }; + + [[nodiscard]] inline constexpr socket_connect_awaitable socket::connect(const endpoint& ep) noexcept { + return socket_connect_awaitable(*this, ep); + } + + [[nodiscard]] inline constexpr socket_connect_awaitable socket::connect(const endpoint& ep, + std::error_code& ec) noexcept { + return socket_connect_awaitable(*this, ep, &ec); + } + + [[nodiscard]] inline constexpr socket_accept_awaitable socket::accept() noexcept { + return socket_accept_awaitable(*this); + } + + [[nodiscard]] inline constexpr socket_accept_error_code_awaitable socket::accept(std::error_code& ec) noexcept { + return socket_accept_error_code_awaitable(*this, ec); + } + + [[nodiscard]] inline constexpr socket_send_awaitable socket::send(const void* buffer, std::size_t size) noexcept { + return socket_send_awaitable(*this, buffer, size); + } + + [[nodiscard]] inline constexpr socket_send_awaitable socket::send(const void* buffer, std::size_t size, + std::error_code& ec) noexcept { + return socket_send_awaitable(*this, buffer, size, &ec); + } + + [[nodiscard]] inline constexpr socket_recv_awaitable socket::recv(void* buffer, std::size_t size) noexcept { + return socket_recv_awaitable(*this, buffer, size); + } + + [[nodiscard]] inline constexpr socket_recv_awaitable socket::recv(void* buffer, std::size_t size, + std::error_code& ec) noexcept { + return socket_recv_awaitable(*this, buffer, size, &ec); + } + + [[nodiscard]] inline constexpr socket_recv_exact_awaitable socket::recv_exact(void* buffer, + std::size_t size) noexcept { + return socket_recv_exact_awaitable(*this, buffer, size); + } + + [[nodiscard]] inline constexpr socket_recv_exact_awaitable socket::recv_exact(void* buffer, std::size_t size, + std::error_code& ec) noexcept { + return socket_recv_exact_awaitable(*this, buffer, size, &ec); + } + + [[nodiscard]] inline constexpr socket_send_to_awaitable socket::send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep) noexcept { + return socket_send_to_awaitable(*this, buffer, size, dst_ep); + } + + [[nodiscard]] inline constexpr socket_send_to_awaitable + socket::send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, std::error_code& ec) noexcept { + return socket_send_to_awaitable(*this, buffer, size, dst_ep, &ec); + } + + [[nodiscard]] inline constexpr socket_recv_from_awaitable socket::recv_from(void* buffer, + std::size_t size) noexcept { + return socket_recv_from_awaitable(*this, buffer, size); + } + + [[nodiscard]] inline constexpr socket_recv_from_awaitable socket::recv_from(void* buffer, std::size_t size, + std::error_code& ec) noexcept { + return socket_recv_from_awaitable(*this, buffer, size, &ec); + } + + [[nodiscard]] inline socket_connect_cancellable_awaitable socket::connect(const endpoint& ep, + asyncpp::stop_token st) noexcept { + return socket_connect_cancellable_awaitable(std::move(st), *this, ep); + } + + [[nodiscard]] inline socket_connect_cancellable_awaitable + socket::connect(const endpoint& ep, asyncpp::stop_token st, std::error_code& ec) noexcept { + return socket_connect_cancellable_awaitable(std::move(st), *this, ep, &ec); + } + + [[nodiscard]] inline socket_accept_cancellable_awaitable socket::accept(asyncpp::stop_token st) noexcept { + return socket_accept_cancellable_awaitable(std::move(st), *this); + } + + [[nodiscard]] inline socket_accept_error_code_cancellable_awaitable socket::accept(asyncpp::stop_token st, + std::error_code& ec) noexcept { + return socket_accept_error_code_cancellable_awaitable(std::move(st), *this, ec); + } + + [[nodiscard]] inline socket_send_cancellable_awaitable socket::send(const void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept { + return socket_send_cancellable_awaitable(std::move(st), *this, buffer, size); + } + + [[nodiscard]] inline socket_send_cancellable_awaitable + socket::send(const void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept { + return socket_send_cancellable_awaitable(std::move(st), *this, buffer, size, &ec); + } + + [[nodiscard]] inline socket_recv_cancellable_awaitable socket::recv(void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept { + return socket_recv_cancellable_awaitable(std::move(st), *this, buffer, size); + } + + [[nodiscard]] inline socket_recv_cancellable_awaitable + socket::recv(void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept { + return socket_recv_cancellable_awaitable(std::move(st), *this, buffer, size, &ec); + } + + [[nodiscard]] inline socket_recv_exact_cancellable_awaitable socket::recv_exact(void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept { + return socket_recv_exact_cancellable_awaitable(std::move(st), *this, buffer, size); + } + + [[nodiscard]] inline socket_recv_exact_cancellable_awaitable + socket::recv_exact(void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept { + return socket_recv_exact_cancellable_awaitable(std::move(st), *this, buffer, size, &ec); + } + + [[nodiscard]] inline socket_send_to_cancellable_awaitable + socket::send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, asyncpp::stop_token st) noexcept { + return socket_send_to_cancellable_awaitable(std::move(st), *this, buffer, size, dst_ep); + } + + [[nodiscard]] inline socket_send_to_cancellable_awaitable socket::send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep, + asyncpp::stop_token st, + std::error_code& ec) noexcept { + return socket_send_to_cancellable_awaitable(std::move(st), *this, buffer, size, dst_ep, &ec); + } + + [[nodiscard]] inline socket_recv_from_cancellable_awaitable socket::recv_from(void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept { + return socket_recv_from_cancellable_awaitable(std::move(st), *this, buffer, size); + } + + [[nodiscard]] inline socket_recv_from_cancellable_awaitable + socket::recv_from(void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept { + return socket_recv_from_cancellable_awaitable(std::move(st), *this, buffer, size, &ec); + } + + inline bool socket_connect_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_connect(m_socket.native_handle(), m_ep, &m_completion); + } + + inline void socket_connect_awaitable::await_resume() { + if (m_completion.result >= 0) return; + if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + } + + inline bool socket_send_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_send(m_socket.native_handle(), m_buffer, m_size, &m_completion); + } + + inline void socket_send_awaitable::await_resume() { + if (m_completion.result >= 0) return; + if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + } + + inline bool socket_recv_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_recv(m_socket.native_handle(), m_buffer, m_size, &m_completion); + } + + inline size_t socket_recv_awaitable::await_resume() { + if (m_completion.result >= 0) return static_cast(m_completion.result); + if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + return 0; + } + + inline bool socket_recv_exact_awaitable::await_suspend(asyncpp::coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { + auto that = static_cast(ptr); + auto engine = that->m_socket.service().engine(); + do { + if (that->m_completion.result <= 0) { + that->m_handle.resume(); + break; + } + that->m_buffer += that->m_completion.result; + that->m_remaining -= that->m_completion.result; + if (that->m_remaining == 0) { + that->m_handle.resume(); + break; + } + } while (engine->enqueue_recv(that->m_socket.native_handle(), that->m_buffer, that->m_remaining, + &that->m_completion)); + }; + m_completion.userdata = this; + m_handle = hdl; + auto engine = m_socket.service().engine(); + while (engine->enqueue_recv(m_socket.native_handle(), m_buffer, m_remaining, &m_completion)) { + if (m_completion.result <= 0) return false; + m_buffer += m_completion.result; + m_remaining -= m_completion.result; + if (m_remaining == 0) return false; + } + return true; + } + + inline size_t socket_recv_exact_awaitable::await_resume() { + if (m_completion.result >= 0) return m_size - m_remaining; + if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + return m_size - m_remaining; + } + + inline bool socket_accept_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_accept(m_socket.native_handle(), &m_completion); + } + + inline socket socket_accept_awaitable::await_resume() { + if (m_completion.result < 0) + throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + return socket::from_fd(m_socket.service(), m_completion.result); + } + + inline bool socket_accept_error_code_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_accept(m_socket.native_handle(), &m_completion); + } + + inline std::optional socket_accept_error_code_awaitable::await_resume() { + if (m_completion.result >= 0) return socket::from_fd(m_socket.service(), m_completion.result); + m_ec = std::error_code(-m_completion.result, std::system_category()); + return std::nullopt; + } + + inline bool socket_send_to_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_send_to(m_socket.native_handle(), m_buffer, m_size, m_destination, + &m_completion); + } + + inline size_t socket_send_to_awaitable::await_resume() { + if (m_completion.result >= 0) return static_cast(m_completion.result); + if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + return 0; + } + + inline bool socket_recv_from_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; + m_completion.userdata = hdl.address(); + return !m_socket.service().engine()->enqueue_recv_from(m_socket.native_handle(), m_buffer, m_size, &m_source, + &m_completion); + } + + inline std::pair socket_recv_from_awaitable::await_resume() { + if (m_completion.result >= 0) return {static_cast(m_completion.result), m_source}; + if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); + *m_ec = std::error_code(-m_completion.result, std::system_category()); + return {}; + } + + template + requires(std::is_invocable_v) + inline void socket::connect(const endpoint& ep, FN&& cb, asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN real_cb; + asyncpp::stop_callback stop_cb; + + data(FN&& cb, asyncpp::stop_token st, detail::io_engine* engine) + : completion_data{&handle, this}, real_cb(std::move(cb)), + stop_cb(std::move(st), detail::cancel_io_stop_callback{this, engine}) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + that->real_cb(that->result < 0 ? std::error_code(-that->result, std::system_category()) + : std::error_code()); + delete that; + }; + }; + auto info = new data(std::move(cb), std::move(st), service().engine()); + if (service().engine()->enqueue_connect(native_handle(), ep, info)) { data::handle(info); } + } + + template + requires(std::is_invocable_v>) + inline void socket::accept(FN&& cb, asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN real_cb; + io_service& service; + asyncpp::stop_callback stop_cb; + + data(FN&& cb, asyncpp::stop_token st, io_service& s) + : completion_data{&handle, this}, real_cb(std::move(cb)), service(s), + stop_cb(std::move(st), detail::cancel_io_stop_callback{this, s.engine()}) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + if (that->result < 0) + that->real_cb(std::error_code(-that->result, std::system_category())); + else + that->real_cb(socket::from_fd(that->service(), that->result)); + + delete that; + }; + }; + auto info = new data(std::move(cb), std::move(st), service()); + if (service().engine()->enqueue_accept(native_handle(), info)) { data::handle(info); } + } + + template + requires(std::is_invocable_v) + inline void socket::send(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN real_cb; + asyncpp::stop_callback stop_cb; + + data(FN&& cb, asyncpp::stop_token st, detail::io_engine* engine) + : completion_data{&handle, this}, real_cb(std::move(cb)), + stop_cb(std::move(st), detail::cancel_io_stop_callback{this, engine}) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + if (that->result < 0) + that->real_cb(0, std::error_code(-that->result, std::system_category())); + else + that->real_cb(that->result, {}); + + delete that; + }; + }; + auto info = new data(std::move(cb), std::move(st), service().engine()); + if (service().engine()->enqueue_send(native_handle(), buffer, size, info)) { data::handle(info); } + } + + template + requires(std::is_invocable_v) + inline void socket::recv(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN real_cb; + asyncpp::stop_callback stop_cb; + + data(FN&& cb, asyncpp::stop_token st, detail::io_engine* engine) + : completion_data{&handle, this}, real_cb(std::move(cb)), + stop_cb(std::move(st), detail::cancel_io_stop_callback{this, engine}) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + if (that->result < 0) + that->real_cb(0, std::error_code(-that->result, std::system_category())); + else + that->real_cb(that->result, {}); + + delete that; + }; + }; + auto info = new data(std::move(cb), std::move(st), service().engine()); + if (service().engine()->enqueue_recv(native_handle(), buffer, size, info)) { data::handle(info); } + } + + template + requires(std::is_invocable_v) + inline void socket::send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, FN&& cb, + asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN real_cb; + asyncpp::stop_callback stop_cb; + + data(FN&& cb, asyncpp::stop_token st, detail::io_engine* engine) + : completion_data{&handle, this}, real_cb(std::move(cb)), + stop_cb(std::move(st), detail::cancel_io_stop_callback{this, engine}) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + if (that->result < 0) + that->real_cb(0, std::error_code(-that->result, std::system_category())); + else + that->real_cb(that->result, {}); + + delete that; + }; + }; + auto info = new data(std::move(cb), std::move(st), service().engine()); + if (service().engine()->enqueue_send_to(native_handle(), buffer, size, dst_ep, info)) { data::handle(info); } + } + + template + requires(std::is_invocable_v) + inline void socket::recv_from(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN real_cb; + endpoint source; + asyncpp::stop_callback stop_cb; + + data(FN&& cb, asyncpp::stop_token st, detail::io_engine* engine) + : completion_data{&handle, this}, real_cb(std::move(cb)), + stop_cb(std::move(st), detail::cancel_io_stop_callback{this, engine}) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + if (that->result < 0) + that->real_cb(0, {}, std::error_code(-that->result, std::system_category())); + else + that->real_cb(that->result, that->source, {}); + + delete that; + }; + }; + auto info = new data(std::move(cb), std::move(st), service().engine()); + if (service().engine()->enqueue_recv_from(native_handle(), buffer, size, &info->source, info)) { + data::handle(info); + } + } + +} // namespace asyncpp::io diff --git a/include/asyncpp/io/tls.h b/include/asyncpp/io/tls.h new file mode 100644 index 0000000..8458440 --- /dev/null +++ b/include/asyncpp/io/tls.h @@ -0,0 +1,399 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace asyncpp::io::tls { + enum class method { + tls, + sslv3, + tlsv1, + tlsv1_1, + tlsv1_2, + dtls, + dtlsv1, + dtlsv1_2, + }; + + enum class mode { server, client }; + + enum class file_type { pem }; + + enum class verify_mode { none = 0, peer = 1, fail_if_no_cert = 2, verify_once = 4, verify_post_handshake = 8 }; + inline verify_mode operator|(verify_mode a, verify_mode b) noexcept { + using type = std::underlying_type_t; + return static_cast(static_cast(a) | static_cast(b)); + } + inline bool operator&(verify_mode a, verify_mode b) noexcept { + using type = std::underlying_type_t; + return (static_cast(a) & static_cast(b)) != 0; + } + + class session; + class cipher; + class x509; + class context { + public: + class client_hello { + friend class context; + void* m_ssl{}; + mutable std::vector m_ciphers; + mutable std::vector m_signalling_ciphers; + mutable std::set m_extensions_preset; + + client_hello(void* ssl) noexcept; + client_hello(const client_hello& other) = default; + client_hello& operator=(const client_hello& other) = default; + + public: + bool is_v2() const noexcept; + std::span random() const noexcept; + std::span session_id() const noexcept; + const std::vector& ciphers() const; + const std::vector& signalling_ciphers() const; + std::span compression_methods() const noexcept; + const std::set& extensions() const; + bool has_extension(unsigned int type) const; + std::span extension(unsigned int type) const noexcept; + void replace_context(context& new_context) const noexcept; + + session* get_session() const noexcept; + + // Helpers + std::string_view server_name_indication() const noexcept; + }; + + private: + const method m_method{}; + const mode m_mode{}; + void* m_ctx{}; + std::function m_passwd_cb; + std::function m_client_hello_cb; + std::function& protocols)> + m_alpn_select_cb; + std::function m_cert_cb; + + friend class session; + + public: + context(method meth = method::tls, mode m = mode::client); + context(const context&) = delete; + context(context&&) = delete; + context& operator=(const context&) = delete; + context& operator=(context&&) = delete; + ~context(); + + method get_method() const noexcept { return m_method; } + mode get_mode() const noexcept { return m_mode; } + + void use_certificate(const std::string& file, file_type type = file_type::pem); + void use_privatekey(const std::string& file, file_type type = file_type::pem); + void set_passwd_callback(std::function cb); + void set_passwd(std::string passwd); + void set_client_hello_callback(std::function cb); + void set_verify(verify_mode mode); + void set_default_verify_paths(); + void set_default_verify_dir(); + void set_default_verify_file(); + void load_verify_locations(const std::string& file, const std::string& path); + std::vector ciphers() const; + void set_alpn_protos(const std::vector& protos); + void set_alpn_select_callback(std::function& protocols)> + cb); + void set_certificate_callback(std::function cb); + + std::vector get_chain_certs() const; + void clear_chain_certs(); + + void debug(); + }; + + class cipher_write_awaitable; + class cipher_read_awaitable; + class plain_write_awaitable; + class plain_read_awaitable; + class handshake_awaitable; + class session { + void* m_ssl{}; + void* m_input_bio{}; + void* m_output_bio{}; + std::function m_cert_cb{}; + + cipher_read_awaitable* m_cipher_readers{}; + cipher_write_awaitable* m_cipher_writers{}; + plain_read_awaitable* m_plain_readers{}; + plain_write_awaitable* m_plain_writers{}; + handshake_awaitable* m_handshakers{}; + + friend class cipher_write_awaitable; + friend class cipher_read_awaitable; + friend class plain_write_awaitable; + friend class plain_read_awaitable; + friend class handshake_awaitable; + + bool try_resume_cipher_read(); + bool try_resume_cipher_write(); + void try_resume_plain(); + + public: + session(const context& ctx); + session(const session&) = delete; + session(session&&) = delete; + session& operator=(const session&) = delete; + session& operator=(session&&) = delete; + ~session(); + + int try_handshake(); + [[nodiscard]] handshake_awaitable handshake() noexcept; + + void shutdown(); + + [[nodiscard]] plain_write_awaitable write(const void* buffer, size_t len) noexcept; + [[nodiscard]] plain_read_awaitable read(void* buf, size_t len) noexcept; + [[nodiscard]] bool try_write(const void* buffer, size_t len, size_t& read); + [[nodiscard]] bool try_read(void* buf, size_t len, size_t& written); + + [[nodiscard]] cipher_write_awaitable cipher_write(const void* buffer, size_t len) noexcept; + [[nodiscard]] cipher_read_awaitable cipher_read(void* buffer, size_t len) noexcept; + + cipher current_cipher() const noexcept; + cipher pending_cipher() const noexcept; + std::vector ciphers() const; + std::vector supported_ciphers() const; + std::vector client_ciphers() const; + std::string_view get_servername() const noexcept; + void set_servername(const std::string& name); + void set_verify(verify_mode mode); + void set_alpn_protos(const std::vector& protos); + std::string_view alpn_selected() const noexcept; + void set_certificate_callback(std::function cb); + x509 get_peer_certificate() const noexcept; + }; + + class cipher { + const void* m_cipher{}; + mutable char* m_description{}; + + friend class context; + friend class session; + + constexpr cipher(const void* ptr) noexcept : m_cipher(ptr) {} + + public: + constexpr cipher() noexcept = default; + constexpr cipher(const cipher& other) noexcept : m_cipher(other.m_cipher) {} + constexpr cipher& operator=(const cipher& other) noexcept { + m_cipher = other.m_cipher; + return *this; + } + ~cipher(); + + bool operator!() const noexcept { return m_cipher == nullptr; } + operator bool() const noexcept { return m_cipher != nullptr; } + bool valid() const noexcept { return m_cipher != nullptr; } + + std::string_view name() const noexcept; + std::string_view standard_name() const noexcept; + std::string_view cipher_name() const noexcept; + size_t bit_count() const noexcept; + std::string_view version() const noexcept; + std::string_view description() const noexcept; + int cipher_nid() const noexcept; + int digest_nid() const noexcept; + int kx_nid() const noexcept; + int auth_nid() const noexcept; + bool is_aead() const noexcept; + uint32_t id() const noexcept; + uint32_t protocol_id() const noexcept; + }; + std::ostream& operator<<(std::ostream& str, const cipher& cipher); + + class x509 { + void* m_x509; + friend class context; + friend class session; + + constexpr x509(void* x509) noexcept : m_x509(x509) {} + + public: + constexpr x509(x509&& other) noexcept : m_x509(other.m_x509) { other.m_x509 = nullptr; } + constexpr x509& operator=(x509&& other) noexcept { + m_x509 = other.m_x509; + other.m_x509 = nullptr; + return *this; + } + ~x509(); + + bool operator!() const noexcept { return m_x509 == nullptr; } + operator bool() const noexcept { return m_x509 != nullptr; } + bool valid() const noexcept { return m_x509 != nullptr; } + + std::string to_der() const; + std::string to_pem() const; + + static x509 from_der(const void* ptr, size_t len); + static x509 from_pem(const void* ptr, size_t len); + + std::chrono::system_clock::time_point not_before() const noexcept; + std::chrono::system_clock::time_point not_after() const noexcept; + std::string subject() const; + std::string issuer() const; + + friend std::strong_ordering operator<=>(const x509& lhs, const x509& rhs) noexcept; + friend bool operator==(const x509& lhs, const x509& rhs) noexcept; + friend bool operator!=(const x509& lhs, const x509& rhs) noexcept; + }; + + class plain_read_awaitable { + plain_read_awaitable(const plain_read_awaitable&) = delete; + plain_read_awaitable(plain_read_awaitable&&) = delete; + plain_read_awaitable& operator=(const plain_read_awaitable&) = delete; + plain_read_awaitable& operator=(plain_read_awaitable&&) = delete; + friend class session; + + session& m_session; + void* const m_buffer; + size_t const m_len; + size_t m_result{}; + coroutine_handle<> m_handle{}; + plain_read_awaitable* m_next{}; + + bool try_resume(); + + public: + constexpr plain_read_awaitable(session& sess, void* buffer, size_t len) noexcept + : m_session{sess}, m_buffer{buffer}, m_len{len} {} + + bool await_ready() const noexcept; + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + + class plain_write_awaitable { + plain_write_awaitable(const plain_write_awaitable&) = delete; + plain_write_awaitable(plain_write_awaitable&&) = delete; + plain_write_awaitable& operator=(const plain_write_awaitable&) = delete; + plain_write_awaitable& operator=(plain_write_awaitable&&) = delete; + friend class session; + + session& m_session; + const void* const m_buffer; + size_t const m_len; + size_t m_result{}; + coroutine_handle<> m_handle{}; + plain_write_awaitable* m_next{}; + + bool try_resume(); + + public: + constexpr plain_write_awaitable(session& sess, const void* buffer, size_t len) noexcept + : m_session{sess}, m_buffer{buffer}, m_len{len} {} + + bool await_ready() const noexcept; + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + + class cipher_read_awaitable { + cipher_read_awaitable(const cipher_read_awaitable&) = delete; + cipher_read_awaitable(cipher_read_awaitable&&) = delete; + cipher_read_awaitable& operator=(const cipher_read_awaitable&) = delete; + cipher_read_awaitable& operator=(cipher_read_awaitable&&) = delete; + friend class session; + + session& m_session; + void* const m_buffer; + size_t const m_len; + size_t m_result{}; + coroutine_handle<> m_handle{}; + cipher_read_awaitable* m_next{}; + + bool try_resume(); + + public: + constexpr cipher_read_awaitable(session& sess, void* buffer, size_t len) noexcept + : m_session{sess}, m_buffer{buffer}, m_len{len} {} + + bool await_ready() const noexcept; + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + + class cipher_write_awaitable { + cipher_write_awaitable(const cipher_write_awaitable&) = delete; + cipher_write_awaitable(cipher_write_awaitable&&) = delete; + cipher_write_awaitable& operator=(const cipher_write_awaitable&) = delete; + cipher_write_awaitable& operator=(cipher_write_awaitable&&) = delete; + friend class session; + + session& m_session; + const void* const m_buffer; + size_t const m_len; + size_t m_result{}; + coroutine_handle<> m_handle{}; + cipher_write_awaitable* m_next{}; + + bool try_resume(); + + public: + constexpr cipher_write_awaitable(session& sess, const void* buffer, size_t len) noexcept + : m_session{sess}, m_buffer{buffer}, m_len{len} {} + + bool await_ready() const noexcept; + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + + class handshake_awaitable { + handshake_awaitable(const handshake_awaitable&) = delete; + handshake_awaitable(handshake_awaitable&&) = delete; + handshake_awaitable& operator=(const handshake_awaitable&) = delete; + handshake_awaitable& operator=(handshake_awaitable&&) = delete; + friend class session; + + session& m_session; + int m_result{}; + coroutine_handle<> m_handle{}; + handshake_awaitable* m_next{}; + + bool try_resume(); + + public: + constexpr handshake_awaitable(session& sess) noexcept : m_session{sess} {} + + bool await_ready() const noexcept; + bool await_suspend(coroutine_handle<> hdl); + void await_resume(); + }; + + [[nodiscard]] inline handshake_awaitable session::handshake() noexcept { return handshake_awaitable(*this); } + + [[nodiscard]] inline plain_write_awaitable session::write(const void* buffer, size_t len) noexcept { + return plain_write_awaitable(*this, buffer, len); + } + + [[nodiscard]] inline plain_read_awaitable session::read(void* buffer, size_t len) noexcept { + return plain_read_awaitable(*this, buffer, len); + } + + [[nodiscard]] inline cipher_write_awaitable session::cipher_write(const void* buffer, size_t len) noexcept { + return cipher_write_awaitable(*this, buffer, len); + } + + [[nodiscard]] inline cipher_read_awaitable session::cipher_read(void* buffer, size_t len) noexcept { + return cipher_read_awaitable(*this, buffer, len); + } + +} // namespace asyncpp::io::tls diff --git a/src/address.cpp b/src/address.cpp new file mode 100644 index 0000000..289f148 --- /dev/null +++ b/src/address.cpp @@ -0,0 +1,178 @@ +#include +#include + +#include +#include + +#ifndef _WIN32 +#include +#include +#else +#include +#include +#endif + +namespace asyncpp::io { + ipv4_address::ipv4_address(const sockaddr_storage& addr) { + if (addr.ss_family != AF_INET) throw std::invalid_argument("addr does not contain a valid ipv4 ip"); + *this = ipv4_address(*reinterpret_cast(&addr)); + } + + ipv4_address::ipv4_address(const sockaddr_in& addr) noexcept + : ipv4_address(addr.sin_addr.s_addr, std::endian::native) {} + + std::pair ipv4_address::to_sockaddr() const noexcept { + sockaddr_storage res{}; + res.ss_family = AF_INET; + memcpy(&reinterpret_cast(&res)->sin_addr.s_addr, data().data(), 4); + return {res, sizeof(sockaddr_in)}; + } + + std::pair ipv4_address::to_sockaddr_in() const noexcept { + sockaddr_in res{}; + res.sin_family = AF_INET; + memcpy(&res.sin_addr.s_addr, data().data(), 4); + return {res, sizeof(sockaddr_in)}; + } + + ipv6_address::ipv6_address(const sockaddr_storage& addr) { + if (addr.ss_family != AF_INET6) throw std::invalid_argument("addr does not contain a valid ipv6 ip"); + *this = ipv6_address(*reinterpret_cast(&addr)); + } + + ipv6_address::ipv6_address(const sockaddr_in6& addr) noexcept : ipv6_address(addr.sin6_addr.s6_addr) {} + + std::pair ipv6_address::to_sockaddr() const noexcept { + sockaddr_storage res{}; + res.ss_family = AF_INET6; + memcpy(&reinterpret_cast(&res)->sin6_addr.s6_addr, data().data(), 16); + return {res, sizeof(sockaddr_in6)}; + } + + std::pair ipv6_address::to_sockaddr_in6() const noexcept { + sockaddr_in6 res{}; + res.sin6_family = AF_INET6; + memcpy(&res.sin6_addr.s6_addr, data().data(), 16); + return {res, sizeof(sockaddr_in6)}; + } + +#ifndef _WIN32 + uds_address::uds_address(const sockaddr_storage& addr, size_t len) { + if (addr.ss_family != AF_UNIX) throw std::invalid_argument("addr does not contain a valid ipv6 ip"); + *this = uds_address(*reinterpret_cast(&addr), len); + } + + uds_address::uds_address(const sockaddr_un& addr, size_t len) noexcept { + memcpy(m_data.data(), addr.sun_path, (std::min)(sizeof(sockaddr_un::sun_path), m_data.size())); + m_len = len - offsetof(struct sockaddr_un, sun_path); + // If it is not abstract remove trailing zeros. This is the same behavior the linux kernel has. + if (m_data[0] != '\0') { + while (m_len && m_data[m_len - 1] == '\0') + m_len--; + } + } + + std::pair uds_address::to_sockaddr() const noexcept { + sockaddr_storage res{}; + res.ss_family = AF_UNIX; + memcpy(reinterpret_cast(&res)->sun_path, m_data.data(), + (std::min)(sizeof(sockaddr_un::sun_path), m_len)); + return {res, sizeof(sockaddr_un)}; + } + + std::pair uds_address::to_sockaddr_un() const noexcept { + sockaddr_un res{}; + res.sun_family = AF_UNIX; + memcpy(res.sun_path, m_data.data(), (std::min)(sizeof(sockaddr_un::sun_path), m_len)); + return {res, sizeof(sockaddr_un)}; + } +#endif + + address::address(const sockaddr_storage& addr, size_t len) { + if (addr.ss_family == AF_INET) + *this = address(ipv4_address(*reinterpret_cast(&addr))); + else if (addr.ss_family == AF_INET6) + *this = address(ipv6_address(*reinterpret_cast(&addr))); +#ifndef _WIN32 + else if (addr.ss_family == AF_UNIX) + *this = address(uds_address(*reinterpret_cast(&addr), len)); +#endif + else + throw std::invalid_argument("addr is not af_inet or af_inet6"); + } + + std::pair address::to_sockaddr() const noexcept { + return is_ipv4() ? ipv4().to_sockaddr() : ipv6().to_sockaddr(); + } + + ipv4_endpoint::ipv4_endpoint(const sockaddr_storage& addr) { + if (addr.ss_family != AF_INET) throw std::invalid_argument("addr does not contain a valid ipv4 ip"); + auto ep = reinterpret_cast(&addr); + m_ip = ipv4_address(*ep); + m_port = htons(ep->sin_port); + } + + ipv4_endpoint::ipv4_endpoint(const sockaddr_in& addr) noexcept { + m_ip = ipv4_address(addr); + m_port = htons(addr.sin_port); + } + + std::pair ipv4_endpoint::to_sockaddr() const noexcept { + auto res = m_ip.to_sockaddr(); + reinterpret_cast(&res.first)->sin_port = htons(m_port); + return res; + } + + std::pair ipv4_endpoint::to_sockaddr_in() const noexcept { + auto res = m_ip.to_sockaddr_in(); + res.first.sin_port = htons(m_port); + return res; + } + + ipv6_endpoint::ipv6_endpoint(const sockaddr_storage& addr) { + if (addr.ss_family != AF_INET6) throw std::invalid_argument("addr does not contain a valid ipv6 ip"); + auto ep = reinterpret_cast(&addr); + m_ip = ipv6_address(*ep); + m_port = htons(ep->sin6_port); + } + + ipv6_endpoint::ipv6_endpoint(const sockaddr_in6& addr) noexcept { + m_ip = ipv6_address(addr); + m_port = htons(addr.sin6_port); + } + + std::pair ipv6_endpoint::to_sockaddr() const noexcept { + auto res = m_ip.to_sockaddr(); + reinterpret_cast(&res.first)->sin6_port = htons(m_port); + return res; + } + + std::pair ipv6_endpoint::to_sockaddr_in6() const noexcept { + auto res = m_ip.to_sockaddr_in6(); + res.first.sin6_port = htons(m_port); + return res; + } + + endpoint::endpoint(const sockaddr_storage& addr, size_t len) { + if (addr.ss_family == AF_INET) + *this = endpoint(ipv4_endpoint(*reinterpret_cast(&addr))); + else if (addr.ss_family == AF_INET6) + *this = endpoint(ipv6_endpoint(*reinterpret_cast(&addr))); +#ifndef _WIN32 + else if (addr.ss_family == AF_UNIX) + *this = endpoint(uds_endpoint(*reinterpret_cast(&addr), len)); +#endif + else + throw std::invalid_argument("addr is not af_inet or af_inet6"); + } + + std::pair endpoint::to_sockaddr() const noexcept { + switch (m_type) { + case address_type::ipv4: return m_ipv4.to_sockaddr(); + case address_type::ipv6: return m_ipv6.to_sockaddr(); + case address_type::uds: return m_uds.to_sockaddr(); + } + return {}; + } + +} // namespace asyncpp::io diff --git a/src/block_allocator.h b/src/block_allocator.h new file mode 100644 index 0000000..d57901c --- /dev/null +++ b/src/block_allocator.h @@ -0,0 +1,106 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include + +namespace asyncpp::io::detail { + + template + class block_allocator { + + struct page { + page* next_page{}; + uint64_t usage{}; + alignas(T) std::array storage{}; + }; + + TMutex m_mtx{}; + page* m_first_page{}; + + public: + constexpr block_allocator() noexcept = default; + block_allocator(const block_allocator&) = delete; + block_allocator& operator=(const block_allocator&) = delete; + ~block_allocator() noexcept { + auto p = m_first_page; + while (p != nullptr) { + auto ptr = p; + p = p->next_page; + assert(ptr->usage == 0); + delete ptr; + } + } + void* allocate() noexcept { + std::unique_lock lck{m_mtx}; + page* p = m_first_page; + page** page_ptr = &m_first_page; + while (p != nullptr) { + if (p->usage != std::numeric_limits::max()) { + auto free_block = std::countr_one(p->usage); + assert(free_block < 64 && free_block >= 0); + p->usage |= (static_cast(1) << free_block); +#if ASYNCPP_HAS_ASAN + __asan_unpoison_memory_region(p->storage.data() + sizeof(T) * free_block, sizeof(T)); +#endif + return p->storage.data() + sizeof(T) * free_block; + } + page_ptr = &p->next_page; + p = p->next_page; + } + // No free blocks left + p = *page_ptr = new (std::nothrow) page{}; + if (p == nullptr) return nullptr; + p->usage |= 1; +#if ASYNCPP_HAS_ASAN + __asan_poison_memory_region(p->storage.data() + sizeof(T), p->storage.size() - sizeof(T)); +#endif + return p->storage.data(); + } + void deallocate(void* ptr) noexcept { + std::unique_lock lck{m_mtx}; + page* p = m_first_page; + while (p != nullptr) { + if (ptr >= p->storage.data() && ptr < p->storage.data() + p->storage.size()) { +#if ASYNCPP_HAS_ASAN + __asan_poison_memory_region(ptr, sizeof(T)); +#endif + const auto offset = static_cast(ptr) - p->storage.data(); + assert(offset % sizeof(T) == 0); + assert(offset < sizeof(T) * 64); + const auto idx = offset / sizeof(T); + assert((p->usage & static_cast(1) << idx) != 0); + p->usage &= ~(static_cast(1) << idx); + return; + } + p = p->next_page; + } + } + template + T* create(Args&&... args) { + auto ptr = allocate(); + if (ptr == nullptr) return nullptr; + if constexpr (std::is_nothrow_constructible_v) { + return new (ptr) T(std::forward(args)...); + } else { + try { + return new (ptr) T(std::forward(args)...); + } catch (...) { + this->deallocate(ptr); + throw; + } + } + // unreachable + } + void destroy(T* obj) { + if (obj != nullptr) { + obj->~T(); + this->deallocate(obj); + } + } + }; +} // namespace asyncpp::io::detail diff --git a/src/dns.cpp b/src/dns.cpp new file mode 100644 index 0000000..02245b6 --- /dev/null +++ b/src/dns.cpp @@ -0,0 +1,1057 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define FAIL_WITH(x, y) \ + { \ + ec = make_error_code(x); \ + return y; \ + } +#define FAIL(x) \ + { \ + ec = make_error_code(x); \ + return; \ + } + +namespace { + template + std::array calculate_md5(const TKey& key, const T&... vals) { + auto mdctx = EVP_MD_CTX_new(); + asyncpp::scope_guard cleanup{[mdctx]() noexcept { EVP_MD_CTX_free(mdctx); }}; + if (EVP_DigestInit_ex(mdctx, EVP_md5(), NULL) != 1) throw std::runtime_error("failed to init mac"); + if (((EVP_DigestUpdate(mdctx, vals.data(), vals.size()) != 1) || ...)) + throw std::runtime_error("failed to update mac"); + std::array md{}; + if (EVP_DigestFinal_ex(mdctx, md.data(), NULL) != 1) throw std::runtime_error("failed to finish mac"); + return md; + } +} // namespace + +namespace asyncpp::io::dns { + + const std::error_category& error_category() noexcept { + class dns_category final : public std::error_category { + const char* name() const noexcept override { return "dns"; } + std::string message(int code) const override { + switch (static_cast(code)) { + case api_error::ok: return "ok"; + case api_error::not_enough_space: return "not enough space"; + case api_error::label_invalid: return "label invalid"; + case api_error::label_too_long: return "label too long"; + case api_error::incomplete_message: return "incomplete or invalid message"; + case api_error::recursion_limit_exceeded: return "recursion limit exceeded"; + case api_error::extra_data: return "extra data in message"; + case api_error::duplicate_id: return "duplicate_id"; + case api_error::no_id: return "no_id"; + case api_error::cancelled: return "cancelled"; + case api_error::timeout: return "timeout"; + case api_error::internal: return "internal"; + default: return ""; + } + } + }; + static const dns_category instance; + return instance; + } + + size_t convert_name(void* out, size_t outlen, std::string_view name, std::error_code& ec) noexcept { + const size_t required_space = (name.empty() ? 1 : 2) + name.size(); + if (outlen < required_space) FAIL_WITH(api_error::not_enough_space, required_space); + auto size = static_cast(out); + *size = 0; + auto ptr = size + 1; + for (auto e : name) { + if (ptr >= static_cast(out) + outlen) FAIL_WITH(api_error::not_enough_space, required_space); + if (e == '.') { + if (*size == 0) FAIL_WITH(api_error::label_invalid, required_space); + size = ptr; + *size = 0; + } else { + (*size)++; + if (*size > max_label_size) { FAIL_WITH(api_error::label_too_long, required_space); } + *ptr = e; + } + ptr++; + } + *ptr = 0; + return required_space; + } + + const std::byte* parse_label(std::span msg, const std::byte* plabel, std::string& res, + std::error_code& ec) noexcept { + auto pmsg = msg.data(); + + const std::byte* first_label_end = nullptr; + + size_t depth = 0; + + while (plabel < pmsg + msg.size()) { + if (static_cast(*plabel) == 0) { + return first_label_end ? first_label_end : (plabel + 1); + } else if (static_cast(*plabel) & 0xc0) { + if (plabel + 1 >= pmsg + msg.size()) FAIL_WITH(api_error::incomplete_message, nullptr); + auto offset = raw_get(plabel) & 0x3fff; + if (pmsg + offset >= plabel) FAIL_WITH(api_error::incomplete_message, nullptr); + first_label_end = plabel + 2; + plabel = pmsg + offset; + if (depth++ == 128) FAIL_WITH(api_error::recursion_limit_exceeded, nullptr); + } else { + if (plabel + static_cast(*plabel) >= pmsg + msg.size()) + FAIL_WITH(api_error::incomplete_message, nullptr); + if (!res.empty()) res += '.'; + res.resize(res.size() + static_cast(*plabel)); + memcpy(res.data() + res.size() - static_cast(*plabel), plabel + 1, + static_cast(*plabel)); + plabel += static_cast(*plabel) + 1; + } + } + FAIL_WITH(api_error::incomplete_message, nullptr); + } + + ipv4_address parse_a(const_buffer data, const_buffer msg, std::error_code& ec) noexcept { + if (data.size() < 4) { + ec = make_error_code(api_error::incomplete_message); + return ipv4_address(); + } else { + return ipv4_address(static_cast(data[0]), static_cast(data[1]), + static_cast(data[2]), static_cast(data[3])); + } + } + + ipv6_address parse_aaaa(const_buffer data, const_buffer msg, std::error_code& ec) noexcept { + if (data.size() < 16) { + ec = make_error_code(api_error::incomplete_message); + return ipv6_address(); + } else { + return ipv6_address( + std::span(reinterpret_cast(data.data()), data.size())); + } + } + + std::string parse_txt(const_buffer data, const_buffer msg, std::error_code& ec) noexcept { + auto ptr = data.data(); + auto end = data.data() + data.size(); + std::string res; + while (ptr < end) { + auto len = static_cast(*ptr); + if (ptr + 1 + len > end) { + res.append(reinterpret_cast(ptr + 1), std::distance(ptr + 1, end)); + ec = make_error_code(api_error::incomplete_message); + return res; + } else { + res.append(reinterpret_cast(ptr + 1), len); + } + ptr += 1 + len; + } + return res; + } + + std::string parse_cname(const_buffer data, const_buffer msg, std::error_code& ec) noexcept { + std::string res; + parse_label(msg, data.data(), res, ec); + return res; + } + + std::string parse_ns(const_buffer data, const_buffer msg, std::error_code& ec) noexcept { + std::string res; + parse_label(msg, data.data(), res, ec); + return res; + } + + mx_record parse_mx(const_buffer data, const_buffer msg, std::error_code& ec) { + mx_record res; + if (data.size() < 3) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + res.preference = raw_get(data.data()); + parse_label(msg, data.data() + 2, res.name, ec); + return res; + } + + std::string parse_ptr(const_buffer data, const_buffer msg, std::error_code& ec) noexcept { + std::string res; + parse_label(msg, data.data(), res, ec); + return res; + } + + soa_record parse_soa(const_buffer data, const_buffer msg, std::error_code& ec) { + soa_record res; + if (data.size() < 20) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + auto ptr = parse_label(msg, data.data(), res.name, ec); + if (ec) return res; + ptr = parse_label(msg, ptr, res.rname, ec); + if (ec) return res; + if (std::distance(ptr, data.data() + data.size()) < 20) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + res.serial = raw_get(ptr); + res.refresh = raw_get(ptr + 4); + res.retry = raw_get(ptr + 8); + res.expire = raw_get(ptr + 12); + res.minimum = raw_get(ptr + 16); + return res; + } + + srv_record parse_srv(const_buffer data, const_buffer msg, std::error_code& ec) { + srv_record res; + if (data.size() < 7) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + res.priority = raw_get(data.data()); + res.weight = raw_get(data.data() + 2); + res.port = raw_get(data.data() + 4); + auto ptr = parse_label(msg, data.data() + 6, res.target, ec); + if (ptr != data.data() + data.size()) ec = make_error_code(api_error::extra_data); + + return res; + } + + tsig_record parse_tsig(const_buffer data, const_buffer msg, std::error_code& ec) { + tsig_record res; + if (data.size() < 17) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + auto ptr = parse_label(msg, data.data(), res.algorithm, ec); + if (std::distance(ptr, data.data() + data.size()) < 16) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + uint64_t ts{}; + ts |= static_cast(*ptr++) << 40; + ts |= static_cast(*ptr++) << 32; + ts |= static_cast(*ptr++) << 24; + ts |= static_cast(*ptr++) << 16; + ts |= static_cast(*ptr++) << 8; + ts |= static_cast(*ptr++); + res.timestamp = std::chrono::system_clock::from_time_t(ts); + res.fudge = raw_get(ptr); + res.mac.resize(raw_get(ptr + 2)); + ptr += 4; + if (std::distance(ptr, data.data() + data.size()) < 6 + res.mac.size()) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + memcpy(res.mac.data(), ptr, res.mac.size()); + ptr += res.mac.size(); + res.original_id = raw_get(ptr); + res.error = static_cast(raw_get(ptr + 2)); + res.other.resize(raw_get(ptr + 4)); + if (std::distance(ptr + 6, data.data() + data.size()) < res.other.size()) { + ec = make_error_code(api_error::incomplete_message); + return res; + } + memcpy(res.other.data(), ptr + 6, res.other.size()); + + return res; + } + + std::vector build_a(ipv4_address addr) { + const auto data = addr.data(); + return {data.begin(), data.end()}; + } + + std::vector build_aaaa(ipv6_address addr) { + const auto data = addr.data(); + return {data.begin(), data.end()}; + } + + std::vector build_txt(std::string_view str) { + std::vector res; + while (!str.empty()) { + auto part = str.substr(0, 255); + res.push_back(part.size()); + res.insert(res.end(), part.begin(), part.end()); + str.remove_prefix(part.size()); + } + return res; + } + + std::vector build_cname(std::string_view str) { + std::vector res(256); + std::error_code ec; + auto len = convert_name(res.data(), res.size(), str, ec); + if (ec) throw std::system_error(ec); + res.resize(len); + return res; + } + + std::vector build_ns(std::string_view str) { + std::vector res(256); + std::error_code ec; + auto len = convert_name(res.data(), res.size(), str, ec); + if (ec) throw std::system_error(ec); + res.resize(len); + return res; + } + + std::vector build_mx(const mx_record& rr) { + std::vector res(2 + 256); + if (rr.name.size() > 255) throw std::system_error(api_error::label_too_long); + raw_set(res.data(), rr.preference); + std::error_code ec; + auto len = convert_name(res.data() + 2, res.size(), rr.name, ec); + if (ec) throw std::system_error(ec); + res.resize(2 + len); + return res; + } + + std::vector build_ptr(std::string_view str) { + std::vector res(256); + std::error_code ec; + auto len = convert_name(res.data(), res.size(), str, ec); + if (ec) throw std::system_error(ec); + res.resize(len); + return res; + } + + std::vector build_soa(const soa_record& rr) { + std::vector res(256 + 256 + 20); + if (rr.name.size() > 255) throw std::system_error(api_error::label_too_long); + if (rr.rname.size() > 255) throw std::system_error(api_error::label_too_long); + std::error_code ec; + auto len = convert_name(res.data(), res.size(), rr.name, ec); + if (ec) throw std::system_error(ec); + len += convert_name(res.data() + len, res.size() - len, rr.rname, ec); + if (ec) throw std::system_error(ec); + if (res.size() - len < 20) throw std::system_error(api_error::not_enough_space); + raw_set(res.data() + len, rr.serial); + raw_set(res.data() + len + 4, rr.refresh); + raw_set(res.data() + len + 8, rr.retry); + raw_set(res.data() + len + 12, rr.expire); + raw_set(res.data() + len + 16, rr.minimum); + res.resize(len + 20); + return res; + } + + std::vector build_srv(const srv_record& rr) { + std::vector res(6 + 256); + if (rr.target.size() > 255) throw std::system_error(api_error::label_too_long); + raw_set(res.data(), rr.priority); + raw_set(res.data() + 2, rr.weight); + raw_set(res.data() + 4, rr.port); + std::error_code ec; + auto len = convert_name(res.data() + 6, res.size() - 6, rr.target, ec); + if (ec) throw std::system_error(ec); + res.resize(6 + len); + return res; + } + + std::vector build_tsig(const tsig_record& rr) { + std::vector res(256); + if (rr.algorithm.size() > 255) throw std::system_error(api_error::label_too_long); + std::error_code ec; + auto len = convert_name(res.data(), res.size(), rr.algorithm, ec); + if (ec) throw std::system_error(ec); + res.resize(len + 10 + rr.mac.size() + rr.other.size()); + auto offset = len; + uint64_t ts = std::chrono::system_clock::to_time_t(rr.timestamp); + res[offset++] = ((ts >> 40) & 0xff); + res[offset++] = ((ts >> 32) & 0xff); + res[offset++] = ((ts >> 24) & 0xff); + res[offset++] = ((ts >> 16) & 0xff); + res[offset++] = ((ts >> 8) & 0xff); + res[offset++] = (ts & 0xff); + raw_set(res.data() + offset, rr.fudge); + raw_set(res.data() + offset + 2, rr.mac.size()); + memcpy(res.data() + offset + 4, rr.mac.data(), rr.mac.size()); + raw_set(res.data() + offset + rr.mac.size() + 4, rr.original_id); + raw_set(res.data() + offset + rr.mac.size() + 6, static_cast(rr.error)); + raw_set(res.data() + offset + rr.mac.size() + 8, rr.other.size()); + memcpy(res.data() + offset + rr.mac.size() + 10, rr.other.data(), rr.other.size()); + return res; + } + + std::vector question::serialize() const { + std::error_code ec; + std::vector record; + record.resize(name.size() + (name.empty() ? 1 : 2) + 4); + convert_name(record.data(), record.size() - 4, name, ec); + if (ec) throw std::system_error(ec); + auto ptr = record.data() + name.size() + (name.empty() ? 1 : 2); + raw_set(ptr, static_cast(qtype)); + raw_set(ptr + 2, static_cast(qclass)); + return record; + } + + const std::byte* question::parse(std::span msg, const std::byte* const rr) { + name.clear(); + const auto pmsg = msg.data(); + const auto pend = pmsg + msg.size(); + std::error_code ec; + const auto pfixed = parse_label(msg, rr, name, ec); + if (ec) throw std::system_error(ec); + if (pfixed == nullptr || (pfixed + 4) > (pend)) throw std::runtime_error("invalid label"); + qtype = raw_get(pfixed); + qclass = raw_get(pfixed + 2); + return pfixed + 4; + } + + std::vector resource_record::serialize() const { + std::error_code ec; + std::vector record; + record.resize(name.size() + (name.empty() ? 1 : 2) + 10 + rdata.size()); + convert_name(record.data(), record.size() - (10 + rdata.size()), name, ec); + if (ec) throw std::system_error(ec); + auto ptr = record.data() + name.size() + (name.empty() ? 1 : 2); + raw_set(ptr, rtype); + raw_set(ptr + 2, rclass); + raw_set(ptr + 4, ttl); + raw_set(ptr + 8, rdata.size()); + memcpy(ptr + 10, rdata.data(), rdata.size()); + return record; + } + + const std::byte* resource_record::parse(std::span msg, const std::byte* const rr) { + name.clear(); + const auto pmsg = msg.data(); + const auto pend = pmsg + msg.size(); + std::error_code ec; + const auto pfixed = parse_label(msg, rr, name, ec); + if (ec) throw std::system_error(ec); + if (pfixed == nullptr || (pfixed + 10) > (pend)) throw std::runtime_error("invalid record"); + rtype = raw_get(pfixed); + rclass = raw_get(pfixed + 2); + ttl = raw_get(pfixed + 4); + auto rdata_len = raw_get(pfixed + 8); + if (pfixed + 10 + rdata_len > pend) throw std::runtime_error("invalid record"); + rdata.resize(rdata_len); + memcpy(rdata.data(), pfixed + 10, rdata_len); + return pfixed + 10 + rdata_len; + } + + message_builder& message_builder::add_tsig_signature(std::string_view keyname, std::span key) { + std::array tsigdata{}; + std::error_code ec; + auto offset = convert_name(tsigdata.data(), tsigdata.size(), keyname, ec); + assert(offset == 6); + // Class (always any) + tsigdata[offset++] = 0x00; + tsigdata[offset++] = 0xff; + // TTL (always 0) + tsigdata[offset++] = 0x00; + tsigdata[offset++] = 0x00; + tsigdata[offset++] = 0x00; + tsigdata[offset++] = 0x00; + const size_t rr_offset = offset; + // Algorithm name + offset += convert_name(tsigdata.data() + offset, tsigdata.size() - offset, "hmac-md5.sig-alg.reg.int", ec); + auto ts = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + static_assert(sizeof(ts) >= 6); + // Timestamp (48bit unix seconds) + tsigdata[offset++] = (ts >> 40) & 0xff; + tsigdata[offset++] = (ts >> 32) & 0xff; + tsigdata[offset++] = (ts >> 24) & 0xff; + tsigdata[offset++] = (ts >> 16) & 0xff; + tsigdata[offset++] = (ts >> 8) & 0xff; + tsigdata[offset++] = ts & 0xff; + // Fudge + tsigdata[offset++] = 0x01; + tsigdata[offset++] = 0x2c; + // Error (always 0 on request) + tsigdata[offset++] = 0x00; + tsigdata[offset++] = 0x00; + // Other len (always 0 on request) + tsigdata[offset++] = 0x00; + tsigdata[offset++] = 0x00; + + // Signing with md5 + // TODO: Support for sha + auto md = calculate_md5(key, std::span(m_start, bytes_used()), std::span(tsigdata.data(), offset)); + + resource_record rr; + rr.name = keyname; + rr.rtype = qtype::tsig; + rr.rclass = qclass::any; + rr.ttl = 0; + // Copy over the info from tsigdata + rr.rdata.insert(rr.rdata.end(), tsigdata.data() + rr_offset, tsigdata.data() + offset - 4); + // Mac size + rr.rdata.push_back((md.size() >> 8) & 0xff); + rr.rdata.push_back(md.size() & 0xff); + // Mac + rr.rdata.insert(rr.rdata.end(), md.data(), md.data() + md.size()); + // Original id + rr.rdata.push_back(m_start[0]); + rr.rdata.push_back(m_start[1]); + // Error + rr.rdata.push_back(0); + rr.rdata.push_back(0); + // Other len + rr.rdata.push_back(0); + rr.rdata.push_back(0); + return add_additional(rr); + } + + std::vector message::serialize() const { + std::vector res; + res.resize(std::numeric_limits::max()); + serialize(res.data(), res.size()); + return res; + } + + void message::serialize(void* buf, size_t bufsize) const { + message_builder b(buf, bufsize); + b.set_id(id); + b.set_qr(is_response); + b.set_authoritative(is_authoritative); + b.set_truncated(is_truncated); + b.set_recursion_desired(is_recursion_desired); + b.set_recursion_available(is_recursion_available); + b.set_opcode(opcode); + b.set_rcode(rcode); + for (auto& e : questions) + b.add_question(e); + for (auto& e : answers) + b.add_answer(e); + for (auto& e : authorities) + b.add_authority(e); + for (auto& e : additional) + b.add_additional(e); + } + + void message::parse(std::span msg) { + const auto pmsg = msg.data(); + const auto pend = pmsg + msg.size(); + if (msg.size() < 12) throw std::runtime_error("invalid message"); + id = raw_get(pmsg); + is_response = static_cast(pmsg[2]) & 0x80; + opcode = static_cast((static_cast(pmsg[2]) >> 3) & 0x0f); + is_authoritative = static_cast(pmsg[2]) & 0x04; + is_truncated = static_cast(pmsg[2]) & 0x02; + is_recursion_desired = static_cast(pmsg[2]) & 0x01; + is_recursion_available = static_cast(pmsg[3]) & 0x80; + rcode = static_cast(static_cast(pmsg[3]) & 0x0f); + auto qcount = raw_get(pmsg + 4); + auto answercount = raw_get(pmsg + 6); + auto authoritycount = raw_get(pmsg + 8); + auto additionalcount = raw_get(pmsg + 10); + auto* ptr = pmsg + 12; + questions.reserve(qcount); + for (size_t i = 0; i < qcount; i++) + ptr = questions.emplace_back().parse(msg, ptr); + answers.reserve(answercount); + for (size_t i = 0; i < answercount; i++) + ptr = answers.emplace_back().parse(msg, ptr); + authorities.reserve(authoritycount); + for (size_t i = 0; i < authoritycount; i++) + ptr = authorities.emplace_back().parse(msg, ptr); + additional.reserve(additionalcount); + for (size_t i = 0; i < additionalcount; i++) + ptr = additional.emplace_back().parse(msg, ptr); + if (ptr != pend) throw std::runtime_error("extra garbage after message"); + } + + std::ostream& operator<<(std::ostream& s, rcode r) { + switch (r) { + case rcode::no_error: return s << "no_error"; + case rcode::form_error: return s << "form_error"; + case rcode::server_failure: return s << "server_failure"; + case rcode::nx_domain: return s << "nx_domain"; + case rcode::not_implemented: return s << "not_implemented"; + case rcode::refused: return s << "refused"; + case rcode::domain_exists: return s << "domain_exists"; + case rcode::rrset_exists: return s << "rrset_exists"; + case rcode::nx_rrset: return s << "nx_rrset"; + case rcode::not_authoritative: return s << "not_authoritative"; + case rcode::not_zone: return s << "not_zone"; + case rcode::bad_signature: return s << "bad_signature"; + case rcode::bad_key: return s << "bad_key"; + case rcode::bad_time: return s << "bad_time"; + default: return s << static_cast(r); + } + } + + std::ostream& operator<<(std::ostream& s, opcode o) { + switch (o) { + case opcode::query: return s << "QUERY"; + case opcode::iquery: return s << "IQUERY"; + case opcode::status: return s << "STATUS"; + case opcode::update: return s << "UPDATE"; + default: return s << static_cast(o); + } + } + + std::ostream& operator<<(std::ostream& s, qtype t) { + switch (t) { + case qtype::a: return s << "A"; + case qtype::ns: return s << "NS"; + case qtype::md: return s << "MD"; + case qtype::mf: return s << "MF"; + case qtype::cname: return s << "CNAME"; + case qtype::soa: return s << "SOA"; + case qtype::mb: return s << "MB"; + case qtype::mg: return s << "MG"; + case qtype::mr: return s << "MR"; + case qtype::null: return s << "NULL"; + case qtype::wks: return s << "WKS"; + case qtype::ptr: return s << "PTR"; + case qtype::hinfo: return s << "HINFO"; + case qtype::minfo: return s << "MINFO"; + case qtype::mx: return s << "MX"; + case qtype::txt: return s << "TXT"; + case qtype::rp: return s << "RP"; + case qtype::afsdb: return s << "AFSDB"; + case qtype::x25: return s << "X25"; + case qtype::isdn: return s << "ISDN"; + case qtype::rt: return s << "RT"; + case qtype::nsap: return s << "NSAP"; + case qtype::nsap_ptr: return s << "NSAP_PTR"; + case qtype::sig: return s << "SIG"; + case qtype::key: return s << "KEY"; + case qtype::px: return s << "PX"; + case qtype::gpos: return s << "GPOS"; + case qtype::aaaa: return s << "AAAA"; + case qtype::loc: return s << "LOC"; + case qtype::nxt: return s << "NXT"; + case qtype::eid: return s << "EID"; + case qtype::nimloc: return s << "NIMLOC"; + case qtype::srv: return s << "SRV"; + case qtype::atma: return s << "ATMA"; + case qtype::naptr: return s << "NAPTR"; + case qtype::kx: return s << "KX"; + case qtype::cert: return s << "CERT"; + case qtype::a6: return s << "A6"; + case qtype::dname: return s << "DNAME"; + case qtype::sink: return s << "SINK"; + case qtype::opt: return s << "OPT"; + case qtype::apl: return s << "APL"; + case qtype::ds: return s << "DS"; + case qtype::sshfp: return s << "SSHFP"; + case qtype::ipseckey: return s << "IPSECKEY"; + case qtype::rrsig: return s << "RRSIG"; + case qtype::nsec: return s << "NSEC"; + case qtype::dnskey: return s << "DNSKEY"; + case qtype::dhcid: return s << "DHCID"; + case qtype::nsec3: return s << "NSEC3"; + case qtype::nsec3param: return s << "NSEC3PARAM"; + case qtype::tlsa: return s << "TLSA"; + case qtype::smimea: return s << "SMIMEA"; + case qtype::hip: return s << "HIP"; + case qtype::ninfo: return s << "NINFO"; + case qtype::rkey: return s << "RKEY"; + case qtype::talink: return s << "TALINK"; + case qtype::cds: return s << "CDS"; + case qtype::cdnskey: return s << "CDNSKEY"; + case qtype::openpgpkey: return s << "OPENPGPKEY"; + case qtype::csync: return s << "CSYNC"; + case qtype::spf: return s << "SPF"; + case qtype::uinfo: return s << "UINFO"; + case qtype::uid: return s << "UID"; + case qtype::gid: return s << "GID"; + case qtype::unspec: return s << "UNSPEC"; + case qtype::nid: return s << "NID"; + case qtype::l32: return s << "L32"; + case qtype::l64: return s << "L64"; + case qtype::lp: return s << "LP"; + case qtype::eui48: return s << "EUI48"; + case qtype::eui64: return s << "EUI64"; + case qtype::tkey: return s << "TKEY"; + case qtype::tsig: return s << "TSIG"; + case qtype::ixfr: return s << "IXFR"; + case qtype::axfr: return s << "AXFR"; + case qtype::mailb: return s << "MAILB"; + case qtype::maila: return s << "MAILA"; + case qtype::any: return s << "ANY"; + case qtype::uri: return s << "URI"; + case qtype::caa: return s << "CAA"; + case qtype::avc: return s << "AVC"; + case qtype::ta: return s << "TA"; + case qtype::dlv: return s << "DLV"; + default: return s << static_cast(t); + } + } + + std::ostream& operator<<(std::ostream& s, qclass c) { + switch (c) { + case qclass::in: return s << "IN"; + case qclass::csnet: return s << "CSNET"; + case qclass::chaos: return s << "CHAOS"; + case qclass::hs: return s << "HS"; + case qclass::any: return s << "ANY"; + default: return s << static_cast(c); + } + } + + bool print_message_visitor::on_header(const message_header& hdr, asyncpp::io::const_buffer msg) noexcept { + m_message = msg; + m_is_update = hdr.opcode() == opcode::update; + (*m_out) << ";; opcode: " << hdr.opcode() << ", status: " << hdr.rcode() << " id: " << hdr.id() << "\n"; + (*m_out) << ";; flags: "; + if (hdr.qr()) (*m_out) << "qr "; + if (hdr.authoritative()) (*m_out) << "aa "; + if (hdr.truncated()) (*m_out) << "tc "; + if (hdr.recursion_desired()) (*m_out) << "rd "; + if (hdr.recursion_available()) (*m_out) << "ra "; + (*m_out) << ", query: " << hdr.query_count() << ", answer: " << hdr.answer_count() + << ", authority: " << hdr.authoritative_count() << ", additional: " << hdr.additional_count() << "\n"; + return true; + } + + bool print_message_visitor::on_question(std::string_view name, qtype qtype, qclass qclass) noexcept { + if (!m_question_header_done) { + if (m_is_update) + (*m_out) << "\n;; ZONE SECTION:\n"; + else + (*m_out) << "\n;; QUESTION SECTION:\n"; + m_question_header_done = true; + } + (*m_out) << name << "\t" << qclass << "\t" << qtype << "\n"; + return true; + } + + bool print_message_visitor::on_answer(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + std::span rdata) noexcept { + if (!m_answer_header_done) { + if (m_is_update) + (*m_out) << "\n;; PREREQUISITE SECTION:\n"; + else + (*m_out) << "\n;; ANSWER SECTION:\n"; + m_answer_header_done = true; + } + print_rr(name, rtype, rclass, ttl, rdata); + return true; + } + + bool print_message_visitor::on_authority(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + std::span rdata) noexcept { + if (!m_authority_header_done) { + if (m_is_update) + (*m_out) << "\n;; UPDATE SECTION:\n"; + else + (*m_out) << "\n;; AUTHORITY SECTION:\n"; + m_authority_header_done = true; + } + print_rr(name, rtype, rclass, ttl, rdata); + return true; + } + + bool print_message_visitor::on_additional(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + std::span rdata) noexcept { + if (!m_additional_header_done) { + (*m_out) << "\n;; ADDITIONAL SECTION:\n"; + m_additional_header_done = true; + } + print_rr(name, rtype, rclass, ttl, rdata); + return true; + } + + void print_message_visitor::print_rr(std::string_view name, qtype rtype, qclass rclass, uint32_t ttl, + std::span rdata) noexcept { + (*m_out) << name << "\t" << rclass << "\t" << rtype << "\t" << ttl << "\t"; + std::error_code ec; + switch (rtype) { + case qtype::a: + if (auto addr = parse_a(rdata, m_message, ec); !ec) (*m_out) << addr.to_string() << "\n"; + break; + case qtype::aaaa: + if (auto addr = parse_aaaa(rdata, m_message, ec); !ec) (*m_out) << addr.to_string() << "\n"; + break; + case qtype::txt: + if (auto txt = parse_txt(rdata, m_message, ec); !ec) (*m_out) << txt << "\n"; + break; + case qtype::cname: + if (auto cname = parse_cname(rdata, m_message, ec); !ec) (*m_out) << cname << "\n"; + break; + case qtype::ns: + if (auto ns = parse_ns(rdata, m_message, ec); !ec) (*m_out) << ns << "\n"; + break; + case qtype::mx: + if (auto mx = parse_mx(rdata, m_message, ec); !ec) (*m_out) << mx.preference << "\t" << mx.name << "\n"; + break; + case qtype::ptr: + if (auto ptr = parse_ptr(rdata, m_message, ec); !ec) (*m_out) << ptr << "\n"; + break; + case qtype::soa: + if (auto soa = parse_soa(rdata, m_message, ec); !ec) + (*m_out) << soa.name << "\t" << soa.rname << "\t" << soa.serial << "\t" << soa.refresh << "\t" + << soa.retry << "\t" << soa.expire << "\t" << soa.minimum << "\n"; + break; + case qtype::srv: + if (auto srv = parse_srv(rdata, m_message, ec); !ec) + (*m_out) << srv.priority << "\t" << srv.weight << "\t" << srv.port << "\t" << srv.target << "\n"; + break; + case qtype::tsig: + if (auto srv = parse_tsig(rdata, m_message, ec); !ec) + (*m_out) << srv.algorithm << "\t" << std::chrono::system_clock::to_time_t(srv.timestamp) << "\t" + << srv.fudge << "\tmacsize=" << srv.mac.size() << "\t" << srv.original_id << "\t" << srv.error + << "\tothersize=" << srv.other.size() << "\n"; + break; + default: (*m_out) << rdata.size() << " bytes\n"; + }; + if (ec) (*m_out) << rdata.size() << " bytes\n"; + } + + struct client::request { + uint16_t id{}; + std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now(); + size_t tries{0}; + std::vector request_data{}; + std::function callback{}; + size_t send_count{}; + }; + + client::client(asyncpp::io::io_service& service) + : m_socket_ipv4(socket::create_udp(service, address_type::ipv4)), + m_socket_ipv6(socket::create_udp(service, address_type::ipv6)) { + launch([](client* that) -> task<> { + auto token = that->m_stop.get_token(); + std::array buf; + while (!token.stop_requested()) { + try { + auto [size, src] = co_await that->m_socket_ipv4.recv_from(buf.data(), buf.size(), token); + if (size < 12) continue; + // TODO: Maybe check if src is a valid nameserver + auto id = raw_get(buf.data()); + std::unique_lock lck{that->m_mutex}; + auto it = that->m_inflight.find(id); + if (it == that->m_inflight.end()) continue; + auto e = std::move(it->second); + that->m_inflight.erase(it); + lck.unlock(); + e.callback(api_error::ok, const_buffer(buf.data(), size)); + } catch (...) {} + } + }(this)); + launch([](client* that) -> task<> { + auto token = that->m_stop.get_token(); + std::array buf; + while (!token.stop_requested()) { + try { + auto [size, src] = co_await that->m_socket_ipv6.recv_from(buf.data(), buf.size(), token); + if (size < 12) continue; + // TODO: Maybe check if src is a valid nameserver + auto id = raw_get(buf.data()); + std::unique_lock lck{that->m_mutex}; + auto it = that->m_inflight.find(id); + if (it == that->m_inflight.end()) continue; + auto e = std::move(it->second); + that->m_inflight.erase(it); + lck.unlock(); + e.callback(api_error::ok, const_buffer(buf.data(), size)); + } catch (...) {} + } + }(this)); + launch([](client* that) -> task<> { + auto token = that->m_stop.get_token(); + auto wait_time = std::chrono::steady_clock::now() + that->m_timeout; + while (!token.stop_requested()) { + co_await timer::get_default().wait(wait_time, that->m_stop_timer.get_token()); + if (token.stop_requested()) break; + auto now = std::chrono::steady_clock::now(); + wait_time = now + that->m_timeout; + std::unique_lock lck{that->m_mutex}; + if (that->m_nameservers.empty()) continue; + for (auto it = that->m_inflight.begin(); it != that->m_inflight.end();) { + auto& e = it->second; + if (now - e.start > that->m_timeout) { + if (++e.tries > that->m_retries) { + auto cb = std::move(e.callback); + it = that->m_inflight.erase(it); + lck.unlock(); + cb(api_error::timeout, {}); + lck.lock(); + continue; + } else { + e.start = now; + that->send_requests(e); + } + } + wait_time = (std::min)(wait_time, e.start + that->m_timeout); + it++; + } + } + }(this)); + } + + client::~client() { this->stop(); } + + void client::set_timeout(std::chrono::milliseconds timeout) { + if (timeout.count() < 1) return; + std::unique_lock lck{m_mutex}; + m_timeout = timeout; + m_stop_timer.request_stop(); + m_stop_timer = {}; + } + + std::optional client::get_free_id() const noexcept { + std::unique_lock lck{m_mutex}; + if (m_inflight.size() == std::numeric_limits::max() + 1) return std::nullopt; + uint16_t id = rand(); + for (size_t i = 0; i < 10; i++) { + if (!m_inflight.contains(id)) return id; + id = rand(); + } + // Fallback for a linear search for the first free id + for (size_t i = 0; i <= std::numeric_limits::max(); i++) { + if (!m_inflight.contains(id)) return id; + } + // This shouldn't be possible + return std::nullopt; + } + + void client::query(asyncpp::io::const_buffer query, std::function callback) { + if (query.size() < 12) { + callback(api_error::incomplete_message, {}); + return; + } + std::unique_lock lck{m_mutex}; + request r; + if (const auto origid = raw_get(query.data()); origid == 0) { + auto found = get_free_id(); + if (!found) throw std::system_error(api_error::no_id); + r.id = *found; + } else if (m_inflight.contains(origid)) { + throw std::system_error(api_error::duplicate_id); + } else + r.id = origid; + + r.callback = std::move(callback); + r.request_data.resize(query.size()); + memcpy(r.request_data.data(), query.data(), query.size()); + raw_set(r.request_data.data(), r.id); + r.send_count = m_nameservers.size(); + auto it = m_inflight.emplace(r.id, std::move(r)).first; + + send_requests(it->second); + } + + void client::query(std::string_view name, dns::qtype qtype, dns::qclass qclass, + std::function callback) { + std::array buffer; + auto used = message_builder(buffer.data(), buffer.size()) + .set_opcode(opcode::query) + .set_recursion_desired(true) + .add_question(name, qtype, qclass) + .bytes_used(); + query(asyncpp::io::const_buffer(buffer.data(), used), std::move(callback)); + } + + void client::send_requests(request& req) { + req.send_count = m_nameservers.size(); + for (auto& e : m_nameservers) { + auto sock = e.is_ipv6() ? &m_socket_ipv6 : &m_socket_ipv4; + sock->send_to( + req.request_data.data(), req.request_data.size(), e, + [id = req.id, this](size_t, std::error_code ec) { + std::unique_lock lck{m_mutex}; + auto it = m_inflight.find(id); + if (it == m_inflight.end()) return; + if (ec) { + if (--it->second.send_count != 0) return; + auto e = std::move(it->second); + m_inflight.erase(it); + lck.unlock(); + e.callback(api_error::internal, {}); + return; + } + }, + m_stop.get_token()); + } + if (m_nameservers.empty()) { + std::unique_lock lck{m_mutex}; + auto it = m_inflight.find(req.id); + auto cb = std::move(it->second.callback); + m_inflight.erase(it); + lck.unlock(); + cb(api_error::internal, {}); + } + } + + struct resolver { + client* parent; + std::string current_name; + dns::qtype type; + std::function res)> callback; + size_t max_depth{10}; + + void next() { + std::array query; + auto size = message_builder(query.data(), query.size()) + .set_opcode(opcode::query) + .set_recursion_desired(true) + .add_question(current_name, type, qclass::in) + .bytes_used(); + parent->query(const_buffer(query.data(), size), [this](api_error error, const_buffer response) { + if (error == api_error::ok) { + std::vector
res; + bool cname_found = false; + try { + visit_answer(response, [&](std::string_view rname, qtype rtype, qclass rclass, uint32_t ttl, + asyncpp::io::const_buffer rdata) { + if (current_name != rname || rclass != qclass::in) return true; + std::error_code ec; + if (rtype == qtype::a) { + auto rr = parse_a(rdata, response, ec); + if (!ec) res.push_back(address(rr)); + } else if (rtype == qtype::aaaa) { + auto rr = parse_aaaa(rdata, response, ec); + if (!ec) res.push_back(address(rr)); + } else if (rtype == qtype::cname) { + auto rr = parse_cname(rdata, response, ec); + if (!ec) { + current_name = std::move(rr); + cname_found = true; + } + } + return true; + }); + } catch (...) { + callback({}); + delete this; + return; + } + max_depth--; + // We found at least some ips + if (cname_found && max_depth != 0) { + // Nothing yet, retry with the cname name + this->next(); + return; + } else { + // No results and no cname + callback(std::move(res)); + delete this; + return; + } + } else { + callback({}); + delete this; + return; + } + }); + } + }; + + void client::resolve(std::string name, dns::qtype type, std::function res)> callback) { + auto res = new resolver{this, std::move(name), type, std::move(callback)}; + res->next(); + } + + void client::stop() { + std::unique_lock lck{m_mutex}; + m_stop.request_stop(); + m_stop_timer.request_stop(); + for (auto& e : m_inflight) { + e.second.callback(api_error::cancelled, {}); + } + } +} // namespace asyncpp::io::dns diff --git a/src/file.cpp b/src/file.cpp new file mode 100644 index 0000000..659c448 --- /dev/null +++ b/src/file.cpp @@ -0,0 +1,96 @@ +#include + +#ifndef _WIN32 +#include +#include +#include +#include +#else +#include +#endif + +namespace asyncpp::io { + file::file(io_service& io) : m_io(&io), m_fd(detail::io_engine::invalid_file_handle) {} + file::file(io_service& io, detail::io_engine::file_handle_t fd) : m_io(&io), m_fd(fd) {} + file::file(io_service& io, const char* filename, std::ios_base::openmode mode) : file(io) { open(filename, mode); } + file::file(io_service& io, const std::string& filename, std::ios_base::openmode mode) : file(io) { + open(filename, mode); + } + file::file(io_service& io, const std::filesystem::path& filename, std::ios_base::openmode mode) : file(io) { + open(filename, mode); + } + file::file(file&& other) noexcept + : m_io(std::exchange(other.m_io, nullptr)), + m_fd(std::exchange(other.m_fd, detail::io_engine::invalid_file_handle)) {} + file& file::operator=(file&& other) { + close(); + m_io = std::exchange(other.m_io, nullptr); + m_fd = std::exchange(other.m_fd, detail::io_engine::invalid_file_handle); + return *this; + } + file::~file() { close(); } + + void file::open(const char* filename, std::ios_base::openmode mode) { +#ifndef _WIN32 + if ((mode & std::ios_base::ate) == std::ios_base::ate) throw std::logic_error("unsupported flag"); + int m = 0; + if ((mode & std::ios_base::app) == std::ios_base::app) m |= O_APPEND; + if ((mode & std::ios_base::in) == std::ios_base::in) + m |= ((mode & std::ios_base::out) == std::ios_base::out) ? O_RDWR : O_RDONLY; + else if ((mode & std::ios_base::out) == std::ios_base::out) + m |= O_WRONLY; + else + throw std::invalid_argument("neither std::ios::in, nor std::ios::out was specified"); + if ((mode & std::ios_base::trunc) == std::ios_base::trunc) m |= O_TRUNC; + auto res = ::open(filename, m, 0660); + if (res < 0) throw std::system_error(errno, std::system_category()); +#else + DWORD access_mode = 0; + if ((mode & std::ios_base::in) == std::ios_base::in) access_mode |= GENERIC_READ; + if ((mode & std::ios_base::out) == std::ios_base::out) access_mode |= GENERIC_WRITE; + if ((mode & (std::ios_base::in | std::ios_base::out)) == 0) + throw std::invalid_argument("neither std::ios::in, nor std::ios::out was specified"); + HANDLE h = CreateFileA(filename, access_mode, 0, NULL, CREATE_NEW, FILE_ATTRIBUTE_NORMAL, NULL); + // TODO: Remaining code +#endif + close(); + m_fd = res; + } + void file::open(const std::string& filename, std::ios_base::openmode mode) { return open(filename.c_str(), mode); } + void file::open(const std::filesystem::path& filename, std::ios_base::openmode mode) { + return open(filename.c_str(), mode); + } + + bool file::is_open() const noexcept { return m_io != nullptr && m_fd != detail::io_engine::invalid_file_handle; } + + void file::close() { + if (m_fd != detail::io_engine::invalid_file_handle) { +#ifndef _WIN32 + ::close(m_fd); +#else + ::CloseHandle(m_fd); +#endif + m_fd = detail::io_engine::invalid_file_handle; + } + } + + void file::swap(file& other) { + std::swap(m_io, other.m_io); + std::swap(m_fd, other.m_fd); + } + + uint64_t file::size() { +#ifdef __APPLE__ + struct stat info {}; + auto res = fstat(m_fd, &info); +#elif defined(_WIN32) + struct _stat64 info {}; + auto res = _fstat64(m_fd, &info); +#else + struct stat64 info {}; + auto res = fstat64(m_fd, &info); +#endif + if (res < 0) throw std::system_error(errno, std::system_category()); + return info.st_size; + } +} // namespace asyncpp::io diff --git a/src/io_engine.cpp b/src/io_engine.cpp new file mode 100644 index 0000000..d2d3516 --- /dev/null +++ b/src/io_engine.cpp @@ -0,0 +1,24 @@ +#include + +#include + +namespace asyncpp::io::detail { + // Select is always supported + std::unique_ptr create_io_engine_select(); + // Only supported on Linux on kernel 5.1+ + std::unique_ptr create_io_engine_uring(); + + std::unique_ptr create_io_engine() { + if (const auto env = getenv("ASYNCPP_IO_ENGINE"); env != nullptr) { + std::string_view engine = env; + if (engine == "uring") + return create_io_engine_uring(); + else if (engine == "select") + return create_io_engine_select(); + else if (!engine.empty()) + throw std::runtime_error("unknown io engine " + std::string(engine)); + } + if (auto uring = create_io_engine_uring(); uring != nullptr) return uring; + return create_io_engine_select(); + } +} // namespace asyncpp::io::detail diff --git a/src/io_engine_select.cpp b/src/io_engine_select.cpp new file mode 100644 index 0000000..0557d86 --- /dev/null +++ b/src/io_engine_select.cpp @@ -0,0 +1,449 @@ +#include + +#include +#include +#include + +#ifndef _WIN32 +#include +#include +#include +#include +#include +#else +#include +#include +#endif + +#ifdef __linux__ +#define USE_EVENTFD +#endif + +#ifdef USE_EVENTFD +#include +#endif + +namespace asyncpp::io::detail { + namespace { + enum class op { connect, accept, recv, send, recv_from, send_to }; + struct entry { + op operation; + io_engine::socket_handle_t socket; + io_engine::completion_data* done; + union { + struct { + void* buf; + size_t len; + } recv; + struct { + const void* buf; + size_t len; + } send; + struct { + void* buf; + size_t len; + endpoint* source; + } recv_from; + struct { + const void* buf; + size_t len; + endpoint destination; + } send_to; + } state; + }; + } // namespace + + class io_engine_select : public io_engine { + public: + io_engine_select(); + io_engine_select(const io_engine_select&) = delete; + io_engine_select& operator=(const io_engine_select&) = delete; + ~io_engine_select(); + + std::string_view name() const noexcept override; + + size_t run(bool nowait) override; + void wake() override; + + bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) override; + bool enqueue_accept(socket_handle_t socket, completion_data* cd) override; + bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) override; + bool enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) override; + bool enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) override; + bool enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) override; + + bool enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) override; + bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, completion_data* cd) override; + bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) override; + + bool cancel(completion_data* cd) override; + + private: + socket_handle_t m_wake_fd; +#ifndef USE_EVENTFD + socket_handle_t m_wake_fd_write; +#endif + std::mutex m_inflight_mtx; + std::vector m_inflight; + std::vector m_done_callbacks; + + enum { RDY_READ = 1, RDY_WRITE = 2, RDY_ERR = 4 }; + bool handle_io(entry& e, int state); + }; + + std::unique_ptr create_io_engine_select() { return std::make_unique(); } + + io_engine_select::io_engine_select() { +#ifdef USE_EVENTFD + m_wake_fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + if (m_wake_fd < 0) throw std::system_error(errno, std::system_category(), "eventfd failed"); +#else + int fds[2]; + auto res = pipe(fds); + if (res < 0) throw std::system_error(errno, std::system_category(), "pipe failed"); + int flags0 = fcntl(fds[0], F_GETFL, 0); + int flags1 = fcntl(fds[1], F_GETFL, 0); + if (flags0 < 0 || flags1 < 0 || // + fcntl(fds[0], F_SETFL, flags0 | O_NONBLOCK) < 0 || fcntl(fds[1], F_SETFL, flags1 | O_NONBLOCK) < 0 || // + fcntl(fds[0], F_SETFD, FD_CLOEXEC) < 0 || fcntl(fds[1], F_SETFD, FD_CLOEXEC) < 0) { + close(fds[0]); + close(fds[1]); + throw std::system_error(errno, std::system_category(), "pipe failed"); + } + m_wake_fd = fds[0]; + m_wake_fd_write = fds[1]; +#endif + } + + io_engine_select::~io_engine_select() { +#ifdef _WIN32 +#else + if (m_wake_fd >= 0) close(m_wake_fd); +#ifndef USE_EVENTFD + if (m_wake_fd_write >= 0) close(m_wake_fd_write); +#endif +#endif + } + + std::string_view io_engine_select::name() const noexcept { return "select"; } + + size_t io_engine_select::run(bool nowait) { + fd_set rd_set{}, wrt_set{}, err_set{}; + int max_fd = m_wake_fd; + FD_SET(m_wake_fd, &rd_set); + std::unique_lock lck{m_inflight_mtx}; + if (nowait && m_inflight.empty()) return m_inflight.size(); + for (auto& e : m_inflight) { + switch (e.operation) { + case op::connect: + case op::send: + case op::send_to: FD_SET(e.socket, &wrt_set); break; + case op::accept: + case op::recv: + case op::recv_from: FD_SET(e.socket, &rd_set); break; + } + max_fd = (std::max)(e.socket, max_fd); + } + lck.unlock(); + struct timeval timeout {}; + if (!nowait) timeout.tv_sec = 10; + auto res = select(max_fd + 1, &rd_set, &wrt_set, &err_set, &timeout); + if (res < 0) throw std::system_error(errno, std::system_category(), "select failed"); + if (FD_ISSET(m_wake_fd, &rd_set)) { + uint64_t val; + [[maybe_unused]] auto rsize = read(m_wake_fd, &val, sizeof(val)); + // Note we ignore the result because its irrelevant + res--; + } + // Note: inflight might have changed in between, but we dont care. + lck.lock(); + if (res == 0) return m_inflight.size(); + for (auto it = m_inflight.begin(); it != m_inflight.end();) { + int state = 0; + state |= FD_ISSET(it->socket, &rd_set) ? RDY_READ : 0; + state |= FD_ISSET(it->socket, &wrt_set) ? RDY_WRITE : 0; + state |= FD_ISSET(it->socket, &err_set) ? RDY_ERR : 0; + if (state == 0 || !handle_io(*it, state)) + it++; + else + it = m_inflight.erase(it); + } + lck.unlock(); + for (auto e : m_done_callbacks) { + e->callback(e->userdata); + } + m_done_callbacks.clear(); + return m_inflight.size(); + } + + bool io_engine_select::handle_io(entry& e, int state) { + switch (e.operation) { + case op::connect: { + if ((state & RDY_WRITE) == 0) return false; + int result; + socklen_t result_len = sizeof(result); + if (getsockopt(e.socket, SOL_SOCKET, SO_ERROR, &result, &result_len) < 0) { + e.done->result = -errno; + } else { + e.done->result = -result; + } + m_done_callbacks.push_back(e.done); + return true; + } + case op::send: { + if ((state & RDY_WRITE) == 0) return false; + auto res = ::send(e.socket, e.state.send.buf, e.state.send.len, 0); + if (res >= 0) { + e.state.send.len -= res; + e.state.send.buf = static_cast(e.state.send.buf) + res; + if (e.state.send.len == 0) { + e.done->result = 0; + m_done_callbacks.push_back(e.done); + return true; + } + } else if (errno != EAGAIN) { + e.done->result = -errno; + m_done_callbacks.push_back(e.done); + return true; + } + return false; + } + case op::accept: { + if ((state & RDY_READ) == 0) return false; + auto res = ::accept(e.socket, nullptr, nullptr); + if (res >= 0 || errno != EAGAIN) { + e.done->result = res >= 0 ? res : -errno; + m_done_callbacks.push_back(e.done); + return true; + } + return false; + } + case op::recv: { + if ((state & RDY_READ) == 0) return false; + auto res = ::recv(e.socket, e.state.recv.buf, e.state.recv.len, 0); + if (res >= 0 || errno != EAGAIN) { + e.done->result = res >= 0 ? res : -errno; + m_done_callbacks.push_back(e.done); + return true; + } + return false; + } + case op::send_to: { + if ((state & RDY_WRITE) == 0) return false; + auto sa = e.state.send_to.destination.to_sockaddr(); + auto res = ::sendto(e.socket, e.state.send.buf, e.state.send.len, 0, reinterpret_cast(&sa.first), + sa.second); + if (res >= 0 || errno != EAGAIN) { + e.done->result = res >= 0 ? res : -errno; + m_done_callbacks.push_back(e.done); + return true; + } + return false; + } + case op::recv_from: { + if ((state & RDY_READ) == 0) return false; + sockaddr_storage sa; + socklen_t sa_len = sizeof(sa); + auto res = + ::recvfrom(e.socket, e.state.recv.buf, e.state.recv.len, 0, reinterpret_cast(&sa), &sa_len); + if (res >= 0 || errno != EAGAIN) { + e.done->result = res >= 0 ? res : -errno; + if (res >= 0 && e.state.recv_from.source) { + if (sa.ss_family == AF_INET || sa.ss_family == AF_INET6 || sa.ss_family == AF_UNIX) + *e.state.recv_from.source = endpoint(sa, sa_len); + else + *e.state.recv_from.source = endpoint{}; + } + m_done_callbacks.push_back(e.done); + return true; + } + return false; + } + default: return true; + } + } + + void io_engine_select::wake() { + uint64_t val = 1; +#ifdef USE_EVENTFD + write(m_wake_fd, &val, sizeof(val)); +#else + write(m_wake_fd_write, &val, sizeof(val)); +#endif + } + + bool io_engine_select::enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) { + auto sa = ep.to_sockaddr(); + auto res = ::connect(socket, reinterpret_cast(&sa.first), sa.second); + if (res == 0 || errno != EINPROGRESS) { + // Succeeded right away + cd->result = res ? -errno : 0; + return true; + } + + entry e{}; + e.operation = op::connect; + e.socket = socket; + e.done = cd; + std::unique_lock lck{m_inflight_mtx}; + m_inflight.push_back(e); + wake(); + return false; + } + + bool io_engine_select::enqueue_accept(socket_handle_t socket, completion_data* cd) { + auto res = ::accept(socket, nullptr, nullptr); + if (res >= 0 || errno != EAGAIN) { + cd->result = res >= 0 ? res : -errno; + return true; + } + + entry e{}; + e.operation = op::accept; + e.socket = socket; + e.done = cd; + std::unique_lock lck{m_inflight_mtx}; + m_inflight.push_back(e); + wake(); + return false; + } + + bool io_engine_select::enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) { + auto res = ::recv(socket, buf, len, 0); + if (res >= 0 || errno != EAGAIN) { + cd->result = res >= 0 ? res : -errno; + return true; + } + + entry e{}; + e.operation = op::recv; + e.socket = socket; + e.done = cd; + e.state.recv.buf = buf; + e.state.recv.len = len; + std::unique_lock lck{m_inflight_mtx}; + m_inflight.push_back(e); + wake(); + return false; + } + + bool io_engine_select::enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) { + auto res = ::send(socket, buf, len, 0); + if (res >= 0) { + len -= res; + buf = static_cast(buf) + res; + } else if (errno != EAGAIN) { + cd->result = -errno; + return true; + } + if (len == 0) { + cd->result = 0; + return true; + } + + entry e{}; + e.operation = op::send; + e.socket = socket; + e.done = cd; + e.state.send.buf = buf; + e.state.send.len = len; + std::unique_lock lck{m_inflight_mtx}; + m_inflight.push_back(e); + wake(); + return false; + } + + bool io_engine_select::enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) { + sockaddr_storage sa; + socklen_t sa_len = sizeof(sa); + auto res = ::recvfrom(socket, buf, len, 0, reinterpret_cast(&sa), &sa_len); + if (res >= 0 || errno != EAGAIN) { + cd->result = res >= 0 ? res : -errno; + if (res >= 0 && source) { + if (sa.ss_family == AF_INET || sa.ss_family == AF_INET6) + *source = endpoint(sa, sa_len); + else + *source = endpoint{}; + } + return true; + } + + entry e{}; + e.operation = op::recv_from; + e.socket = socket; + e.done = cd; + e.state.recv_from.buf = buf; + e.state.recv_from.len = len; + e.state.recv_from.source = source; + std::unique_lock lck{m_inflight_mtx}; + m_inflight.push_back(e); + wake(); + return false; + } + + bool io_engine_select::enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) { + auto sa = dst.to_sockaddr(); + auto res = ::sendto(socket, buf, len, 0, reinterpret_cast(&sa.first), sa.second); + if (res >= 0 || errno != EAGAIN) { + cd->result = res >= 0 ? res : -errno; + return true; + } + + entry e{}; + e.operation = op::send_to; + e.socket = socket; + e.done = cd; + e.state.send_to.buf = buf; + e.state.send_to.len = len; + e.state.send_to.destination = dst; + std::unique_lock lck{m_inflight_mtx}; + m_inflight.push_back(e); + wake(); + return false; + } + + bool io_engine_select::enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) { + // There is no way to do async file io on linux without uring, so just do the read inline + auto res = pread(fd, buf, len, offset); + cd->result = res >= 0 ? res : -errno; + return true; + } + + bool io_engine_select::enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, + completion_data* cd) { + // There is no way to do async file io on linux without uring, so just do the write inline + auto res = pwrite(fd, buf, len, offset); + cd->result = res >= 0 ? res : -errno; + return true; + } + + bool io_engine_select::enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) { +// There is no way to do async file io on linux without uring, so just do the fsync inline +#ifdef __linux__ + auto res = flags == fsync_flags::datasync ? fdatasync(fd) : fsync(fd); +#else + auto res = fsync(fd); +#endif + cd->result = res >= 0 ? res : -errno; + return true; + } + + bool io_engine_select::cancel(completion_data* cd) { + std::unique_lock lck{m_inflight_mtx}; + for (auto it = m_inflight.begin(); it != m_inflight.end(); it++) { + if (it->done == cd) { + it = m_inflight.erase(it); + lck.unlock(); + cd->result = -ECANCELED; + cd->callback(cd->userdata); + return true; + } + } + return false; + } + +} // namespace asyncpp::io::detail diff --git a/src/io_engine_uring.cpp b/src/io_engine_uring.cpp new file mode 100644 index 0000000..795e973 --- /dev/null +++ b/src/io_engine_uring.cpp @@ -0,0 +1,279 @@ +#include + +#if !defined(__linux__) || !defined(ASYNCPP_ENABLE_URING) +namespace asyncpp::io::detail { + std::unique_ptr create_io_engine_uring() { return nullptr; } +} // namespace asyncpp::io::detail +#else + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "block_allocator.h" + +namespace asyncpp::io::detail { + + class io_engine_uring : public io_engine { + public: + io_engine_uring(struct io_uring ring) noexcept; + io_engine_uring(const io_engine_uring&) = delete; + io_engine_uring& operator=(const io_engine_uring&) = delete; + ~io_engine_uring(); + + std::string_view name() const noexcept override; + + size_t run(bool nowait) override; + void wake() override; + + bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) override; + bool enqueue_accept(socket_handle_t socket, completion_data* cd) override; + bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) override; + bool enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) override; + bool enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) override; + bool enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) override; + + bool enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) override; + bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, completion_data* cd) override; + bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) override; + + bool cancel(completion_data* cd) override; + + private: + struct msghdr_info { + struct msghdr hdr {}; + sockaddr_storage sockaddr{}; + iovec data{}; + asyncpp::io::endpoint* real_endpoint{}; + }; + + std::mutex m_sqe_mtx{}; + std::mutex m_cqe_mtx{}; + std::atomic m_inflight_count{}; + struct io_uring m_ring {}; + block_allocator m_state_allocator{}; + }; + + std::unique_ptr create_io_engine_uring() { + // check if the kernel supports uring and return nullptr if not + if (syscall(__NR_io_uring_register, 0, IORING_UNREGISTER_BUFFERS, NULL, 0) && errno == ENOSYS) return nullptr; + io_uring ring{}; + auto res = io_uring_queue_init(256, &ring, 0); + if (res < 0) return nullptr; + std::unique_ptr probe(io_uring_get_probe_ring(&ring), &free); + // Make sure all required opcodes are supported + if (io_uring_opcode_supported(probe.get(), IORING_OP_NOP) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_CONNECT) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_ACCEPT) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_RECV) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_SEND) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_RECVMSG) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_SENDMSG) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_ASYNC_CANCEL) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_READV) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_WRITEV) == 0 || + io_uring_opcode_supported(probe.get(), IORING_OP_FSYNC) == 0) + return nullptr; + return std::make_unique(std::move(ring)); + } + + io_engine_uring::io_engine_uring(struct io_uring ring) noexcept : m_ring(ring) {} + + io_engine_uring::~io_engine_uring() { io_uring_queue_exit(&m_ring); } + + std::string_view io_engine_uring::name() const noexcept { return "uring"; } + + size_t io_engine_uring::run(bool nowait) { + __kernel_timespec timeout{}; + if (!nowait) timeout.tv_sec = 10; + io_uring_cqe* cqe; + std::unique_lock lck{m_cqe_mtx}; + auto res = io_uring_wait_cqe_timeout(&m_ring, &cqe, &timeout); + if (res == -ETIME || res == -EINTR) return m_inflight_count; + if (res < 0) throw std::system_error(-res, std::system_category(), "uring wait cqe failed"); + auto* info = static_cast(io_uring_cqe_get_data(cqe)); + auto opres = cqe->res; + io_uring_cqe_seen(&m_ring, cqe); + // Wakeup call using wake() + if (info == nullptr) return m_inflight_count; + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + lck.unlock(); + + info->result = opres; + + if (auto extra = static_cast(info->engine_state); extra != nullptr) { + if (extra->real_endpoint != nullptr) { + if (extra->sockaddr.ss_family == AF_INET || extra->sockaddr.ss_family == AF_INET6 || + extra->sockaddr.ss_family == AF_UNIX) + *extra->real_endpoint = endpoint(extra->sockaddr, extra->hdr.msg_namelen); + else + *extra->real_endpoint = endpoint(); + } + m_state_allocator.destroy(extra); + info->engine_state = nullptr; + } + + if (info->callback) info->callback(info->userdata); + + return m_inflight_count; + } + + void io_engine_uring::wake() { + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_nop(sqe); + io_uring_sqe_set_data(sqe, nullptr); + io_uring_submit(&m_ring); + } + + bool io_engine_uring::enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) { + auto sa = ep.to_sockaddr(); + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_connect(sqe, socket, reinterpret_cast(&sa.first), sa.second); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_accept(socket_handle_t socket, completion_data* cd) { + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_accept(sqe, socket, nullptr, nullptr, 0); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) { + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_recv(sqe, socket, buf, len, 0); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) { + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_send(sqe, socket, buf, len, 0); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) { + auto* info = m_state_allocator.create(); + info->hdr.msg_name = &info->sockaddr; + info->hdr.msg_namelen = sizeof(info->sockaddr); + info->hdr.msg_iov = &info->data; + info->hdr.msg_iovlen = 1; + info->data.iov_base = buf; + info->data.iov_len = len; + info->real_endpoint = source; + + cd->engine_state = info; + + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_recvmsg(sqe, socket, &info->hdr, 0); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) { + auto addr = dst.to_sockaddr(); + auto* info = m_state_allocator.create(); + info->hdr.msg_name = &info->sockaddr; + info->hdr.msg_namelen = addr.second; + info->hdr.msg_iov = &info->data; + info->hdr.msg_iovlen = 1; + info->sockaddr = addr.first; + info->data.iov_base = const_cast(buf); + info->data.iov_len = len; + info->real_endpoint = nullptr; + + cd->engine_state = info; + + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_sendmsg(sqe, socket, &info->hdr, 0); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) { + auto* info = m_state_allocator.create(); + info->data.iov_base = buf; + info->data.iov_len = len; + info->real_endpoint = nullptr; + + cd->engine_state = info; + + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_readv(sqe, fd, &info->data, 1, offset); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, + completion_data* cd) { + auto* info = m_state_allocator.create(); + info->data.iov_base = const_cast(buf); + info->data.iov_len = len; + info->real_endpoint = nullptr; + + cd->engine_state = info; + + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_writev(sqe, fd, &info->data, 1, offset); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) { + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_fsync(sqe, fd, flags == fsync_flags::datasync ? IORING_FSYNC_DATASYNC : 0); + io_uring_sqe_set_data(sqe, cd); + io_uring_submit(&m_ring); + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + return false; + } + + bool io_engine_uring::cancel(completion_data* cd) { + std::lock_guard lck{m_sqe_mtx}; + struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); + io_uring_prep_cancel(sqe, cd, 0); + io_uring_sqe_set_data(sqe, nullptr); + io_uring_submit(&m_ring); + return true; + } + +} // namespace asyncpp::io::detail + +#endif diff --git a/src/io_service.cpp b/src/io_service.cpp new file mode 100644 index 0000000..789b3d5 --- /dev/null +++ b/src/io_service.cpp @@ -0,0 +1,72 @@ +#include +#include + +#include + +namespace asyncpp::io { + + io_service::io_service() : m_engine(detail::create_io_engine()) {} + + io_service::~io_service() noexcept(false) {} + + bool io_service::run(run_mode mode) { + if (mode == run_mode::until_stopped) { + bool had_tasks = false; + while (!m_stopped) + had_tasks = run(run_mode::once); + return had_tasks; + } else if (mode == run_mode::while_active) { + bool had_tasks = true; + while (had_tasks) + had_tasks = run(run_mode::once); + return true; + } + + auto old_disp = dispatcher::current(this); + auto had_tasks = m_engine->run(mode == run_mode::nowait) != 0; + auto task = m_dispatched.pop(); + while (task.has_value()) { + had_tasks = true; + task.value()(); + task = m_dispatched.pop(); + } + dispatcher::current(old_disp); + return had_tasks; + } + + void io_service::stop() { + m_stopped = true; + if (dispatcher::current() != this) m_engine->wake(); + } + + void io_service::push(std::function fn) { + m_dispatched.push(std::move(fn)); + if (dispatcher::current() != this) m_engine->wake(); + } + + namespace { + class default_io_service final : public io_service { + public: + default_io_service() { + m_thread = std::thread([this]() { +#ifdef __linux__ + pthread_setname_np(pthread_self(), "dflt_io_srv"); +#endif + this->run(run_mode::until_stopped); + }); + } + ~default_io_service() { + this->stop(); + if (m_thread.joinable()) m_thread.join(); + } + + private: + std::thread m_thread; + }; + } // namespace + + std::shared_ptr io_service::get_default() { + static auto instance = std::make_shared(); + return instance; + } +} // namespace asyncpp::io diff --git a/src/socket.cpp b/src/socket.cpp new file mode 100644 index 0000000..5f7d292 --- /dev/null +++ b/src/socket.cpp @@ -0,0 +1,270 @@ +#include + +#include + +#ifndef _WIN32 +#include +#include +#include +#include +#else +#include +#include +#endif + +namespace { + + std::system_error sys_error(int code) { + return std::system_error(std::make_error_code(static_cast(code))); + } + +} // namespace + +namespace asyncpp::io { + + socket socket::create_tcp(io_service& io, address_type addrtype) { + int domain = -1; + switch (addrtype) { + case address_type::ipv4: domain = AF_INET; break; + case address_type::ipv6: domain = AF_INET6; break; + case address_type::uds: domain = AF_UNIX; break; + } + if (domain == -1) throw sys_error(ENOTSUP); +#ifndef __APPLE__ + auto fd = ::socket(domain, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0); + if (fd < 0) throw sys_error(errno); +#else + auto fd = ::socket(domain, SOCK_STREAM, 0); + if (fd < 0) throw sys_error(errno); + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0 || fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0 || fcntl(fd, F_SETFD, FD_CLOEXEC) < 0) { + close(fd); + throw std::system_error(errno, std::system_category(), "fcntl failed"); + } +#endif + if (addrtype == address_type::ipv6) { + int opt = 0; + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) + throw std::system_error(errno, std::system_category(), "setsockopt failed"); + } + return socket(&io, fd); + } + + socket_create_and_connect_awaitable socket::create_connected_tcp(io_service& io, endpoint ep) { + return socket_create_and_connect_awaitable(io, ep); + } + + socket_create_and_connect_cancellable_awaitable socket::create_connected_tcp(io_service& io, endpoint ep, + asyncpp::stop_token token) { + return socket_create_and_connect_cancellable_awaitable(std::move(token), io, ep); + } + + socket socket::create_udp(io_service& io, address_type addrtype) { + int domain = -1; + switch (addrtype) { + case address_type::ipv4: domain = AF_INET; break; + case address_type::ipv6: domain = AF_INET6; break; + case address_type::uds: domain = AF_UNIX; break; + } + if (domain == -1) throw sys_error(ENOTSUP); +#ifndef __APPLE__ + auto fd = ::socket(domain, SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0); + if (fd < 0) throw sys_error(errno); +#else + auto fd = ::socket(domain, SOCK_DGRAM, 0); + if (fd < 0) throw sys_error(errno); + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0 || fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0 || fcntl(fd, F_SETFD, FD_CLOEXEC) < 0) { + close(fd); + throw std::system_error(errno, std::system_category(), "fcntl failed"); + } +#endif + if (addrtype == address_type::ipv6) { + int opt = 0; + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) + throw std::system_error(errno, std::system_category(), "setsockopt failed"); + } + return socket(&io, fd); + } + + socket socket::create_and_bind_tcp(io_service& io, const endpoint& ep) { + auto sock = create_tcp(io, ep.type()); + sock.bind(ep); + return sock; + } + + socket socket::create_and_bind_udp(io_service& io, const endpoint& ep) { + auto sock = create_udp(io, ep.type()); + sock.bind(ep); + return sock; + } + + socket socket::from_fd(io_service& io, detail::io_engine::socket_handle_t fd) { + if (fd < 0) throw std::logic_error("invalid socket"); +#ifdef _WIN32 + unsigned long mode = blocking ? 0 : 1; + if (ioctlsocket(fd, FIONBIO, &mode) != SOCKET_ERROR) + throw std::system_error(std::make_error_code(std::errc::io_error), "ioctlsocket failed"); +#else + int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) throw sys_error(errno); + if ((flags & O_NONBLOCK) != O_NONBLOCK && fcntl(fd, F_SETFL, flags | O_NONBLOCK) != 0) throw sys_error(errno); +#endif + socket sock(&io, fd); + sock.update_endpoint_info(); + return sock; + } + +#ifndef _WIN32 + std::pair socket::connected_pair_tcp(io_service& io, address_type addrtype) { + int domain = -1; + switch (addrtype) { + case address_type::ipv4: domain = AF_INET; break; + case address_type::ipv6: domain = AF_INET6; break; + case address_type::uds: domain = AF_UNIX; break; + } + if (domain == -1) throw sys_error(ENOTSUP); + + int socks[2]; +#ifndef __APPLE__ + if (socketpair(domain, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0, socks) != 0) throw sys_error(errno); +#else + if (socketpair(domain, SOCK_STREAM, 0, socks) != 0) throw sys_error(errno); + int flags0 = fcntl(socks[0], F_GETFL, 0); + int flags1 = fcntl(socks[1], F_GETFL, 0); + if (flags0 < 0 || flags1 < 0 || // + fcntl(socks[0], F_SETFL, flags0 | O_NONBLOCK) < 0 || fcntl(socks[1], F_SETFL, flags1 | O_NONBLOCK) < 0 || // + fcntl(socks[0], F_SETFD, FD_CLOEXEC) < 0 || fcntl(socks[1], F_SETFD, FD_CLOEXEC) < 0) { + close(socks[0]); + close(socks[1]); + throw std::system_error(errno, std::system_category(), "pipe failed"); + } +#endif + std::pair res{socket(&io, socks[0]), socket(&io, socks[1])}; + res.first.update_endpoint_info(); + res.second.update_endpoint_info(); + return res; + } + + std::pair socket::connected_pair_udp(io_service& io, address_type addrtype) { + int domain = -1; + switch (addrtype) { + case address_type::ipv4: domain = AF_INET; break; + case address_type::ipv6: domain = AF_INET6; break; + case address_type::uds: domain = AF_UNIX; break; + } + if (domain == -1) throw sys_error(ENOTSUP); + + int socks[2]; +#ifndef __APPLE__ + if (socketpair(domain, SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0, socks) != 0) throw sys_error(errno); +#else + if (socketpair(domain, SOCK_DGRAM, 0, socks) != 0) throw sys_error(errno); + int flags0 = fcntl(socks[0], F_GETFL, 0); + int flags1 = fcntl(socks[1], F_GETFL, 0); + if (flags0 < 0 || flags1 < 0 || // + fcntl(socks[0], F_SETFL, flags0 | O_NONBLOCK) < 0 || fcntl(socks[1], F_SETFL, flags1 | O_NONBLOCK) < 0 || // + fcntl(socks[0], F_SETFD, FD_CLOEXEC) < 0 || fcntl(socks[1], F_SETFD, FD_CLOEXEC) < 0) { + close(socks[0]); + close(socks[1]); + throw std::system_error(errno, std::system_category(), "pipe failed"); + } +#endif + return {socket(&io, socks[0]), socket(&io, socks[1])}; + } +#endif + + socket::socket(io_service* io, int fd) noexcept : m_io{io}, m_fd{fd}, m_remote_ep{}, m_local_ep{} {} + + socket::socket(socket&& other) noexcept + : m_io{other.m_io}, m_fd{other.m_fd}, m_remote_ep{other.m_remote_ep}, m_local_ep{other.m_local_ep} { + other.m_io = nullptr; + other.m_fd = -1; + other.m_local_ep = {}; + other.m_remote_ep = {}; + } + + socket& socket::operator=(socket&& other) noexcept { + if (m_fd >= 0) { + close(m_fd); + // TODO: Log errors returned from close ? + m_fd = -1; + } + m_io = other.m_io; + other.m_io = nullptr; + m_fd = other.m_fd; + other.m_fd = -1; + m_local_ep = other.m_local_ep; + other.m_local_ep = {}; + m_remote_ep = other.m_remote_ep; + other.m_remote_ep = {}; + + return *this; + } + + socket::~socket() { + if (m_fd >= 0) { + close(m_fd); + // TODO: Log errors returned from close ? + m_fd = -1; + } + } + + void socket::bind(const endpoint& ep) { + if (m_fd < 0) throw std::logic_error("invalid socket"); + + auto sa = ep.to_sockaddr(); + auto res = ::bind(m_fd, reinterpret_cast(&sa.first), sa.second); + if (res < 0) throw sys_error(errno); + + update_endpoint_info(); + } + + void socket::listen(std::uint32_t backlog) { + if (m_fd < 0) throw std::logic_error("invalid socket"); + + if (backlog == 0) backlog = 20; + auto res = ::listen(m_fd, backlog); + if (res < 0) throw sys_error(errno); + } + + void socket::allow_broadcast(bool enable) { + if (m_fd < 0) throw std::logic_error("invalid socket"); + + int opt = enable ? 1 : 0; + auto res = setsockopt(m_fd, SOL_SOCKET, SO_BROADCAST, &opt, sizeof(opt)); + if (res < 0) throw sys_error(errno); + } + + void socket::close_send() { + if (m_fd < 0) throw std::logic_error("invalid socket"); + + auto res = ::shutdown(m_fd, SHUT_WR); + if (res < 0 && errno != ENOTCONN) throw sys_error(errno); + } + + void socket::close_recv() { + if (m_fd < 0) throw std::logic_error("invalid socket"); + + auto res = ::shutdown(m_fd, SHUT_RD); + if (res < 0 && errno != ENOTCONN) throw sys_error(errno); + } + + void socket::update_endpoint_info() { + sockaddr_storage sa; + socklen_t sa_size = sizeof(sa); + auto res = getpeername(m_fd, reinterpret_cast(&sa), &sa_size); + if (res >= 0) + m_remote_ep = endpoint(sa, sa_size); + else if (res < 0 && errno != ENOTCONN) + throw sys_error(errno); + else + m_remote_ep = {}; + + sa_size = sizeof(sa); + res = getsockname(m_fd, reinterpret_cast(&sa), &sa_size); + if (res < 0) throw sys_error(errno); + m_local_ep = endpoint(sa, sa_size); + } + +} // namespace asyncpp::io diff --git a/src/tls.cpp b/src/tls.cpp new file mode 100644 index 0000000..0fa58aa --- /dev/null +++ b/src/tls.cpp @@ -0,0 +1,932 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace asyncpp::io::tls { + namespace { +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + const SSL_METHOD* ossl_method_from_enum(method meth, mode m) { + switch (meth) { + case method::tls: return m == mode::server ? TLS_server_method() : TLS_client_method(); +#ifndef OPENSSL_NO_SSL3_METHOD + case method::sslv3: return m == mode::server ? SSLv3_server_method() : SSLv3_client_method(); +#endif + case method::tlsv1: return m == mode::server ? TLSv1_server_method() : TLSv1_client_method(); + case method::tlsv1_1: return m == mode::server ? TLSv1_1_server_method() : TLSv1_1_client_method(); + case method::tlsv1_2: return m == mode::server ? TLSv1_2_server_method() : TLSv1_2_client_method(); + case method::dtls: return m == mode::server ? DTLS_server_method() : DTLS_client_method(); + case method::dtlsv1: return m == mode::server ? DTLSv1_server_method() : DTLSv1_client_method(); + case method::dtlsv1_2: return m == mode::server ? DTLSv1_2_server_method() : DTLSv1_2_client_method(); + default: throw std::logic_error("invalid method"); + } + } +#pragma GCC diagnostic pop + + void throw_ossl_error() { + auto error = ERR_get_error(); + ERR_clear_error(); + char buf[128]; + ERR_error_string_n(error, buf, sizeof(buf)); + throw std::runtime_error("Openssl failed: " + std::to_string(error) + " " + std::string(buf)); + } + +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + int set_null_on_dup(CRYPTO_EX_DATA* to, const CRYPTO_EX_DATA*, void**, int idx, long, void*) { +#else + int set_null_on_dup(CRYPTO_EX_DATA* to, const CRYPTO_EX_DATA*, void*, int idx, long, void*) { +#endif + CRYPTO_set_ex_data(to, idx, nullptr); + return 1; + } + + int context_udi() { + static int index = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, set_null_on_dup, nullptr); + if (index < 0) throw std::runtime_error("Failed to register custom data"); + return index; + } + + int ssl_udi() { + static int index = SSL_get_ex_new_index(0, nullptr, nullptr, set_null_on_dup, nullptr); + if (index < 0) throw std::runtime_error("Failed to register custom data"); + return index; + } + } // namespace + + context::context(method meth, mode m) : m_method(meth), m_mode(m) { + auto ossl_method = ossl_method_from_enum(meth, m); + if (ossl_method == nullptr) throw_ossl_error(); + const auto udi = context_udi(); + auto ctx = SSL_CTX_new(ossl_method); + if (ctx == nullptr) throw_ossl_error(); + m_ctx = ctx; + SSL_CTX_set_ex_data(ctx, udi, this); + SSL_CTX_set_options(ctx, SSL_OP_ALL); + SSL_CTX_set_default_verify_paths(ctx); + if (m_mode == mode::client) SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr); + } + + context::~context() { + if (m_ctx) { + assert(SSL_CTX_get_ex_data(static_cast(m_ctx), context_udi()) == this); + SSL_CTX_free(static_cast(m_ctx)); + } + } + + void context::use_certificate(const std::string& file, file_type type) { + if (SSL_CTX_use_certificate_chain_file(static_cast(m_ctx), file.c_str()) != 1) throw_ossl_error(); + } + + void context::use_privatekey(const std::string& file, file_type type) { + if (SSL_CTX_use_PrivateKey_file(static_cast(m_ctx), file.c_str(), SSL_FILETYPE_PEM) != 1) + throw_ossl_error(); + if (!SSL_CTX_check_private_key(static_cast(m_ctx))) throw_ossl_error(); + } + + void context::set_passwd_callback(std::function cb) { + m_passwd_cb = std::move(cb); + SSL_CTX_set_default_passwd_cb(static_cast(m_ctx), [](char* buf, int len, int rw, void* udata) -> int { + auto that = static_cast(udata); + if (that->m_passwd_cb) return that->m_passwd_cb(buf, len, rw == 1); + buf[0] = '\0'; + return 0; + }); + SSL_CTX_set_default_passwd_cb_userdata(static_cast(m_ctx), this); + } + + void context::set_passwd(std::string passwd) { + set_passwd_callback([passwd = std::move(passwd)](char* buf, size_t len, bool encrypt) -> size_t { + strncpy(buf, passwd.c_str(), len); + buf[len - 1] = '\0'; + return strlen(buf); + }); + } + + void context::set_client_hello_callback(std::function cb) { + m_client_hello_cb = std::move(cb); + SSL_CTX_set_client_hello_cb( + static_cast(m_ctx), + [](SSL* ssl, int* al, void* udata) -> int { + auto that = static_cast(udata); + client_hello hello(ssl); + if (that->m_client_hello_cb) + return that->m_client_hello_cb(hello, *al) ? SSL_CLIENT_HELLO_SUCCESS : SSL_CLIENT_HELLO_ERROR; + return SSL_CLIENT_HELLO_SUCCESS; + }, + this); + } + + void context::set_verify(verify_mode mode) { + auto sslmode = SSL_VERIFY_NONE; + if (mode & verify_mode::peer) sslmode = SSL_VERIFY_PEER; + if (mode & verify_mode::fail_if_no_cert) sslmode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + if (mode & verify_mode::verify_once) sslmode |= SSL_VERIFY_CLIENT_ONCE; + if (mode & verify_mode::verify_post_handshake) sslmode |= SSL_VERIFY_POST_HANDSHAKE; + SSL_CTX_set_verify(static_cast(m_ctx), sslmode, nullptr); + } + + void context::set_default_verify_paths() { + if (SSL_CTX_set_default_verify_paths(static_cast(m_ctx)) != 1) throw_ossl_error(); + } + + void context::set_default_verify_dir() { + if (SSL_CTX_set_default_verify_dir(static_cast(m_ctx)) != 1) throw_ossl_error(); + } + + void context::set_default_verify_file() { + if (SSL_CTX_set_default_verify_file(static_cast(m_ctx)) != 1) throw_ossl_error(); + } + + void context::load_verify_locations(const std::string& file, const std::string& path) { + if (SSL_CTX_load_verify_locations(static_cast(m_ctx), file.empty() ? nullptr : file.c_str(), + path.empty() ? nullptr : path.c_str()) != 1) + throw_ossl_error(); + } + + std::vector context::ciphers() const { + auto ciphers = SSL_CTX_get_ciphers(static_cast(m_ctx)); + if (ciphers == nullptr) return {}; + std::vector result(sk_SSL_CIPHER_num(ciphers)); + for (size_t i = 0; i < result.size(); i++) + result[i] = cipher(sk_SSL_CIPHER_value(ciphers, i)); + return result; + } + + void context::set_alpn_protos(const std::vector& protos) { + if (protos.empty()) throw std::runtime_error("alpn list must not be empty"); + size_t list_length = 0; + for (auto& e : protos) { + if (e.empty()) continue; + if (e.size() > 255) throw std::runtime_error("alpn too large"); + list_length += 1 + e.size(); + } + std::vector list(list_length); + for (size_t i = 0; auto& e : protos) { + if (e.empty()) continue; + list[i] = e.size(); + memcpy(&list[i + 1], e.data(), e.size()); + i += 1 + e.size(); + } + if (SSL_CTX_set_alpn_protos(static_cast(m_ctx), list.data(), list.size()) != 0) + throw std::runtime_error("failed to set alpn protocols"); + } + + void context::set_alpn_select_callback( + std::function& protocols)> + cb) { + m_alpn_select_cb = std::move(cb); + SSL_CTX_set_alpn_select_cb( + static_cast(m_ctx), + [](SSL* ssl, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, + void* arg) -> int { + const auto that = static_cast(arg); + const auto wrapper = static_cast(SSL_get_ex_data(ssl, ssl_udi())); + // Early out if no callback is set + if (that == nullptr || wrapper == nullptr || !that->m_alpn_select_cb) return SSL_TLSEXT_ERR_NOACK; + + // Figure out number of protocols available + size_t num_protocols = 0; + for (unsigned int i = 0; i < inlen; i += 1 + in[i]) { + if (i + in[i] + 1 > inlen) return SSL_TLSEXT_ERR_NOACK; + num_protocols++; + } + + // Allocate and fill string_view array + std::unique_ptr protocols_ptr(new (std::nothrow) std::string_view[num_protocols]); + if (!protocols_ptr) return SSL_TLSEXT_ERR_NOACK; + std::span protocols{protocols_ptr.get(), num_protocols}; + for (unsigned int i = 0, ip = 0; i < inlen; i += 1 + in[i]) + protocols[ip++] = {reinterpret_cast(&in[i + 1]), in[i]}; + + // Call callback + std::string_view res; + try { + if (!that->m_alpn_select_cb(*wrapper, res, protocols)) return SSL_TLSEXT_ERR_NOACK; + } catch (...) { return SSL_TLSEXT_ERR_NOACK; } + + // If the returned value is not withing supplied array (e.g. a constant was assigned) search for it + if (res.data() < reinterpret_cast(in) || + res.data() >= reinterpret_cast(in + inlen)) { + bool found = false; + for (auto& e : protocols) { + if (e == res) { + res = e; + found = true; + break; + } + } + if (!found) return SSL_TLSEXT_ERR_NOACK; + } + *out = reinterpret_cast(res.data()); + *outlen = res.size(); + return SSL_TLSEXT_ERR_OK; + }, + this); + } + + void context::set_certificate_callback(std::function cb) { + m_cert_cb = std::move(cb); + SSL_CTX_set_cert_cb( + static_cast(m_ctx), + [](SSL* ssl, void* arg) -> int { + try { + const auto that = static_cast(arg); + if (!that->m_cert_cb) return 1; + const auto wrapper = static_cast(SSL_get_ex_data(ssl, ssl_udi())); + if (wrapper == nullptr) return 0; + that->m_cert_cb(*wrapper); + } catch (...) { return 0; } + return 1; + }, + this); + } + + std::vector context::get_chain_certs() const { + std::vector res; + STACK_OF(X509)* certs = nullptr; + if (SSL_CTX_get0_chain_certs(static_cast(m_ctx), &certs) != 1) throw_ossl_error(); + if (certs == nullptr) return res; + res.reserve(sk_X509_num(certs)); + for (int i = 0; i < sk_X509_num(certs); i++) { + auto cert = sk_X509_value(certs, i); + X509_up_ref(cert); + res.push_back(cert); + } + return res; + } + + void context::clear_chain_certs() { + if (SSL_CTX_clear_chain_certs(static_cast(m_ctx)) != 1) throw_ossl_error(); + } + + void context::debug() { + STACK_OF(X509)* sks = nullptr; + [[maybe_unused]] auto res = SSL_CTX_get0_chain_certs(static_cast(m_ctx), &sks); + printf("sks=%p\n", sks); + printf("\tsize=%d\n", sk_X509_num(sks)); + for (int i = 0; i < sk_X509_num(sks); i++) { + //std::cout << x509(sk_X509_value(sks, i)).to_der() << std::endl; + } + } + + context::client_hello::client_hello(void* ssl) noexcept : m_ssl(ssl) {} + + bool context::client_hello::is_v2() const noexcept { return SSL_client_hello_isv2(static_cast(m_ssl)) != 0; } + + std::span context::client_hello::random() const noexcept { + const unsigned char* ptr = nullptr; + auto size = SSL_client_hello_get0_random(static_cast(m_ssl), &ptr); + return std::as_bytes(std::span{ptr, size}); + } + + std::span context::client_hello::session_id() const noexcept { + const unsigned char* ptr = nullptr; + auto size = SSL_client_hello_get0_session_id(static_cast(m_ssl), &ptr); + return std::as_bytes(std::span{ptr, size}); + } + + const std::vector& context::client_hello::ciphers() const { + if (!m_ciphers.empty()) return m_ciphers; + + const unsigned char* bytes = nullptr; + STACK_OF(SSL_CIPHER)* sk = NULL, *scsv = NULL; + auto len = SSL_client_hello_get0_ciphers(static_cast(m_ssl), &bytes); + if (SSL_bytes_to_cipher_list(static_cast(m_ssl), bytes, len, + SSL_client_hello_isv2(static_cast(m_ssl)), &sk, &scsv) == 0) + throw std::runtime_error("failed to parse cipher list"); + + scope_guard del([sk, scsv]() noexcept { + sk_SSL_CIPHER_free(sk); + sk_SSL_CIPHER_free(scsv); + }); + m_ciphers.resize(sk_SSL_CIPHER_num(sk)); + for (size_t i = 0; i < m_ciphers.size(); i++) + m_ciphers[i] = cipher(sk_SSL_CIPHER_value(sk, i)); + m_signalling_ciphers.resize(sk_SSL_CIPHER_num(scsv)); + for (size_t i = 0; i < m_signalling_ciphers.size(); i++) + m_signalling_ciphers[i] = cipher(sk_SSL_CIPHER_value(scsv, i)); + return m_ciphers; + } + + const std::vector& context::client_hello::signalling_ciphers() const { + if (m_ciphers.empty() && m_signalling_ciphers.empty()) ciphers(); + return m_signalling_ciphers; + } + + std::span context::client_hello::compression_methods() const noexcept { + const unsigned char* ptr = nullptr; + auto size = SSL_client_hello_get0_compression_methods(static_cast(m_ssl), &ptr); + return std::as_bytes(std::span{ptr, size}); + } + + const std::set& context::client_hello::extensions() const { + if (!m_extensions_preset.empty()) return m_extensions_preset; + int* exts = nullptr; + size_t len = 0; + if (SSL_client_hello_get1_extensions_present(static_cast(m_ssl), &exts, &len) == 0) + throw std::runtime_error("failed to get present extensions"); + if (exts == nullptr) return m_extensions_preset; + for (size_t i = 0; i < len; i++) + m_extensions_preset.emplace(static_cast(exts[i])); + OPENSSL_free(exts); + return m_extensions_preset; + } + + bool context::client_hello::has_extension(unsigned int type) const { return extensions().contains(type); } + + std::span context::client_hello::extension(unsigned int type) const noexcept { + const unsigned char* ptr = nullptr; + size_t len = 0; + if (SSL_client_hello_get0_ext(static_cast(m_ssl), type, &ptr, &len) == 0) return {}; + return std::as_bytes(std::span{ptr, len}); + } + + void context::client_hello::replace_context(context& new_context) const noexcept { + SSL_set_SSL_CTX(static_cast(m_ssl), static_cast(new_context.m_ctx)); + SSL_clear_options(static_cast(m_ssl), 0xFFFFFFFFL); + SSL_set_options(static_cast(m_ssl), SSL_CTX_get_options(static_cast(new_context.m_ctx))); + } + + session* context::client_hello::get_session() const noexcept { + return static_cast(SSL_get_ex_data(static_cast(m_ssl), ssl_udi())); + } + + std::string_view context::client_hello::server_name_indication() const noexcept { + auto ext = extension(TLSEXT_TYPE_server_name); + if (ext.size() <= 2) return ""; + size_t len = static_cast(ext[0]) << 8; + len += static_cast(ext[1]); + if (len + 2 != ext.size()) return ""; + ext = ext.subspan(2, len); + if (ext.size() < 3 || static_cast(ext[0]) != TLSEXT_NAMETYPE_host_name) return ""; + len = static_cast(ext[1]) << 8; + len += static_cast(ext[2]); + if (len + 3 != ext.size()) return ""; + auto ptr = reinterpret_cast(ext.subspan(3).data()); + return {ptr, len}; + } + + session::session(const context& ctx) { + const auto udi = ssl_udi(); + auto ssl = SSL_new(static_cast(ctx.m_ctx)); + if (ssl == nullptr) throw_ossl_error(); + SSL_set_ex_data(ssl, udi, this); + if (ctx.m_mode == mode::server) + SSL_set_accept_state(ssl); + else + SSL_set_connect_state(ssl); + auto ibio = BIO_new(BIO_s_mem()); + if (ibio == nullptr) { + SSL_free(ssl); + throw_ossl_error(); + } + auto obio = BIO_new(BIO_s_mem()); + if (obio == nullptr) { + BIO_free_all(ibio); + SSL_free(ssl); + throw_ossl_error(); + } + BIO_set_mem_eof_return(obio, -1); + + SSL_set_bio(ssl, ibio, obio); + + m_ssl = ssl; + m_input_bio = ibio; + m_output_bio = obio; + } + + session::~session() { + if (m_ssl) { + assert(SSL_get_ex_data(static_cast(m_ssl), ssl_udi()) == this); + SSL_free(static_cast(m_ssl)); + } + } + + int session::try_handshake() { + do { + auto res = SSL_do_handshake(static_cast(m_ssl)); + try_resume_cipher_read(); + try_resume_cipher_write(); + if (res != 1) { + auto error = SSL_get_error(static_cast(m_ssl), res); + if (error == SSL_ERROR_WANT_WRITE && try_resume_cipher_read()) + continue; + else if (error == SSL_ERROR_WANT_READ && try_resume_cipher_write()) + continue; + } + return res; + } while (true); + } + + void session::shutdown() { + SSL_shutdown(static_cast(m_ssl)); + try_resume_cipher_read(); + try_resume_cipher_write(); + try_resume_plain(); + } + + bool session::try_resume_cipher_read() { + bool did_resume = false; + while (m_cipher_readers && m_cipher_readers->try_resume()) + did_resume = true; + return did_resume; + } + + bool session::try_resume_cipher_write() { + bool did_resume = false; + while (m_cipher_writers && m_cipher_writers->try_resume()) + did_resume = true; + return did_resume; + } + + void session::try_resume_plain() { + while (m_plain_readers && m_plain_readers->try_resume()) + ; + while (m_plain_writers && m_plain_writers->try_resume()) + ; + while (m_handshakers && m_handshakers->try_resume()) + ; + } + + bool session::try_read(void* buf, size_t len, size_t& read) { + if (SSL_get_shutdown(static_cast(m_ssl))) { + read = 0; + return true; + } + do { + if (!SSL_is_init_finished(static_cast(m_ssl))) try_handshake(); + auto res = SSL_read_ex(static_cast(m_ssl), buf, len, &read); + if (res == 0) { + auto error = SSL_get_error(static_cast(m_ssl), res); + if (error == SSL_ERROR_WANT_WRITE && try_resume_cipher_read()) continue; + if (error == SSL_ERROR_WANT_READ && try_resume_cipher_write()) continue; + if (error == SSL_ERROR_ZERO_RETURN) { + read = 0; + return true; + } + if (error != SSL_ERROR_WANT_WRITE && error != SSL_ERROR_WANT_READ) + throw std::runtime_error("SSL protocol error"); + } + return res == 1; + } while (true); + } + + bool session::try_write(const void* buf, size_t len, size_t& written) { + if (SSL_get_shutdown(static_cast(m_ssl))) { + written = 0; + return true; + } + do { + if (!SSL_is_init_finished(static_cast(m_ssl))) try_handshake(); + auto res = SSL_write_ex(static_cast(m_ssl), buf, len, &written); + if (res == 0) { + auto error = SSL_get_error(static_cast(m_ssl), res); + if (error == SSL_ERROR_WANT_WRITE && try_resume_cipher_read()) continue; + if (error == SSL_ERROR_WANT_READ && try_resume_cipher_write()) continue; + } + try_resume_cipher_read(); + return res == 1; + } while (true); + } + + cipher session::current_cipher() const noexcept { return cipher(SSL_get_current_cipher(static_cast(m_ssl))); } + + cipher session::pending_cipher() const noexcept { return cipher(SSL_get_pending_cipher(static_cast(m_ssl))); } + + std::vector session::ciphers() const { + auto ciphers = SSL_get_ciphers(static_cast(m_ssl)); + if (ciphers == nullptr) return {}; + std::vector result(sk_SSL_CIPHER_num(ciphers)); + for (size_t i = 0; i < result.size(); i++) + result[i] = cipher(sk_SSL_CIPHER_value(ciphers, i)); + return result; + } + + std::vector session::supported_ciphers() const { + auto ciphers = SSL_get1_supported_ciphers(static_cast(m_ssl)); + if (ciphers == nullptr) return {}; + scope_guard del([ciphers]() noexcept { sk_SSL_CIPHER_free(ciphers); }); + std::vector result(sk_SSL_CIPHER_num(ciphers)); + for (size_t i = 0; i < result.size(); i++) + result[i] = cipher(sk_SSL_CIPHER_value(ciphers, i)); + return result; + } + + std::vector session::client_ciphers() const { + auto ciphers = SSL_get_client_ciphers(static_cast(m_ssl)); + if (ciphers == nullptr) return {}; + std::vector result(sk_SSL_CIPHER_num(ciphers)); + for (size_t i = 0; i < result.size(); i++) + result[i] = cipher(sk_SSL_CIPHER_value(ciphers, i)); + return result; + } + + std::string_view session::get_servername() const noexcept { + auto res = SSL_get_servername(static_cast(m_ssl), TLSEXT_NAMETYPE_host_name); + return res ? res : ""; + } + + void session::set_servername(const std::string& name) { + if (SSL_set_tlsext_host_name(static_cast(m_ssl), name.c_str()) == 0) + throw std::runtime_error("failed to set hostname"); + } + + void session::set_verify(verify_mode mode) { + auto sslmode = SSL_VERIFY_NONE; + if (mode & verify_mode::peer) sslmode = SSL_VERIFY_PEER; + if (mode & verify_mode::fail_if_no_cert) sslmode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + if (mode & verify_mode::verify_once) sslmode |= SSL_VERIFY_CLIENT_ONCE; + if (mode & verify_mode::verify_post_handshake) sslmode |= SSL_VERIFY_POST_HANDSHAKE; + SSL_set_verify(static_cast(m_ssl), sslmode, nullptr); + } + + void session::set_alpn_protos(const std::vector& protos) { + if (protos.empty()) throw std::runtime_error("alpn list must not be empty"); + size_t list_length = 0; + for (auto& e : protos) { + if (e.empty()) continue; + if (e.size() > 255) throw std::runtime_error("alpn too large"); + list_length += 1 + e.size(); + } + std::vector list(list_length); + for (size_t i = 0; auto& e : protos) { + if (e.empty()) continue; + list[i] = e.size(); + memcpy(&list[i + 1], e.data(), e.size()); + i += 1 + e.size(); + } + if (SSL_set_alpn_protos(static_cast(m_ssl), list.data(), list.size()) != 0) + throw std::runtime_error("failed to set alpn protocols"); + } + + std::string_view session::alpn_selected() const noexcept { + const unsigned char* data = nullptr; + unsigned int len = 0; + SSL_get0_alpn_selected(static_cast(m_ssl), &data, &len); + if (data == nullptr || len == 0) return ""; + return {reinterpret_cast(data), len}; + } + + void session::set_certificate_callback(std::function cb) { + m_cert_cb = std::move(cb); + SSL_set_cert_cb( + static_cast(m_ssl), + [](SSL* ssl, void* arg) -> int { + try { + const auto that = static_cast(arg); + if (!that->m_cert_cb) return 1; + that->m_cert_cb(*that); + } catch (...) { return 0; } + return 1; + }, + this); + } + + x509 session::get_peer_certificate() const noexcept { return SSL_get_peer_certificate(static_cast(m_ssl)); } + + cipher::~cipher() { + if (m_description) OPENSSL_free(m_description); + } + + std::string_view cipher::name() const noexcept { + return SSL_CIPHER_get_name(static_cast(m_cipher)); + } + + std::string_view cipher::standard_name() const noexcept { + return SSL_CIPHER_standard_name(static_cast(m_cipher)); + } + + std::string_view cipher::cipher_name() const noexcept { + return OPENSSL_cipher_name(SSL_CIPHER_standard_name(static_cast(m_cipher))); + } + + size_t cipher::bit_count() const noexcept { + return SSL_CIPHER_get_bits(static_cast(m_cipher), nullptr); + } + + std::string_view cipher::version() const noexcept { + return SSL_CIPHER_get_version(static_cast(m_cipher)); + } + + std::string_view cipher::description() const noexcept { + if (!m_cipher) return "(null)"; + if (!m_description) return m_description; + m_description = SSL_CIPHER_description(static_cast(m_cipher), nullptr, 128); + return m_description; + } + + int cipher::cipher_nid() const noexcept { + return SSL_CIPHER_get_cipher_nid(static_cast(m_cipher)); + } + + int cipher::digest_nid() const noexcept { + return SSL_CIPHER_get_digest_nid(static_cast(m_cipher)); + } + + int cipher::kx_nid() const noexcept { return SSL_CIPHER_get_kx_nid(static_cast(m_cipher)); } + + int cipher::auth_nid() const noexcept { return SSL_CIPHER_get_auth_nid(static_cast(m_cipher)); } + + bool cipher::is_aead() const noexcept { return SSL_CIPHER_is_aead(static_cast(m_cipher)) == 1; } + + uint32_t cipher::id() const noexcept { return SSL_CIPHER_get_id(static_cast(m_cipher)); } + + uint32_t cipher::protocol_id() const noexcept { + return SSL_CIPHER_get_protocol_id(static_cast(m_cipher)); + } + + std::ostream& operator<<(std::ostream& str, const cipher& cipher) { return str << cipher.description(); } + + x509::~x509() { + if (m_x509) X509_free(static_cast(m_x509)); + } + + std::string x509::to_der() const { + unsigned char* out = nullptr; + auto len = i2d_X509(static_cast(m_x509), &out); + if (out == nullptr || len == 0) throw std::runtime_error("failed to convert to der"); + std::string res; + res.assign(reinterpret_cast(out), len); + OPENSSL_free(out); + return res; + } + + std::string x509::to_pem() const { + auto bio = BIO_new(BIO_s_mem()); + scope_guard guard{[bio]() noexcept { BIO_free(bio); }}; + if (PEM_write_bio_X509(bio, static_cast(m_x509)) != 1) + throw std::runtime_error("failed to convert to pem"); + BUF_MEM* bptr = nullptr; + BIO_get_mem_ptr(bio, &bptr); + return std::string{bptr->data, bptr->length}; + } + + x509 x509::from_der(const void* ptr, size_t len) { + auto u8ptr = reinterpret_cast(ptr); + auto res = d2i_X509(nullptr, &u8ptr, len); + return x509(res); + } + + x509 x509::from_pem(const void* ptr, size_t len) { + auto bio = BIO_new(BIO_s_mem()); + scope_guard guard{[bio]() noexcept { BIO_free(bio); }}; + BIO_write(bio, ptr, len); + auto res = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + return x509(res); + } + + std::chrono::system_clock::time_point x509::not_before() const noexcept { + int day = 0, sec = 0; + ASN1_TIME_diff(&day, &sec, nullptr, X509_get0_notBefore(static_cast(m_x509))); + auto now = std::chrono::system_clock::now(); + return now + std::chrono::days(day) + std::chrono::seconds(sec); + } + + std::chrono::system_clock::time_point x509::not_after() const noexcept { + int day = 0, sec = 0; + ASN1_TIME_diff(&day, &sec, nullptr, X509_get0_notAfter(static_cast(m_x509))); + auto now = std::chrono::system_clock::now(); + return now + std::chrono::days(day) + std::chrono::seconds(sec); + } + + std::string x509::subject() const { + auto name = X509_get_subject_name(static_cast(m_x509)); + auto bio = BIO_new(BIO_s_mem()); + scope_guard guard{[bio]() noexcept { BIO_free(bio); }}; + if (X509_NAME_print_ex(bio, name, 0, XN_FLAG_ONELINE & ~ASN1_STRFLGS_ESC_MSB) == -1) throw_ossl_error(); + BUF_MEM* bptr = nullptr; + BIO_get_mem_ptr(bio, &bptr); + return std::string{bptr->data, bptr->length}; + } + + std::string x509::issuer() const { + auto name = X509_get_issuer_name(static_cast(m_x509)); + auto bio = BIO_new(BIO_s_mem()); + scope_guard guard{[bio]() noexcept { BIO_free(bio); }}; + if (X509_NAME_print_ex(bio, name, 0, XN_FLAG_ONELINE & ~ASN1_STRFLGS_ESC_MSB) == -1) throw_ossl_error(); + BUF_MEM* bptr = nullptr; + BIO_get_mem_ptr(bio, &bptr); + return std::string{bptr->data, bptr->length}; + } + + std::strong_ordering operator<=>(const x509& lhs, const x509& rhs) noexcept { + if (lhs.m_x509 == rhs.m_x509) return std::strong_ordering::equal; + if (lhs.m_x509 == nullptr) return std::strong_ordering::less; + if (rhs.m_x509 == nullptr) return std::strong_ordering::greater; + const auto res = X509_cmp(static_cast(lhs.m_x509), static_cast(rhs.m_x509)); + if (res < 0) + return std::strong_ordering::less; + else if (res > 0) + return std::strong_ordering::greater; + else + return std::strong_ordering::equal; + } + + bool operator==(const x509& lhs, const x509& rhs) noexcept { return (lhs <=> rhs) == std::strong_ordering::equal; } + + bool operator!=(const x509& lhs, const x509& rhs) noexcept { return (lhs <=> rhs) != std::strong_ordering::equal; } + + bool plain_read_awaitable::try_resume() { + auto res = m_session.try_read(m_buffer, m_len, m_result); + if (res) { + m_handle.resume(); + return true; + } + return false; + } + + bool plain_read_awaitable::await_ready() const noexcept { return false; } + + bool plain_read_awaitable::await_suspend(coroutine_handle<> hdl) { + auto res = m_session.try_read(m_buffer, m_len, m_result); + if (!res) { + m_handle = hdl; + plain_read_awaitable** last = &m_session.m_plain_readers; + while (*last != nullptr) + last = &(*last)->m_next; + *last = this; + } + return !res; + } + + size_t plain_read_awaitable::await_resume() { + if (m_handle) { + assert(m_session.m_plain_readers == this); + m_session.m_plain_readers = m_next; + } + return m_result; + } + + bool plain_write_awaitable::try_resume() { + auto res = m_session.try_write(m_buffer, m_len, m_result); + if (res && m_result != 0) { + m_handle.resume(); + return true; + } + return false; + } + + bool plain_write_awaitable::await_ready() const noexcept { return false; } + + bool plain_write_awaitable::await_suspend(coroutine_handle<> hdl) { + auto res = m_session.try_write(m_buffer, m_len, m_result); + if (!res) { + m_handle = hdl; + plain_write_awaitable** last = &m_session.m_plain_writers; + while (*last != nullptr) + last = &(*last)->m_next; + *last = this; + } + return !res; + } + + size_t plain_write_awaitable::await_resume() { + if (m_handle) { + assert(m_session.m_plain_writers == this); + m_session.m_plain_writers = m_next; + } + return m_result; + } + + bool cipher_read_awaitable::try_resume() { + if (SSL_get_shutdown(static_cast(m_session.m_ssl))) { + m_result = 0; + m_handle.resume(); + return true; + } + auto res = BIO_read_ex(static_cast(m_session.m_output_bio), m_buffer, m_len, &m_result); + if (res == 1 && m_result != 0) { + m_handle.resume(); + return true; + } + return false; + } + + bool cipher_read_awaitable::await_ready() const noexcept { return false; } + + bool cipher_read_awaitable::await_suspend(coroutine_handle<> hdl) { + if (SSL_get_shutdown(static_cast(m_session.m_ssl))) { + m_result = 0; + return false; + } + auto res = BIO_read_ex(static_cast(m_session.m_output_bio), m_buffer, m_len, &m_result); + bool suspend = !(res == 1 && m_result != 0); + if (suspend) { + m_handle = hdl; + cipher_read_awaitable** last = &m_session.m_cipher_readers; + while (*last != nullptr) + last = &(*last)->m_next; + *last = this; + } + return suspend; + } + + size_t cipher_read_awaitable::await_resume() { + if (m_handle) { + assert(m_session.m_cipher_readers == this); + m_session.m_cipher_readers = m_next; + } + return m_result; + } + + bool cipher_write_awaitable::try_resume() { + if (BIO_ctrl_pending(static_cast(m_session.m_input_bio)) > 256 * 1024) { + if (m_len == 0) m_session.shutdown(); + size_t len; + auto res = BIO_write_ex(static_cast(m_session.m_input_bio), m_buffer, m_len, &len); + m_session.try_resume_plain(); + // Write + assert(res == 1); + assert(len == m_len); + m_handle.resume(); + return true; + } + return false; + } + + bool cipher_write_awaitable::await_ready() const noexcept { return false; } + + bool cipher_write_awaitable::await_suspend(coroutine_handle<> hdl) { + if (BIO_ctrl_pending(static_cast(m_session.m_input_bio)) > 256 * 1024) { + // Memory bios can grow indefinitly, to avoid buffering to much data we suspend if the amount + // exceeds 256K. + m_handle = hdl; + cipher_write_awaitable** last = &m_session.m_cipher_writers; + while (*last != nullptr) + last = &(*last)->m_next; + *last = this; + return true; + } + if (m_len == 0) m_session.shutdown(); + auto res = BIO_write_ex(static_cast(m_session.m_input_bio), m_buffer, m_len, &m_result); + // Write + assert(m_result == 0 || res == 1); + assert(m_result == m_len); + m_session.try_resume_plain(); + return false; + } + + size_t cipher_write_awaitable::await_resume() { + if (m_handle) { + assert(m_session.m_cipher_writers == this); + m_session.m_cipher_writers = m_next; + } + return m_result; + } + + bool handshake_awaitable::try_resume() { + auto res = m_session.try_handshake(); + if (res != 1) { + auto error = SSL_get_error(static_cast(m_session.m_ssl), res); + if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ) return false; + m_result = error; + } + m_handle.resume(); + return true; + } + + bool handshake_awaitable::await_ready() const noexcept { + return SSL_is_init_finished(static_cast(m_session.m_ssl)); + } + + bool handshake_awaitable::await_suspend(coroutine_handle<> hdl) { + auto res = m_session.try_handshake(); + if (res != 1) { + auto error = SSL_get_error(static_cast(m_session.m_ssl), res); + if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ) { + m_handle = hdl; + handshake_awaitable** last = &m_session.m_handshakers; + while (*last != nullptr) + last = &(*last)->m_next; + *last = this; + return true; + } + m_result = error; + } + return false; + } + + void handshake_awaitable::await_resume() { + if (m_handle) { + assert(m_session.m_handshakers == this); + m_session.m_handshakers = m_next; + } + if (m_result != 0) + throw std::runtime_error("SSL Handshake failed: " + + std::to_string(SSL_get_verify_result(static_cast(m_session.m_ssl)))); + } + +} // namespace asyncpp::io::tls diff --git a/test/address.cpp b/test/address.cpp new file mode 100644 index 0000000..14ab8b9 --- /dev/null +++ b/test/address.cpp @@ -0,0 +1,176 @@ +#include + +#include + +using namespace asyncpp::io; + +TEST(ASYNCPP_IO, IPv4ConstructAndAccess) { + ASSERT_EQ(ipv4_address().integer(), 0) << "Default constructor should result in 0.0.0.0"; + ASSERT_EQ(ipv4_address(127, 0, 0, 1).integer(), 0x7f000001u) << "Piecewise construction"; + ASSERT_EQ(ipv4_address(127, 0, 0, 1).integer(std::endian::little), 0x0100007fu) + << "Piecewise construction is big endian"; + ASSERT_EQ(ipv4_address(0x7f000001u).integer(), 0x7f000001u) << "uint32_t construction"; + ASSERT_EQ(ipv4_address(0x0100007fu, std::endian::little).integer(), 0x7f000001u) + << "uint32_t construction (little endian)"; + ASSERT_EQ(ipv4_address(std::array{127, 0, 0, 1}), ipv4_address(127, 0, 0, 1)) + << "Array of uint8_t construction"; + uint8_t data[] = {127, 0, 0, 1}; + ASSERT_EQ(ipv4_address(data), ipv4_address(127, 0, 0, 1)) << "Array of uint8_t construction"; +} + +TEST(ASYNCPP_IO, IPv4AddrType) { + ASSERT_EQ(ipv4_address::loopback(), ipv4_address(127, 0, 0, 1)) << "loopback() returns 127.0.0.1"; + ASSERT_EQ(ipv4_address::any(), ipv4_address(0, 0, 0, 0)) << "any() returns 0.0.0.0"; + ASSERT_TRUE(ipv4_address::loopback().is_loopback()); + ASSERT_FALSE(ipv4_address::any().is_loopback()); + ASSERT_FALSE(ipv4_address::loopback().is_private()); + ASSERT_FALSE(ipv4_address::any().is_private()); + ASSERT_TRUE(ipv4_address(10, 0, 0, 1).is_private()); + ASSERT_FALSE(ipv4_address(8, 8, 8, 8).is_private()); +} + +TEST(ASYNCPP_IO, IPv4Parse) { + ASSERT_EQ(ipv4_address::parse(""), std::nullopt); + ASSERT_EQ(ipv4_address::parse("1"), std::nullopt); + ASSERT_EQ(ipv4_address::parse("1.0.01"), std::nullopt); + ASSERT_EQ(ipv4_address::parse("10.0.0.1 "), std::nullopt); + ASSERT_EQ(ipv4_address::parse(" 10.0.0.1"), std::nullopt); + ASSERT_EQ(ipv4_address::parse("256.0.0.1"), std::nullopt); + ASSERT_EQ(ipv4_address::parse("1.0.0.1"), ipv4_address(1, 0, 0, 1)); + ASSERT_EQ(ipv4_address::parse("10.0.0.1"), ipv4_address(10, 0, 0, 1)); + ASSERT_EQ(ipv4_address::parse("100.0.0.1"), ipv4_address(100, 0, 0, 1)); + + static constexpr auto static_parse = ipv4_address::parse("10.0.0.1"); + static_assert(static_parse.has_value()); + static_assert(static_parse.value() == ipv4_address(10, 0, 0, 1)); +} + +TEST(ASYNCPP_IO, IPv4ToString) { + ASSERT_EQ(ipv4_address(0, 0, 0, 0).to_string(), "0.0.0.0"); + ASSERT_EQ(ipv4_address(1, 0, 0, 0).to_string(), "1.0.0.0"); + ASSERT_EQ(ipv4_address(10, 0, 0, 0).to_string(), "10.0.0.0"); + ASSERT_EQ(ipv4_address(100, 0, 0, 0).to_string(), "100.0.0.0"); + ASSERT_EQ(ipv4_address(0, 0, 0, 0).to_string(), "0.0.0.0"); + ASSERT_EQ(ipv4_address(0, 1, 0, 0).to_string(), "0.1.0.0"); + ASSERT_EQ(ipv4_address(0, 10, 0, 0).to_string(), "0.10.0.0"); + ASSERT_EQ(ipv4_address(0, 100, 0, 0).to_string(), "0.100.0.0"); +} + +TEST(ASYNCPP_IO, IPv6Parse) { + ASSERT_EQ(ipv6_address::parse(""), std::nullopt); + ASSERT_EQ(ipv6_address::parse("123"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("foo"), std::nullopt); + ASSERT_EQ(ipv6_address::parse(":1234"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("0102:0304:0506:0708:090a:0b0c:0d0e:0f10 "), std::nullopt); + ASSERT_EQ(ipv6_address::parse(" 0102:0304:0506:0708:090a:0b0c:0d0e:0f10"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("0102:0304:0506:0708:090a:0b0c:0d0e:0f10:"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("0102:0304:0506:0708:090a:0b0c:0d0e"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("01022:0304:0506:0708:090a:0b0c:0d0e:0f10"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("0102:0304:0506:192.168.0.1:0b0c:0d0e:0f10"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("::"), ipv6_address(0, 0)); + ASSERT_EQ(ipv6_address::parse("::1"), ipv6_address::loopback()); + ASSERT_EQ(ipv6_address::parse("::01"), ipv6_address::loopback()); + ASSERT_EQ(ipv6_address::parse("::001"), ipv6_address::loopback()); + ASSERT_EQ(ipv6_address::parse("::0001"), ipv6_address::loopback()); + ASSERT_EQ(ipv6_address::parse("0102:0304:0506:0708:090a:0b0c:0d0e:0f10"), + ipv6_address(0x0102030405060708, 0x090A0B0C0D0E0F10)); + ASSERT_EQ(ipv6_address::parse("0002:0304:0506:0708:090a:0b0c:0d0e:0f10"), + ipv6_address(0x0002030405060708, 0x090A0B0C0D0E0F10)); + ASSERT_EQ(ipv6_address::parse("0000:0304:0506:0708:090a:0b0c:0d0e:0f10"), + ipv6_address(0x0000030405060708, 0x090A0B0C0D0E0F10)); + ASSERT_EQ(ipv6_address::parse("::0506:0708:090a:0b0c:0d0e:0f10"), + ipv6_address(0x0000000005060708, 0x090A0B0C0D0E0F10)); + ASSERT_EQ(ipv6_address::parse("0102:0304::0b0c:0d0e:0f10"), ipv6_address(0x0102030400000000, 0x00000B0C0D0E0F10)); + ASSERT_EQ(ipv6_address::parse("0102:0304:0506:0708:090a:0b0c::"), + ipv6_address(0x0102030405060708, 0x090A0B0C00000000)); + ASSERT_EQ(ipv6_address::parse("2001:db8:85a3:8d3:1319:8a2e:370:7348"), + ipv6_address(0x20010db885a308d3, 0x13198a2e03707348)); + + ASSERT_EQ(ipv6_address::parse("::ffff:192.168.0.1"), ipv6_address(0x0, 0xffffc0a80001)); + // https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5.2 requires a ipv4 mapped address to be in the form of ::ffff:xxxx:xxxx + ASSERT_EQ(ipv6_address::parse("0102:0304::128.69.32.17"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("0102:0304::128.69.32.17"), std::nullopt); + + // Hexadecimal chars in dotted decimal part + ASSERT_EQ(ipv6_address::parse("64:ff9b::12f.100.30.1"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("64:ff9b::123.10a.30.1"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("64:ff9b::123.100.3d.1"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("64:ff9b::12f.100.30.f4"), std::nullopt); + + // Overflow of individual parts of dotted decimal notation + ASSERT_EQ(ipv6_address::parse("::ffff:456.12.45.30"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("::ffff:45.256.45.30"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("::ffff:45.25.677.30"), std::nullopt); + ASSERT_EQ(ipv6_address::parse("::ffff:123.12.45.301"), std::nullopt); + + static constexpr auto static_parse = ipv6_address::parse("::1"); + static_assert(static_parse.has_value()); + static_assert(static_parse.value() == ipv6_address::loopback()); +} + +TEST(ASYNCPP_IO, IPv6MapIPv4) { + ASSERT_EQ(ipv6_address(ipv4_address(10, 0, 0, 1)).to_string(), "::ffff:a00:1"); + ASSERT_TRUE(ipv6_address(ipv4_address(10, 0, 0, 1)).is_ipv4_mapped()); + ASSERT_EQ(ipv6_address(ipv4_address(10, 0, 0, 1)).mapped_ipv4(), ipv4_address(10, 0, 0, 1)); + ASSERT_FALSE(ipv6_address(0x0102030405060708, 0x090A0B0C00000000).is_ipv4_mapped()); +} + +TEST(ASYNCPP_IO, IPv6ToString) { + ASSERT_EQ(ipv6_address(0, 0).to_string(), "::"); + ASSERT_EQ(ipv6_address::loopback().to_string(), "::1"); + ASSERT_EQ(ipv6_address::loopback().to_string(true), "0000:0000:0000:0000:0000:0000:0000:0001"); + + ASSERT_EQ(ipv6_address(0x0102030405060708, 0x090A0B0C0D0E0F10).to_string(), "102:304:506:708:90a:b0c:d0e:f10"); + ASSERT_EQ(ipv6_address(0x0001001001001000, 0x0).to_string(), "1:10:100:1000::"); + ASSERT_EQ(ipv6_address(0x0002030405060708, 0x090A0B0C0D0E0F10).to_string(), "2:304:506:708:90a:b0c:d0e:f10"); + ASSERT_EQ(ipv6_address(0x0000030405060708, 0x090A0B0C0D0E0F10).to_string(), "0:304:506:708:90a:b0c:d0e:f10"); + ASSERT_EQ(ipv6_address(0x0000000005060708, 0x090A0B0C0D0E0F10).to_string(), "::506:708:90a:b0c:d0e:f10"); + ASSERT_EQ(ipv6_address(0x0102030400000000, 0x00000B0C0D0E0F10).to_string(), "102:304::b0c:d0e:f10"); + ASSERT_EQ(ipv6_address(0x0102030405060708, 0x090A0B0C0D0E0000).to_string(), "102:304:506:708:90a:b0c:d0e:0"); + ASSERT_EQ(ipv6_address(0x0102030405060708, 0x090A0B0C00000000).to_string(), "102:304:506:708:90a:b0c::"); + + // Check that it contracts the first of multiple equal-length zero runs. + ASSERT_EQ(ipv6_address(0x0102030400000000, 0x090A0B0C00000000).to_string(), "102:304::90a:b0c:0:0"); +} + +TEST(ASYNCPP_IO, UDSParse) { + ASSERT_EQ(uds_address::parse(std::string_view("\0", 1)), std::nullopt); + ASSERT_EQ(uds_address::parse(std::string_view("@")), std::nullopt); + ASSERT_EQ(uds_address::parse(std::string_view("@\0", 2)), std::nullopt); + ASSERT_EQ(uds_address::parse(std::string_view(" test.sock")), std::nullopt); + ASSERT_EQ(uds_address::parse(std::string_view("test.sock ")), std::nullopt); + ASSERT_EQ(uds_address::parse(std::string(109, 's')), std::nullopt); + auto addr = uds_address::parse("uds.socket"); + ASSERT_TRUE(addr.has_value()); + ASSERT_EQ(addr->data().size(), 10); + ASSERT_TRUE(memcmp(addr->data().data(), "uds.socket", 10) == 0); + addr = uds_address::parse("./uds.socket"); + ASSERT_TRUE(addr.has_value()); + ASSERT_EQ(addr->data().size(), 12); + ASSERT_TRUE(memcmp(addr->data().data(), "./uds.socket", 12) == 0); + addr = uds_address::parse("@uds.socket"); + ASSERT_TRUE(addr.has_value()); + ASSERT_EQ(addr->data().size(), 11); + ASSERT_TRUE(memcmp(addr->data().data(), "\0uds.socket", 11) == 0); + addr = uds_address::parse(""); + ASSERT_TRUE(addr.has_value()); + ASSERT_EQ(addr->data().size(), 0); + + static constexpr auto static_parse = uds_address::parse("@uds.socket"); + static_assert(static_parse.has_value()); +} + +TEST(ASYNCPP_IO, UDSToString) { + ASSERT_EQ(uds_address("@uds").to_string(), "@uds"); + ASSERT_EQ(uds_address(std::string_view("\0uds", 4)).to_string(), "@uds"); + ASSERT_EQ(uds_address("./uds").to_string(), "./uds"); +} + +TEST(ASYNCPP_IO, UDSTypes) { + ASSERT_TRUE(uds_address("@uds").is_abstract()); + ASSERT_FALSE(uds_address("@uds").is_unnamed()); + ASSERT_FALSE(uds_address("uds").is_abstract()); + ASSERT_FALSE(uds_address("uds").is_unnamed()); + ASSERT_FALSE(uds_address("").is_abstract()); + ASSERT_TRUE(uds_address("").is_unnamed()); +} diff --git a/test/dns.cpp b/test/dns.cpp new file mode 100644 index 0000000..164cf0f --- /dev/null +++ b/test/dns.cpp @@ -0,0 +1,73 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace asyncpp::io; + +TEST(ASYNCPP_IO, DNSResolve) { + io_service service; + dns::client client(service); + + client.add_nameserver(address::parse("1.1.1.1").value()); + client.set_retries(3); + client.set_timeout(std::chrono::milliseconds(250)); + + dns::api_error error = dns::api_error::timeout; + client.query("thalhammer.it", dns::qtype::a, dns::qclass::in, + [&client, &error](dns::api_error e, const_buffer res) { + if (e == dns::api_error::ok) { + dns::visit_answer(res, [&](std::string_view name, dns::qtype rtype, dns::qclass rclass, + uint32_t ttl, asyncpp::io::const_buffer rdata) { + if (rtype == dns::qtype::a && rclass == dns::qclass::in) { + std::error_code ec; + auto rr = dns::parse_a(rdata, res, ec); + if (!ec) std::cout << rr.to_string() << std::endl; + } + return true; + }); + } else + std::cout << e << std::endl; + + error = e; + client.stop(); + }); + + service.run(io_service::run_mode::while_active); + + ASSERT_EQ(error, dns::api_error::ok); +} + +TEST(ASYNCPP_IO, DNSResolveTimeout) { + io_service service; + dns::client client(service); + + client.add_nameserver(address::parse("2.2.2.2").value()); + client.set_retries(0); + client.set_timeout(std::chrono::milliseconds(100)); + + auto now = std::chrono::steady_clock::now(); + std::chrono::milliseconds dur; + + dns::api_error error = dns::api_error::timeout; + client.query("thalhammer.it", dns::qtype::a, dns::qclass::in, [&](dns::api_error e, const_buffer res) { + error = e; + dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - now); + client.stop(); + service.stop(); + }); + + service.run(io_service::run_mode::while_active); + + ASSERT_EQ(error, dns::api_error::timeout); + ASSERT_LE(dur, std::chrono::milliseconds(500)) << "Query took too long to react to timeout"; +} diff --git a/test/endpoint.cpp b/test/endpoint.cpp new file mode 100644 index 0000000..8c9928f --- /dev/null +++ b/test/endpoint.cpp @@ -0,0 +1,26 @@ +#include + +#include + +using namespace asyncpp::io; + +TEST(ASYNCPP_IO, IPv4EndpointParse) { + ASSERT_EQ(ipv4_endpoint::parse("1.0.0.1"), ipv4_endpoint(ipv4_address(1, 0, 0, 1), 0)); + ASSERT_EQ(ipv4_endpoint::parse("1.0.0.1:"), std::nullopt); + ASSERT_EQ(ipv4_endpoint::parse("1.0.0.1:1"), ipv4_endpoint(ipv4_address(1, 0, 0, 1), 1)); +} + +TEST(ASYNCPP_IO, IPv6EndpointParse) { + ASSERT_EQ(ipv6_endpoint::parse("[::1]"), ipv6_endpoint(ipv6_address::loopback(), 0)); + ASSERT_EQ(ipv6_endpoint::parse("[::1]:"), std::nullopt); + ASSERT_EQ(ipv6_endpoint::parse("[::1]:1"), ipv6_endpoint(ipv6_address::loopback(), 1)); +} + +TEST(ASYNCPP_IO, EndpointParse) { + ASSERT_EQ(endpoint::parse("[::1]"), endpoint(ipv6_address::loopback(), 0)); + ASSERT_EQ(endpoint::parse("[::1]:"), std::nullopt); + ASSERT_EQ(endpoint::parse("[::1]:1"), endpoint(ipv6_address::loopback(), 1)); + ASSERT_EQ(endpoint::parse("1.0.0.1"), endpoint(ipv4_address(1, 0, 0, 1), 0)); + ASSERT_EQ(endpoint::parse("1.0.0.1:"), std::nullopt); + ASSERT_EQ(endpoint::parse("1.0.0.1:1"), endpoint(ipv4_address(1, 0, 0, 1), 1)); +} diff --git a/test/file.cpp b/test/file.cpp new file mode 100644 index 0000000..43fa6c1 --- /dev/null +++ b/test/file.cpp @@ -0,0 +1,210 @@ +#include +#include +#include +#include + +#include + +using namespace asyncpp; + +namespace { + std::string read_file(const std::string& name) { + std::ifstream file(name, std::ios::in | std::ios::binary); + if (!file) throw std::runtime_error("failed to open file"); + std::ostringstream ss; + ss << file.rdbuf(); + return ss.str(); + } +} // namespace + +#if ASYNCPP_IO_HANDLE_FROM_FILEBUF +TEST(ASYNCPP_IO, FileFreeGetHandle) { + std::ofstream file("test.bin", std::ios::binary | std::ios::trunc); + auto hdl = io::detail::get_file_handle_from_filebuf(file.rdbuf()); + ASSERT_GE(hdl, 3); +} + +TEST(ASYNCPP_IO, FileFreeRead) { + std::fstream file("test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file); + ASSERT_TRUE(file.write("Hello World", 11)); + ASSERT_TRUE(file.flush()); + + io::io_service service; + std::string read; + async_launch_scope scope; + scope.invoke([&service, &file, &read]() -> task<> { + read.resize(128); + auto hdl = io::detail::get_file_handle_from_filebuf(file.rdbuf()); + auto size = co_await io::read(*service.engine(), hdl, read.data(), read.size(), 0); + read.resize(size); + service.stop(); + }); + + service.run(); + + ASSERT_TRUE(scope.all_done()); + ASSERT_EQ(read.size(), 11); + ASSERT_EQ(read, "Hello World"); +} + +TEST(ASYNCPP_IO, FileFreeWrite) { + std::fstream file("test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file); + + io::io_service service; + size_t write_size; + async_launch_scope scope; + scope.invoke([&service, &file, &write_size]() -> task<> { + auto hdl = io::detail::get_file_handle_from_filebuf(file.rdbuf()); + write_size = co_await io::write(*service.engine(), hdl, "Hello World", 11, 0); + service.stop(); + }); + + service.run(); + + ASSERT_EQ(write_size, 11); + + ASSERT_TRUE(scope.all_done()); + std::string read(128, '\0'); + ASSERT_EQ(file.read(read.data(), read.size()).gcount(), 11); + read.resize(11); + ASSERT_EQ(read, "Hello World"); +} +#endif + +TEST(ASYNCPP_IO, FileCreate) { + io::io_service service; + + { + io::file file(service, "test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + } + { + io::file file(service, std::string("test.bin"), + std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + } + { + io::file file(service, std::filesystem::path("test.bin"), + std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + } + { + io::file file(service); + ASSERT_FALSE(file.is_open()); + ASSERT_FALSE(file); + ASSERT_TRUE(!file); + } + { + io::file file(service); + file.open("test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + } + { + io::file file(service); + file.open(std::string("test.bin"), std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + } + { + io::file file(service); + file.open(std::filesystem::path("test.bin"), + std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + } + { + io::file file(service); + ASSERT_THROW(file.open(""), std::system_error); + ASSERT_FALSE(file.is_open()); + ASSERT_FALSE(file); + ASSERT_TRUE(!file); + } +} + +TEST(ASYNCPP_IO, FileWrite) { + io::io_service service; + + io::file file(service, "test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file); + + size_t write_size; + async_launch_scope scope; + scope.invoke([&file, &write_size, &service]() -> task<> { + write_size = co_await file.write("Hello World", 11, 0); + service.stop(); + }); + + service.run(); + ASSERT_TRUE(scope.all_done()); + + ASSERT_EQ(write_size, 11); + ASSERT_EQ(file.size(), 11); + auto content = read_file("test.bin"); + ASSERT_EQ(content.size(), 11); + ASSERT_EQ(content, "Hello World"); +} + +TEST(ASYNCPP_IO, FileRead) { + std::ofstream("test.bin", std::ios::out | std::ios::trunc | std::ios::binary) << "Hello World"; + + io::io_service service; + io::file file(service, "test.bin", std::ios_base::in | std::ios::binary); + ASSERT_TRUE(file); + + size_t read_size; + std::string read; + read.resize(file.size()); + async_launch_scope scope; + scope.invoke([&]() -> task<> { + read_size = co_await file.read(read.data(), read.size(), 0); + service.stop(); + }); + + service.run(); + ASSERT_TRUE(scope.all_done()); + + ASSERT_EQ(read_size, 11); + ASSERT_EQ(read, "Hello World"); + + scope.invoke([&]() -> task<> { + read_size = co_await file.read(read.data(), read.size(), 6); + service.stop(); + }); + service.run(); + ASSERT_TRUE(scope.all_done()); + + ASSERT_EQ(read_size, 5); + read.resize(read_size); + ASSERT_EQ(read, "World"); +} + +TEST(ASYNCPP_IO, FileOpenClose) { + io::io_service service; + + io::file file(service, "test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); + + file.close(); + ASSERT_FALSE(file.is_open()); + ASSERT_FALSE(file); + ASSERT_TRUE(!file); + + file.open("test.bin", std::ios_base::in | std::ios_base::out | std::ios::binary | std::ios::trunc); + ASSERT_TRUE(file.is_open()); + ASSERT_TRUE(file); + ASSERT_FALSE(!file); +} diff --git a/test/network.cpp b/test/network.cpp new file mode 100644 index 0000000..83896e6 --- /dev/null +++ b/test/network.cpp @@ -0,0 +1,51 @@ +#include +#include + +#include + +using namespace asyncpp::io; + +TEST(ASYNCPP_IO, IPv4NetworkContains) { + ASSERT_FALSE(ipv4_network(ipv4_address(10, 0, 0, 0), 24).contains(ipv4_address(10, 0, 1, 1))); + ASSERT_TRUE(ipv4_network(ipv4_address(10, 0, 0, 0), 24).contains(ipv4_address(10, 0, 0, 1))); + ASSERT_TRUE(ipv4_network(ipv4_address(10, 0, 0, 0), 0).contains(ipv4_address(233, 1, 22, 5))); + ASSERT_FALSE(ipv4_network(ipv4_address(10, 0, 0, 0), 1).contains(ipv4_address(233, 1, 22, 5))); + ASSERT_TRUE(ipv4_network(ipv4_address(10, 0, 0, 1), 32).contains(ipv4_address(10, 0, 0, 1))); +} + +TEST(ASYNCPP_IO, IPv6NetworkContains) { + const ipv6_address ip(0x0102030405060708, 0x090A0B0C0D0E0F10); + const ipv6_address ip2(0x090A0B0C0D0E0F10, 0x0102030405060708); + for (size_t i = 0; i <= 128; i++) { + ASSERT_TRUE(ipv6_network(ip, i).contains(ip)); + } + ASSERT_TRUE(ipv6_network(ip, 1).contains(ip2)); + ASSERT_TRUE(ipv6_network(ip, 2).contains(ip2)); + ASSERT_TRUE(ipv6_network(ip, 3).contains(ip2)); + ASSERT_TRUE(ipv6_network(ip, 4).contains(ip2)); + for (size_t i = 5; i <= 128; i++) { + ASSERT_FALSE(ipv6_network(ip, i).contains(ip2)); + ASSERT_FALSE(ipv6_network(ip2, i).contains(ip)); + } +} + +TEST(ASYNCPP_IO, IPv4Network) { + const ipv4_address ip(10, 0, 0, 22); + ASSERT_EQ(ipv4_network(ip, 24).canonical(), ipv4_address(10, 0, 0, 0)); + ASSERT_EQ(ipv4_network(ip, 24).broadcast(), ipv4_address(10, 0, 0, 255)); +} + +TEST(ASYNCPP_IO, IPv6Network) { + const ipv6_address ip(0x0102030405060708, 0x090A0B0C0D0E0F10); + ASSERT_EQ(ipv6_network(ip, 64).canonical(), ipv6_address(0x0102030405060708, 0)); + ASSERT_EQ(ipv6_network(ip, 64).broadcast(), ipv6_address(0x0102030405060708, 0xffffffffffffffff)); +} + +TEST(ASYNCPP_IO, IPv6Test) { + auto ip = ipv6_network(ipv6_address::parse("2003::").value(), 19); + auto ip2 = ipv6_network(ipv6_address::parse("2003:8:f401::").value(), 48); + auto ip3 = ipv6_network(ipv6_address::parse("2003:8:f40e::").value(), 48); + ASSERT_TRUE(ip2 > ip); + ASSERT_FALSE(ip2 < ip); + ASSERT_TRUE(ip3 > ip2); +} diff --git a/test/so_compat.cpp b/test/so_compat.cpp new file mode 100644 index 0000000..eaca507 --- /dev/null +++ b/test/so_compat.cpp @@ -0,0 +1,5 @@ +#ifdef ASYNCPP_SO_COMPAT +#define ASYNCPP_SO_COMPAT_IMPL +#endif +#include +#include diff --git a/test/socket.cpp b/test/socket.cpp new file mode 100644 index 0000000..c805e7e --- /dev/null +++ b/test/socket.cpp @@ -0,0 +1,145 @@ +#include +#include +#include +#include +#include + +#include + +using namespace asyncpp::io; +using asyncpp::launch; +using asyncpp::task; + +asyncpp::stop_token timeout(std::chrono::nanoseconds ts) { + asyncpp::stop_source source; + asyncpp::timer::get_default().schedule([source](bool) { source.request_stop(); }, ts); + return source.get_token(); +} + +TEST(ASYNCPP_IO, IOService) { + io_service service; + service.run(io_service::run_mode::nowait); +} + +TEST(ASYNCPP_IO, IOServicePush) { + static bool did_trigger = false; + auto service = io_service::get_default(); + service->push([]() { did_trigger = true; }); + for (size_t i = 0; i < 10 && !did_trigger; i++) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + ASSERT_TRUE(did_trigger) << "io_service did not wake up within 100ms"; +} + +TEST(ASYNCPP_IO, Socket) { + std::string result; + io_service service; + launch([](io_service& service, std::string& result) -> task<> { + const auto ip = endpoint::parse("194.36.147.124:80").value(); + auto client = socket::create_tcp(service, ip.type()); + co_await client.connect(ip); + constexpr std::string_view req = "GET / HTTP/1.0\r\n\r\n"; + co_await client.send(req.data(), req.size()); + while (true) { + char buf[1024]; + auto res = co_await client.recv(buf, 1024); + result.append(buf, res); + if (res == 0 || result.size() >= 1024) break; + } + }(service, result)); + + service.run(); + + ASSERT_FALSE(result.empty()); + ASSERT_TRUE(result.starts_with("HTTP")); +} + +TEST(ASYNCPP_IO, SocketSelf) { + io_service service; + std::string received; + auto stop = timeout(std::chrono::seconds(2)); + launch([](io_service& service, std::string& received, asyncpp::stop_token st) -> task<> { + // Launch a tcp server that accepts a single connection and sends "HELLO" + auto server = socket::create_and_bind_tcp(service, endpoint(ipv4_address::any(), 0)); + launch([](io_service& service, socket& server, asyncpp::stop_token st) -> task<> { + server.listen(); + auto client = co_await server.accept(st); + co_await client.send("HELLO", 5, st); + }(service, server, st)); + // Connect to said server + auto client = socket::create_tcp(service, server.local_endpoint().type()); + co_await client.connect(server.local_endpoint(), st); + // and read until connection is closed + while (true) { + char buf[128]; + auto res = co_await client.recv(&buf, 128, st); + if (res == 0) break; + received.append(buf, res); + } + }(service, received, stop)); + + service.run(); + + ASSERT_EQ(received.size(), 5); + ASSERT_EQ(received, "HELLO"); +} + +TEST(ASYNCPP_IO, SocketUDP) { + io_service service; + launch([](io_service& service) -> task<> { + const auto ip = endpoint::parse("185.194.142.4:10070").value(); + auto client = socket::create_udp(service, ip.type()); + constexpr uint8_t buf[] = {0x00, 0x00, 0xe4, 0x00}; + co_await client.send_to(buf, sizeof(buf), ip); + uint8_t receive_buf[128]; + auto [res, source] = + co_await client.recv_from(&receive_buf, sizeof(receive_buf), timeout(std::chrono::seconds(2))); + printf("got %zu bytes from %s\n", res, source.to_string().c_str()); + }(service)); + + service.run(); +} + +TEST(ASYNCPP_IO, SocketValid) { + auto service = io_service::get_default(); + + socket sock; + ASSERT_FALSE(sock); + sock = socket::create_tcp(*service, address_type::ipv4); + ASSERT_TRUE(sock); + socket sock2 = std::move(sock); + ASSERT_FALSE(sock); + ASSERT_FALSE(sock.valid()); + ASSERT_TRUE(sock2); + ASSERT_TRUE(sock2.valid()); + auto fd = sock2.release(); + ASSERT_FALSE(sock2); + close(fd); +} + +#ifdef __linux__ +TEST(ASYNCPP_IO, SocketPair) { + io_service service; + std::string received; + asyncpp::async_launch_scope scope; + scope.invoke([&service, &received]() -> task<> { + auto stop = timeout(std::chrono::seconds(2)); + auto pair = socket::connected_pair_tcp(service, address_type::uds); + co_await pair.first.send("Hello", 5, stop); + pair.first.close_send(); + while (true) { + char buf[128]; + auto res = co_await pair.second.recv(&buf, 128, stop); + if (res == 0) break; + received.append(buf, res); + } + service.stop(); + }); + + service.run(); + + ASSERT_TRUE(scope.all_done()); + ASSERT_EQ(received.size(), 5); + ASSERT_EQ(received, "Hello"); +} +#endif diff --git a/test/tls.cpp b/test/tls.cpp new file mode 100644 index 0000000..4100b05 --- /dev/null +++ b/test/tls.cpp @@ -0,0 +1,161 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace asyncpp::io; + +TEST(ASYNCPP_IO, TLSContext) { + tls::context ctx; + ASSERT_EQ(ctx.get_method(), tls::method::tls); + ASSERT_EQ(ctx.get_mode(), tls::mode::client); +} + +TEST(ASYNCPP_IO, TLSRoundtrip) { + std::cout.sync_with_stdio(true); + // Generate cert if missing + if (!std::filesystem::exists("ssl.crt") || !std::filesystem::exists("ssl.key")) { + std::cout << "Generating temporary cert..." << std::endl; + system("openssl req -x509 -newkey rsa:2048 -keyout ssl.key -out ssl.crt -sha256 -days 2 -nodes -subj " + "\"/C=XX/ST=StateName/L=SomeCity/O=ASYNCPP/OU=ASYNCPP-TEST/CN=server1\""); + atexit([]() { + unlink("ssl.key"); + unlink("ssl.crt"); + }); + } + + tls::context ctx_client(tls::method::tls, tls::mode::client); + tls::context ctx_server(tls::method::tls, tls::mode::server); + ctx_server.use_certificate("ssl.crt"); + //ctx_server.use_certificate("sample.pem"); + ctx_server.debug(); + for (auto& e : ctx_server.get_chain_certs()) { + std::cout << e.to_pem() << std::endl; + } + ctx_server.use_privatekey("ssl.key"); + ctx_server.set_client_hello_callback([](const tls::context::client_hello& hello, int& alert) { + std::cout << "sni: '" << hello.server_name_indication() << "'" << std::endl; + return true; + }); + ctx_server.set_verify(tls::verify_mode::none); + ctx_server.set_alpn_select_callback( + [](tls::session&, std::string_view& res, const std::span& available) { + std::cout << "alpn:"; + for (auto e : available) + std::cout << " '" << e << "'"; + std::cout << std::endl; + res = available.front(); + return true; + }); + ctx_client.set_verify(tls::verify_mode::none); + ctx_client.set_alpn_protos({"http/1.1", "h2"}); + + tls::session session_client(ctx_client); + tls::session session_server(ctx_server); + + session_client.set_servername("server1"); + + launch([](tls::session& client, tls::session& server) -> asyncpp::task<> { + char buffer[1024]; + while (true) { + auto len = co_await client.cipher_read(buffer, sizeof(buffer)); + co_await server.cipher_write(buffer, len); + if (len == 0) break; + } + }(session_client, session_server)); + + launch([](tls::session& client, tls::session& server) -> asyncpp::task<> { + char buffer[1024]; + while (true) { + auto len = co_await server.cipher_read(buffer, sizeof(buffer)); + co_await client.cipher_write(buffer, len); + if (len == 0) break; + } + }(session_client, session_server)); + + bool done = false; + launch([](tls::session& server, bool& done) -> asyncpp::task<> { + while (!done) { + char buf[1024]; + auto res = co_await server.read(buf, sizeof(buf)); + if (res > 0) { + std::cout << std::string_view(buf, res) << std::endl; + done = true; + } + } + server.shutdown(); + }(session_server, done)); + + while (!done) { + const char* test = "Hello World\n"; + size_t size{}; + [[maybe_unused]] auto res = session_client.try_write(test, strlen(test), size); + } + + auto cert = session_client.get_peer_certificate(); + std::cout << cert.to_pem() << std::endl; + std::cout << "nbf: " << std::chrono::system_clock::to_time_t(cert.not_before()) << std::endl; + std::cout << "naf: " << std::chrono::system_clock::to_time_t(cert.not_after()) << std::endl; + std::cout << "subject:" << cert.subject() << std::endl; + std::cout << "issuer: " << cert.issuer() << std::endl; + + ASSERT_EQ("http/1.1", session_server.alpn_selected()); + ASSERT_EQ("http/1.1", session_client.alpn_selected()); + ASSERT_EQ("server1", session_server.get_servername()); + ASSERT_EQ("server1", session_client.get_servername()); +} + +TEST(ASYNCPP_IO, TLSClient) { + std::cout.sync_with_stdio(true); + tls::context ctx_client(tls::method::tls, tls::mode::client); + //ctx_client.set_verify(tls::verify_mode::none); + ctx_client.load_verify_locations("", "/etc/ssl/certs/"); + ctx_client.set_alpn_protos({"http/1.1"}); + tls::session ssl_client(ctx_client); + io_service service; + asyncpp::async_launch_scope scope; + auto sock = socket::create_tcp(service, address_type::ipv4); + ssl_client.set_servername("thalhammer.it"); + + scope.invoke([&ssl_client, &sock, &service, &scope]() -> asyncpp::task<> { + const auto ip = endpoint::parse("194.36.147.124:443").value(); + co_await sock.connect(ip); + + scope.invoke([&ssl_client, &sock]() -> asyncpp::task<> { + char buffer[64 * 1024]; + while (true) { + auto len = co_await ssl_client.cipher_read(buffer, sizeof(buffer)); + if (len == 0) break; + co_await sock.send(buffer, len); + } + }); + + scope.invoke([&ssl_client, &sock]() -> asyncpp::task<> { + char buffer[64 * 1024]; + while (true) { + auto len = co_await sock.recv(buffer, sizeof(buffer)); + co_await ssl_client.cipher_write(buffer, len); + if (len == 0) break; + } + }); + + co_await ssl_client.handshake(); + constexpr std::string_view req = "GET / HTTP/1.1\r\nHost: thalhammer.it\r\nConnection: close\r\n\r\n"; + co_await ssl_client.write(req.data(), req.size()); + while (true) { + char buf[64 * 1024]; + auto res = co_await ssl_client.read(buf, sizeof(buf)); + if (res == 0) break; + if (res > 0) { std::cout << std::string_view(buf, res) << std::endl; } + } + service.stop(); + }); + + while (!scope.all_done()) + service.run(io_service::run_mode::once); +}