diff --git a/.github/workflows/gcc.yml b/.github/workflows/gcc.yml index 42ba1b416a..0050386039 100644 --- a/.github/workflows/gcc.yml +++ b/.github/workflows/gcc.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - gcc: [12, 13, 14] + gcc: [13, 14] build_type: [Debug] std: [20, 23] diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b0d49c008..7843ad72b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,12 +4,15 @@ include(cmake/prelude.cmake) project( glaze - VERSION 2.9.5 + VERSION 2.9.4 LANGUAGES CXX ) include(cmake/project-is-top-level.cmake) include(cmake/variables.cmake) +include(cmake/stdexec.cmake) + +fetch_stdexec() add_library(glaze_glaze INTERFACE) add_library(glaze::glaze ALIAS glaze_glaze) @@ -19,7 +22,7 @@ if (MSVC) string(REGEX MATCH "\/cl(.exe)?$" matched_cl ${CMAKE_CXX_COMPILER}) if (matched_cl) # for a C++ standards compliant preprocessor, not needed for clang-cl - target_compile_options(glaze_glaze INTERFACE "/Zc:preprocessor" /permissive- /Zc:lambda) + target_compile_options(glaze_glaze INTERFACE "/Zc:preprocessor" /permissive- /Zc:lambda /Zc:__cplusplus) if(PROJECT_IS_TOP_LEVEL) target_compile_options(glaze_glaze INTERFACE @@ -40,6 +43,7 @@ target_compile_features(glaze_glaze INTERFACE cxx_std_20) target_include_directories( glaze_glaze ${warning_guard} INTERFACE "$" + "$" ) if(NOT CMAKE_SKIP_INSTALL_RULES) diff --git a/README.md b/README.md index d5487b22c2..549f10551a 100644 --- a/README.md +++ b/README.md @@ -172,11 +172,11 @@ auto ec = glz::write_file_json(obj, "./obj.json", std::string{}); - Only tested on 64bit systems, but should run on 32bit systems - Only supports little-endian systems -[Actions](https://github.com/stephenberry/glaze/actions) build and test with [Clang](https://clang.llvm.org) (15+), [MSVC](https://visualstudio.microsoft.com/vs/features/cplusplus/) (2022), and [GCC](https://gcc.gnu.org) (12+) on apple, windows, and linux. +[Actions](https://github.com/stephenberry/glaze/actions) build and test with [Clang](https://clang.llvm.org) (15+), [MSVC](https://visualstudio.microsoft.com/vs/features/cplusplus/) (2022), and [GCC](https://gcc.gnu.org) (13+) on apple, windows, and linux. ![clang build](https://github.com/stephenberry/glaze/actions/workflows/clang.yml/badge.svg) ![gcc build](https://github.com/stephenberry/glaze/actions/workflows/gcc.yml/badge.svg) ![msvc build](https://github.com/stephenberry/glaze/actions/workflows/msvc.yml/badge.svg) -> Glaze seeks to maintain compatibility with the latest three versions of GCC and Clang, as well as the latest version of MSVC and Apple Clang. +> Glaze seeks to maintain compatibility with the latest three versions of GCC and Clang, as well as the latest version of MSVC and Apple Clang. As an exception, GCC 12 is not supported due to lack of `std::format`. ## How To Use Glaze diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake new file mode 100644 index 0000000000..b273c3bba4 --- /dev/null +++ b/cmake/CPM.cmake @@ -0,0 +1,1225 @@ +# CPM.cmake - CMake's missing package manager +# =========================================== +# See https://github.com/cpm-cmake/CPM.cmake for usage and update instructions. +# +# MIT License +# ----------- +#[[ + Copyright (c) 2019-2023 Lars Melchior and contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +]] + +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +# Initialize logging prefix +if(NOT CPM_INDENT) + set(CPM_INDENT + "CPM:" + CACHE INTERNAL "" + ) +endif() + +if(NOT COMMAND cpm_message) + function(cpm_message) + message(${ARGV}) + endfunction() +endif() + +set(CURRENT_CPM_VERSION 1.0.0-development-version) + +get_filename_component(CPM_CURRENT_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}" REALPATH) +if(CPM_DIRECTORY) + if(NOT CPM_DIRECTORY STREQUAL CPM_CURRENT_DIRECTORY) + if(CPM_VERSION VERSION_LESS CURRENT_CPM_VERSION) + message( + AUTHOR_WARNING + "${CPM_INDENT} \ +A dependency is using a more recent CPM version (${CURRENT_CPM_VERSION}) than the current project (${CPM_VERSION}). \ +It is recommended to upgrade CPM to the most recent version. \ +See https://github.com/cpm-cmake/CPM.cmake for more information." + ) + endif() + if(${CMAKE_VERSION} VERSION_LESS "3.17.0") + include(FetchContent) + endif() + return() + endif() + + get_property( + CPM_INITIALIZED GLOBAL "" + PROPERTY CPM_INITIALIZED + SET + ) + if(CPM_INITIALIZED) + return() + endif() +endif() + +if(CURRENT_CPM_VERSION MATCHES "development-version") + message( + WARNING "${CPM_INDENT} Your project is using an unstable development version of CPM.cmake. \ +Please update to a recent release if possible. \ +See https://github.com/cpm-cmake/CPM.cmake for details." + ) +endif() + +set_property(GLOBAL PROPERTY CPM_INITIALIZED true) + +macro(cpm_set_policies) + # the policy allows us to change options without caching + cmake_policy(SET CMP0077 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + # the policy allows us to change set(CACHE) without caching + if(POLICY CMP0126) + cmake_policy(SET CMP0126 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0126 NEW) + endif() + + # The policy uses the download time for timestamp, instead of the timestamp in the archive. This + # allows for proper rebuilds when a projects url changes + if(POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0135 NEW) + endif() + + # treat relative git repository paths as being relative to the parent project's remote + if(POLICY CMP0150) + cmake_policy(SET CMP0150 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0150 NEW) + endif() +endmacro() +cpm_set_policies() + +option(CPM_USE_LOCAL_PACKAGES "Always try to use `find_package` to get dependencies" + $ENV{CPM_USE_LOCAL_PACKAGES} +) +option(CPM_LOCAL_PACKAGES_ONLY "Only use `find_package` to get dependencies" + $ENV{CPM_LOCAL_PACKAGES_ONLY} +) +option(CPM_DOWNLOAD_ALL "Always download dependencies from source" $ENV{CPM_DOWNLOAD_ALL}) +option(CPM_DONT_UPDATE_MODULE_PATH "Don't update the module path to allow using find_package" + $ENV{CPM_DONT_UPDATE_MODULE_PATH} +) +option(CPM_DONT_CREATE_PACKAGE_LOCK "Don't create a package lock file in the binary path" + $ENV{CPM_DONT_CREATE_PACKAGE_LOCK} +) +option(CPM_INCLUDE_ALL_IN_PACKAGE_LOCK + "Add all packages added through CPM.cmake to the package lock" + $ENV{CPM_INCLUDE_ALL_IN_PACKAGE_LOCK} +) +option(CPM_USE_NAMED_CACHE_DIRECTORIES + "Use additional directory of package name in cache on the most nested level." + $ENV{CPM_USE_NAMED_CACHE_DIRECTORIES} +) + +set(CPM_VERSION + ${CURRENT_CPM_VERSION} + CACHE INTERNAL "" +) +set(CPM_DIRECTORY + ${CPM_CURRENT_DIRECTORY} + CACHE INTERNAL "" +) +set(CPM_FILE + ${CMAKE_CURRENT_LIST_FILE} + CACHE INTERNAL "" +) +set(CPM_PACKAGES + "" + CACHE INTERNAL "" +) +set(CPM_DRY_RUN + OFF + CACHE INTERNAL "Don't download or configure dependencies (for testing)" +) + +if(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_SOURCE_CACHE_DEFAULT $ENV{CPM_SOURCE_CACHE}) +else() + set(CPM_SOURCE_CACHE_DEFAULT OFF) +endif() + +set(CPM_SOURCE_CACHE + ${CPM_SOURCE_CACHE_DEFAULT} + CACHE PATH "Directory to download CPM dependencies" +) + +if(NOT CPM_DONT_UPDATE_MODULE_PATH) + set(CPM_MODULE_PATH + "${CMAKE_BINARY_DIR}/CPM_modules" + CACHE INTERNAL "" + ) + # remove old modules + file(REMOVE_RECURSE ${CPM_MODULE_PATH}) + file(MAKE_DIRECTORY ${CPM_MODULE_PATH}) + # locally added CPM modules should override global packages + set(CMAKE_MODULE_PATH "${CPM_MODULE_PATH};${CMAKE_MODULE_PATH}") +endif() + +if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + set(CPM_PACKAGE_LOCK_FILE + "${CMAKE_BINARY_DIR}/cpm-package-lock.cmake" + CACHE INTERNAL "" + ) + file(WRITE ${CPM_PACKAGE_LOCK_FILE} + "# CPM Package Lock\n# This file should be committed to version control\n\n" + ) +endif() + +include(FetchContent) + +# Try to infer package name from git repository uri (path or url) +function(cpm_package_name_from_git_uri URI RESULT) + if("${URI}" MATCHES "([^/:]+)/?.git/?$") + set(${RESULT} + ${CMAKE_MATCH_1} + PARENT_SCOPE + ) + else() + unset(${RESULT} PARENT_SCOPE) + endif() +endfunction() + +# Try to infer package name and version from a url +function(cpm_package_name_and_ver_from_url url outName outVer) + if(url MATCHES "[/\\?]([a-zA-Z0-9_\\.-]+)\\.(tar|tar\\.gz|tar\\.bz2|zip|ZIP)(\\?|/|$)") + # We matched an archive + set(filename "${CMAKE_MATCH_1}") + + if(filename MATCHES "([a-zA-Z0-9_\\.-]+)[_-]v?(([0-9]+\\.)*[0-9]+[a-zA-Z0-9]*)") + # We matched - (ie foo-1.2.3) + set(${outName} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + set(${outVer} + "${CMAKE_MATCH_2}" + PARENT_SCOPE + ) + elseif(filename MATCHES "(([0-9]+\\.)+[0-9]+[a-zA-Z0-9]*)") + # We couldn't find a name, but we found a version + # + # In many cases (which we don't handle here) the url would look something like + # `irrelevant/ACTUAL_PACKAGE_NAME/irrelevant/1.2.3.zip`. In such a case we can't possibly + # distinguish the package name from the irrelevant bits. Moreover if we try to match the + # package name from the filename, we'd get bogus at best. + unset(${outName} PARENT_SCOPE) + set(${outVer} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + else() + # Boldly assume that the file name is the package name. + # + # Yes, something like `irrelevant/ACTUAL_NAME/irrelevant/download.zip` will ruin our day, but + # such cases should be quite rare. No popular service does this... we think. + set(${outName} + "${filename}" + PARENT_SCOPE + ) + unset(${outVer} PARENT_SCOPE) + endif() + else() + # No ideas yet what to do with non-archives + unset(${outName} PARENT_SCOPE) + unset(${outVer} PARENT_SCOPE) + endif() +endfunction() + +function(cpm_find_package NAME VERSION) + string(REPLACE " " ";" EXTRA_ARGS "${ARGN}") + find_package(${NAME} ${VERSION} ${EXTRA_ARGS} QUIET) + if(${CPM_ARGS_NAME}_FOUND) + if(DEFINED ${CPM_ARGS_NAME}_VERSION) + set(VERSION ${${CPM_ARGS_NAME}_VERSION}) + endif() + cpm_message(STATUS "${CPM_INDENT} Using local package ${CPM_ARGS_NAME}@${VERSION}") + CPMRegisterPackage(${CPM_ARGS_NAME} "${VERSION}") + set(CPM_PACKAGE_FOUND + YES + PARENT_SCOPE + ) + else() + set(CPM_PACKAGE_FOUND + NO + PARENT_SCOPE + ) + endif() +endfunction() + +# Create a custom FindXXX.cmake module for a CPM package This prevents `find_package(NAME)` from +# finding the system library +function(cpm_create_module_file Name) + if(NOT CPM_DONT_UPDATE_MODULE_PATH) + # erase any previous modules + file(WRITE ${CPM_MODULE_PATH}/Find${Name}.cmake + "include(\"${CPM_FILE}\")\n${ARGN}\nset(${Name}_FOUND TRUE)" + ) + endif() +endfunction() + +# Find a package locally or fallback to CPMAddPackage +function(CPMFindPackage) + set(oneValueArgs NAME VERSION GIT_TAG FIND_PACKAGE_ARGUMENTS) + + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "" ${ARGN}) + + if(NOT DEFINED CPM_ARGS_VERSION) + if(DEFINED CPM_ARGS_GIT_TAG) + cpm_get_version_from_git_tag("${CPM_ARGS_GIT_TAG}" CPM_ARGS_VERSION) + endif() + endif() + + set(downloadPackage ${CPM_DOWNLOAD_ALL}) + if(DEFINED CPM_DOWNLOAD_${CPM_ARGS_NAME}) + set(downloadPackage ${CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + elseif(DEFINED ENV{CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + set(downloadPackage $ENV{CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + endif() + if(downloadPackage) + CPMAddPackage(${ARGN}) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + cpm_find_package(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}" ${CPM_ARGS_FIND_PACKAGE_ARGUMENTS}) + + if(NOT CPM_PACKAGE_FOUND) + CPMAddPackage(${ARGN}) + cpm_export_variables(${CPM_ARGS_NAME}) + endif() + +endfunction() + +# checks if a package has been added before +function(cpm_check_if_package_already_added CPM_ARGS_NAME CPM_ARGS_VERSION) + if("${CPM_ARGS_NAME}" IN_LIST CPM_PACKAGES) + CPMGetPackageVersion(${CPM_ARGS_NAME} CPM_PACKAGE_VERSION) + if("${CPM_PACKAGE_VERSION}" VERSION_LESS "${CPM_ARGS_VERSION}") + message( + WARNING + "${CPM_INDENT} Requires a newer version of ${CPM_ARGS_NAME} (${CPM_ARGS_VERSION}) than currently included (${CPM_PACKAGE_VERSION})." + ) + endif() + cpm_get_fetch_properties(${CPM_ARGS_NAME}) + set(${CPM_ARGS_NAME}_ADDED NO) + set(CPM_PACKAGE_ALREADY_ADDED + YES + PARENT_SCOPE + ) + cpm_export_variables(${CPM_ARGS_NAME}) + else() + set(CPM_PACKAGE_ALREADY_ADDED + NO + PARENT_SCOPE + ) + endif() +endfunction() + +# Parse the argument of CPMAddPackage in case a single one was provided and convert it to a list of +# arguments which can then be parsed idiomatically. For example gh:foo/bar@1.2.3 will be converted +# to: GITHUB_REPOSITORY;foo/bar;VERSION;1.2.3 +function(cpm_parse_add_package_single_arg arg outArgs) + # Look for a scheme + if("${arg}" MATCHES "^([a-zA-Z]+):(.+)$") + string(TOLOWER "${CMAKE_MATCH_1}" scheme) + set(uri "${CMAKE_MATCH_2}") + + # Check for CPM-specific schemes + if(scheme STREQUAL "gh") + set(out "GITHUB_REPOSITORY;${uri}") + set(packageType "git") + elseif(scheme STREQUAL "gl") + set(out "GITLAB_REPOSITORY;${uri}") + set(packageType "git") + elseif(scheme STREQUAL "bb") + set(out "BITBUCKET_REPOSITORY;${uri}") + set(packageType "git") + # A CPM-specific scheme was not found. Looks like this is a generic URL so try to determine + # type + elseif(arg MATCHES ".git/?(@|#|$)") + set(out "GIT_REPOSITORY;${arg}") + set(packageType "git") + else() + # Fall back to a URL + set(out "URL;${arg}") + set(packageType "archive") + + # We could also check for SVN since FetchContent supports it, but SVN is so rare these days. + # We just won't bother with the additional complexity it will induce in this function. SVN is + # done by multi-arg + endif() + else() + if(arg MATCHES ".git/?(@|#|$)") + set(out "GIT_REPOSITORY;${arg}") + set(packageType "git") + else() + # Give up + message(FATAL_ERROR "${CPM_INDENT} Can't determine package type of '${arg}'") + endif() + endif() + + # For all packages we interpret @... as version. Only replace the last occurrence. Thus URIs + # containing '@' can be used + string(REGEX REPLACE "@([^@]+)$" ";VERSION;\\1" out "${out}") + + # Parse the rest according to package type + if(packageType STREQUAL "git") + # For git repos we interpret #... as a tag or branch or commit hash + string(REGEX REPLACE "#([^#]+)$" ";GIT_TAG;\\1" out "${out}") + elseif(packageType STREQUAL "archive") + # For archives we interpret #... as a URL hash. + string(REGEX REPLACE "#([^#]+)$" ";URL_HASH;\\1" out "${out}") + # We don't try to parse the version if it's not provided explicitly. cpm_get_version_from_url + # should do this at a later point + else() + # We should never get here. This is an assertion and hitting it means there's a problem with the + # code above. A packageType was set, but not handled by this if-else. + message(FATAL_ERROR "${CPM_INDENT} Unsupported package type '${packageType}' of '${arg}'") + endif() + + set(${outArgs} + ${out} + PARENT_SCOPE + ) +endfunction() + +# Check that the working directory for a git repo is clean +function(cpm_check_git_working_dir_is_clean repoPath gitTag isClean) + + find_package(Git REQUIRED) + + if(NOT GIT_EXECUTABLE) + # No git executable, assume directory is clean + set(${isClean} + TRUE + PARENT_SCOPE + ) + return() + endif() + + # check for uncommitted changes + execute_process( + COMMAND ${GIT_EXECUTABLE} status --porcelain + RESULT_VARIABLE resultGitStatus + OUTPUT_VARIABLE repoStatus + OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET + WORKING_DIRECTORY ${repoPath} + ) + if(resultGitStatus) + # not supposed to happen, assume clean anyway + message(WARNING "${CPM_INDENT} Calling git status on folder ${repoPath} failed") + set(${isClean} + TRUE + PARENT_SCOPE + ) + return() + endif() + + if(NOT "${repoStatus}" STREQUAL "") + set(${isClean} + FALSE + PARENT_SCOPE + ) + return() + endif() + + # check for committed changes + execute_process( + COMMAND ${GIT_EXECUTABLE} diff -s --exit-code ${gitTag} + RESULT_VARIABLE resultGitDiff + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_QUIET + WORKING_DIRECTORY ${repoPath} + ) + + if(${resultGitDiff} EQUAL 0) + set(${isClean} + TRUE + PARENT_SCOPE + ) + else() + set(${isClean} + FALSE + PARENT_SCOPE + ) + endif() + +endfunction() + +# Add PATCH_COMMAND to CPM_ARGS_UNPARSED_ARGUMENTS. This method consumes a list of files in ARGN +# then generates a `PATCH_COMMAND` appropriate for `ExternalProject_Add()`. This command is appended +# to the parent scope's `CPM_ARGS_UNPARSED_ARGUMENTS`. +function(cpm_add_patches) + # Return if no patch files are supplied. + if(NOT ARGN) + return() + endif() + + # Find the patch program. + find_program(PATCH_EXECUTABLE patch) + if(WIN32 AND NOT PATCH_EXECUTABLE) + # The Windows git executable is distributed with patch.exe. Find the path to the executable, if + # it exists, then search `../../usr/bin` for patch.exe. + find_package(Git QUIET) + if(GIT_EXECUTABLE) + get_filename_component(extra_search_path ${GIT_EXECUTABLE} DIRECTORY) + get_filename_component(extra_search_path ${extra_search_path} DIRECTORY) + get_filename_component(extra_search_path ${extra_search_path} DIRECTORY) + find_program(PATCH_EXECUTABLE patch HINTS "${extra_search_path}/usr/bin") + endif() + endif() + if(NOT PATCH_EXECUTABLE) + message(FATAL_ERROR "Couldn't find `patch` executable to use with PATCHES keyword.") + endif() + + # Create a temporary + set(temp_list ${CPM_ARGS_UNPARSED_ARGUMENTS}) + + # Ensure each file exists (or error out) and add it to the list. + set(first_item True) + foreach(PATCH_FILE ${ARGN}) + # Make sure the patch file exists, if we can't find it, try again in the current directory. + if(NOT EXISTS "${PATCH_FILE}") + if(NOT EXISTS "${CMAKE_CURRENT_LIST_DIR}/${PATCH_FILE}") + message(FATAL_ERROR "Couldn't find patch file: '${PATCH_FILE}'") + endif() + set(PATCH_FILE "${CMAKE_CURRENT_LIST_DIR}/${PATCH_FILE}") + endif() + + # Convert to absolute path for use with patch file command. + get_filename_component(PATCH_FILE "${PATCH_FILE}" ABSOLUTE) + + # The first patch entry must be preceded by "PATCH_COMMAND" while the following items are + # preceded by "&&". + if(first_item) + set(first_item False) + list(APPEND temp_list "PATCH_COMMAND") + else() + list(APPEND temp_list "&&") + endif() + # Add the patch command to the list + list(APPEND temp_list "${PATCH_EXECUTABLE}" "-p1" "<" "${PATCH_FILE}") + endforeach() + + # Move temp out into parent scope. + set(CPM_ARGS_UNPARSED_ARGUMENTS + ${temp_list} + PARENT_SCOPE + ) + +endfunction() + +# method to overwrite internal FetchContent properties, to allow using CPM.cmake to overload +# FetchContent calls. As these are internal cmake properties, this method should be used carefully +# and may need modification in future CMake versions. Source: +# https://github.com/Kitware/CMake/blob/dc3d0b5a0a7d26d43d6cfeb511e224533b5d188f/Modules/FetchContent.cmake#L1152 +function(cpm_override_fetchcontent contentName) + cmake_parse_arguments(PARSE_ARGV 1 arg "" "SOURCE_DIR;BINARY_DIR" "") + if(NOT "${arg_UNPARSED_ARGUMENTS}" STREQUAL "") + message(FATAL_ERROR "${CPM_INDENT} Unsupported arguments: ${arg_UNPARSED_ARGUMENTS}") + endif() + + string(TOLOWER ${contentName} contentNameLower) + set(prefix "_FetchContent_${contentNameLower}") + + set(propertyName "${prefix}_sourceDir") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} "${arg_SOURCE_DIR}") + + set(propertyName "${prefix}_binaryDir") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} "${arg_BINARY_DIR}") + + set(propertyName "${prefix}_populated") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} TRUE) +endfunction() + +# Download and add a package from source +function(CPMAddPackage) + cpm_set_policies() + + list(LENGTH ARGN argnLength) + if(argnLength EQUAL 1) + cpm_parse_add_package_single_arg("${ARGN}" ARGN) + + # The shorthand syntax implies EXCLUDE_FROM_ALL and SYSTEM + set(ARGN "${ARGN};EXCLUDE_FROM_ALL;YES;SYSTEM;YES;") + endif() + + set(oneValueArgs + NAME + FORCE + VERSION + GIT_TAG + DOWNLOAD_ONLY + GITHUB_REPOSITORY + GITLAB_REPOSITORY + BITBUCKET_REPOSITORY + GIT_REPOSITORY + SOURCE_DIR + FIND_PACKAGE_ARGUMENTS + NO_CACHE + SYSTEM + GIT_SHALLOW + EXCLUDE_FROM_ALL + SOURCE_SUBDIR + CUSTOM_CACHE_KEY + ) + + set(multiValueArgs URL OPTIONS DOWNLOAD_COMMAND PATCHES) + + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "${multiValueArgs}" "${ARGN}") + + # Set default values for arguments + + if(NOT DEFINED CPM_ARGS_VERSION) + if(DEFINED CPM_ARGS_GIT_TAG) + cpm_get_version_from_git_tag("${CPM_ARGS_GIT_TAG}" CPM_ARGS_VERSION) + endif() + endif() + + if(CPM_ARGS_DOWNLOAD_ONLY) + set(DOWNLOAD_ONLY ${CPM_ARGS_DOWNLOAD_ONLY}) + else() + set(DOWNLOAD_ONLY NO) + endif() + + if(DEFINED CPM_ARGS_GITHUB_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://github.com/${CPM_ARGS_GITHUB_REPOSITORY}.git") + elseif(DEFINED CPM_ARGS_GITLAB_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://gitlab.com/${CPM_ARGS_GITLAB_REPOSITORY}.git") + elseif(DEFINED CPM_ARGS_BITBUCKET_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://bitbucket.org/${CPM_ARGS_BITBUCKET_REPOSITORY}.git") + endif() + + if(DEFINED CPM_ARGS_GIT_REPOSITORY) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_REPOSITORY ${CPM_ARGS_GIT_REPOSITORY}) + if(NOT DEFINED CPM_ARGS_GIT_TAG) + set(CPM_ARGS_GIT_TAG v${CPM_ARGS_VERSION}) + endif() + + # If a name wasn't provided, try to infer it from the git repo + if(NOT DEFINED CPM_ARGS_NAME) + cpm_package_name_from_git_uri(${CPM_ARGS_GIT_REPOSITORY} CPM_ARGS_NAME) + endif() + endif() + + set(CPM_SKIP_FETCH FALSE) + + if(DEFINED CPM_ARGS_GIT_TAG) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_TAG ${CPM_ARGS_GIT_TAG}) + # If GIT_SHALLOW is explicitly specified, honor the value. + if(DEFINED CPM_ARGS_GIT_SHALLOW) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_SHALLOW ${CPM_ARGS_GIT_SHALLOW}) + endif() + endif() + + if(DEFINED CPM_ARGS_URL) + # If a name or version aren't provided, try to infer them from the URL + list(GET CPM_ARGS_URL 0 firstUrl) + cpm_package_name_and_ver_from_url(${firstUrl} nameFromUrl verFromUrl) + # If we fail to obtain name and version from the first URL, we could try other URLs if any. + # However multiple URLs are expected to be quite rare, so for now we won't bother. + + # If the caller provided their own name and version, they trump the inferred ones. + if(NOT DEFINED CPM_ARGS_NAME) + set(CPM_ARGS_NAME ${nameFromUrl}) + endif() + if(NOT DEFINED CPM_ARGS_VERSION) + set(CPM_ARGS_VERSION ${verFromUrl}) + endif() + + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS URL "${CPM_ARGS_URL}") + endif() + + # Check for required arguments + + if(NOT DEFINED CPM_ARGS_NAME) + message( + FATAL_ERROR + "${CPM_INDENT} 'NAME' was not provided and couldn't be automatically inferred for package added with arguments: '${ARGN}'" + ) + endif() + + # Check if package has been added before + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + if(CPM_PACKAGE_ALREADY_ADDED) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + # Check for manual overrides + if(NOT CPM_ARGS_FORCE AND NOT "${CPM_${CPM_ARGS_NAME}_SOURCE}" STREQUAL "") + set(PACKAGE_SOURCE ${CPM_${CPM_ARGS_NAME}_SOURCE}) + set(CPM_${CPM_ARGS_NAME}_SOURCE "") + CPMAddPackage( + NAME "${CPM_ARGS_NAME}" + SOURCE_DIR "${PACKAGE_SOURCE}" + EXCLUDE_FROM_ALL "${CPM_ARGS_EXCLUDE_FROM_ALL}" + SYSTEM "${CPM_ARGS_SYSTEM}" + PATCHES "${CPM_ARGS_PATCHES}" + OPTIONS "${CPM_ARGS_OPTIONS}" + SOURCE_SUBDIR "${CPM_ARGS_SOURCE_SUBDIR}" + DOWNLOAD_ONLY "${DOWNLOAD_ONLY}" + FORCE True + ) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + # Check for available declaration + if(NOT CPM_ARGS_FORCE AND NOT "${CPM_DECLARATION_${CPM_ARGS_NAME}}" STREQUAL "") + set(declaration ${CPM_DECLARATION_${CPM_ARGS_NAME}}) + set(CPM_DECLARATION_${CPM_ARGS_NAME} "") + CPMAddPackage(${declaration}) + cpm_export_variables(${CPM_ARGS_NAME}) + # checking again to ensure version and option compatibility + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + return() + endif() + + if(NOT CPM_ARGS_FORCE) + if(CPM_USE_LOCAL_PACKAGES OR CPM_LOCAL_PACKAGES_ONLY) + cpm_find_package(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}" ${CPM_ARGS_FIND_PACKAGE_ARGUMENTS}) + + if(CPM_PACKAGE_FOUND) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + if(CPM_LOCAL_PACKAGES_ONLY) + message( + SEND_ERROR + "${CPM_INDENT} ${CPM_ARGS_NAME} not found via find_package(${CPM_ARGS_NAME} ${CPM_ARGS_VERSION})" + ) + endif() + endif() + endif() + + CPMRegisterPackage("${CPM_ARGS_NAME}" "${CPM_ARGS_VERSION}") + + if(DEFINED CPM_ARGS_GIT_TAG) + set(PACKAGE_INFO "${CPM_ARGS_GIT_TAG}") + elseif(DEFINED CPM_ARGS_SOURCE_DIR) + set(PACKAGE_INFO "${CPM_ARGS_SOURCE_DIR}") + else() + set(PACKAGE_INFO "${CPM_ARGS_VERSION}") + endif() + + if(DEFINED FETCHCONTENT_BASE_DIR) + # respect user's FETCHCONTENT_BASE_DIR if set + set(CPM_FETCHCONTENT_BASE_DIR ${FETCHCONTENT_BASE_DIR}) + else() + set(CPM_FETCHCONTENT_BASE_DIR ${CMAKE_BINARY_DIR}/_deps) + endif() + + cpm_add_patches(${CPM_ARGS_PATCHES}) + + if(DEFINED CPM_ARGS_DOWNLOAD_COMMAND) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS DOWNLOAD_COMMAND ${CPM_ARGS_DOWNLOAD_COMMAND}) + elseif(DEFINED CPM_ARGS_SOURCE_DIR) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS SOURCE_DIR ${CPM_ARGS_SOURCE_DIR}) + if(NOT IS_ABSOLUTE ${CPM_ARGS_SOURCE_DIR}) + # Expand `CPM_ARGS_SOURCE_DIR` relative path. This is important because EXISTS doesn't work + # for relative paths. + get_filename_component( + source_directory ${CPM_ARGS_SOURCE_DIR} REALPATH BASE_DIR ${CMAKE_CURRENT_BINARY_DIR} + ) + else() + set(source_directory ${CPM_ARGS_SOURCE_DIR}) + endif() + if(NOT EXISTS ${source_directory}) + string(TOLOWER ${CPM_ARGS_NAME} lower_case_name) + # remove timestamps so CMake will re-download the dependency + file(REMOVE_RECURSE "${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-subbuild") + endif() + elseif(CPM_SOURCE_CACHE AND NOT CPM_ARGS_NO_CACHE) + string(TOLOWER ${CPM_ARGS_NAME} lower_case_name) + set(origin_parameters ${CPM_ARGS_UNPARSED_ARGUMENTS}) + list(SORT origin_parameters) + if(CPM_ARGS_CUSTOM_CACHE_KEY) + # Application set a custom unique directory name + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${CPM_ARGS_CUSTOM_CACHE_KEY}) + elseif(CPM_USE_NAMED_CACHE_DIRECTORIES) + string(SHA1 origin_hash "${origin_parameters};NEW_CACHE_STRUCTURE_TAG") + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${origin_hash}/${CPM_ARGS_NAME}) + else() + string(SHA1 origin_hash "${origin_parameters}") + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${origin_hash}) + endif() + # Expand `download_directory` relative path. This is important because EXISTS doesn't work for + # relative paths. + get_filename_component(download_directory ${download_directory} ABSOLUTE) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS SOURCE_DIR ${download_directory}) + + if(CPM_SOURCE_CACHE) + file(LOCK ${download_directory}/../cmake.lock) + endif() + + if(EXISTS ${download_directory}) + if(CPM_SOURCE_CACHE) + file(LOCK ${download_directory}/../cmake.lock RELEASE) + endif() + + cpm_store_fetch_properties( + ${CPM_ARGS_NAME} "${download_directory}" + "${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-build" + ) + cpm_get_fetch_properties("${CPM_ARGS_NAME}") + + if(DEFINED CPM_ARGS_GIT_TAG AND NOT (PATCH_COMMAND IN_LIST CPM_ARGS_UNPARSED_ARGUMENTS)) + # warn if cache has been changed since checkout + cpm_check_git_working_dir_is_clean(${download_directory} ${CPM_ARGS_GIT_TAG} IS_CLEAN) + if(NOT ${IS_CLEAN}) + message( + WARNING "${CPM_INDENT} Cache for ${CPM_ARGS_NAME} (${download_directory}) is dirty" + ) + endif() + endif() + + cpm_add_subdirectory( + "${CPM_ARGS_NAME}" + "${DOWNLOAD_ONLY}" + "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + "${${CPM_ARGS_NAME}_BINARY_DIR}" + "${CPM_ARGS_EXCLUDE_FROM_ALL}" + "${CPM_ARGS_SYSTEM}" + "${CPM_ARGS_OPTIONS}" + ) + set(PACKAGE_INFO "${PACKAGE_INFO} at ${download_directory}") + + # As the source dir is already cached/populated, we override the call to FetchContent. + set(CPM_SKIP_FETCH TRUE) + cpm_override_fetchcontent( + "${lower_case_name}" SOURCE_DIR "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + BINARY_DIR "${${CPM_ARGS_NAME}_BINARY_DIR}" + ) + + else() + # Enable shallow clone when GIT_TAG is not a commit hash. Our guess may not be accurate, but + # it should guarantee no commit hash get mis-detected. + if(NOT DEFINED CPM_ARGS_GIT_SHALLOW) + cpm_is_git_tag_commit_hash("${CPM_ARGS_GIT_TAG}" IS_HASH) + if(NOT ${IS_HASH}) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_SHALLOW TRUE) + endif() + endif() + + # remove timestamps so CMake will re-download the dependency + file(REMOVE_RECURSE ${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-subbuild) + set(PACKAGE_INFO "${PACKAGE_INFO} to ${download_directory}") + endif() + endif() + + cpm_create_module_file(${CPM_ARGS_NAME} "CPMAddPackage(\"${ARGN}\")") + + if(CPM_PACKAGE_LOCK_ENABLED) + if((CPM_ARGS_VERSION AND NOT CPM_ARGS_SOURCE_DIR) OR CPM_INCLUDE_ALL_IN_PACKAGE_LOCK) + cpm_add_to_package_lock(${CPM_ARGS_NAME} "${ARGN}") + elseif(CPM_ARGS_SOURCE_DIR) + cpm_add_comment_to_package_lock(${CPM_ARGS_NAME} "local directory") + else() + cpm_add_comment_to_package_lock(${CPM_ARGS_NAME} "${ARGN}") + endif() + endif() + + cpm_message( + STATUS "${CPM_INDENT} Adding package ${CPM_ARGS_NAME}@${CPM_ARGS_VERSION} (${PACKAGE_INFO})" + ) + + if(NOT CPM_SKIP_FETCH) + cpm_declare_fetch( + "${CPM_ARGS_NAME}" "${CPM_ARGS_VERSION}" "${PACKAGE_INFO}" "${CPM_ARGS_UNPARSED_ARGUMENTS}" + ) + cpm_fetch_package("${CPM_ARGS_NAME}" populated) + if(CPM_SOURCE_CACHE AND download_directory) + file(LOCK ${download_directory}/../cmake.lock RELEASE) + endif() + if(${populated}) + cpm_add_subdirectory( + "${CPM_ARGS_NAME}" + "${DOWNLOAD_ONLY}" + "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + "${${CPM_ARGS_NAME}_BINARY_DIR}" + "${CPM_ARGS_EXCLUDE_FROM_ALL}" + "${CPM_ARGS_SYSTEM}" + "${CPM_ARGS_OPTIONS}" + ) + endif() + cpm_get_fetch_properties("${CPM_ARGS_NAME}") + endif() + + set(${CPM_ARGS_NAME}_ADDED YES) + cpm_export_variables("${CPM_ARGS_NAME}") +endfunction() + +# Fetch a previously declared package +macro(CPMGetPackage Name) + if(DEFINED "CPM_DECLARATION_${Name}") + CPMAddPackage(NAME ${Name}) + else() + message(SEND_ERROR "${CPM_INDENT} Cannot retrieve package ${Name}: no declaration available") + endif() +endmacro() + +# export variables available to the caller to the parent scope expects ${CPM_ARGS_NAME} to be set +macro(cpm_export_variables name) + set(${name}_SOURCE_DIR + "${${name}_SOURCE_DIR}" + PARENT_SCOPE + ) + set(${name}_BINARY_DIR + "${${name}_BINARY_DIR}" + PARENT_SCOPE + ) + set(${name}_ADDED + "${${name}_ADDED}" + PARENT_SCOPE + ) + set(CPM_LAST_PACKAGE_NAME + "${name}" + PARENT_SCOPE + ) +endmacro() + +# declares a package, so that any call to CPMAddPackage for the package name will use these +# arguments instead. Previous declarations will not be overridden. +macro(CPMDeclarePackage Name) + if(NOT DEFINED "CPM_DECLARATION_${Name}") + set("CPM_DECLARATION_${Name}" "${ARGN}") + endif() +endmacro() + +function(cpm_add_to_package_lock Name) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + cpm_prettify_package_arguments(PRETTY_ARGN false ${ARGN}) + file(APPEND ${CPM_PACKAGE_LOCK_FILE} "# ${Name}\nCPMDeclarePackage(${Name}\n${PRETTY_ARGN})\n") + endif() +endfunction() + +function(cpm_add_comment_to_package_lock Name) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + cpm_prettify_package_arguments(PRETTY_ARGN true ${ARGN}) + file(APPEND ${CPM_PACKAGE_LOCK_FILE} + "# ${Name} (unversioned)\n# CPMDeclarePackage(${Name}\n${PRETTY_ARGN}#)\n" + ) + endif() +endfunction() + +# includes the package lock file if it exists and creates a target `cpm-update-package-lock` to +# update it +macro(CPMUsePackageLock file) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + get_filename_component(CPM_ABSOLUTE_PACKAGE_LOCK_PATH ${file} ABSOLUTE) + if(EXISTS ${CPM_ABSOLUTE_PACKAGE_LOCK_PATH}) + include(${CPM_ABSOLUTE_PACKAGE_LOCK_PATH}) + endif() + if(NOT TARGET cpm-update-package-lock) + add_custom_target( + cpm-update-package-lock COMMAND ${CMAKE_COMMAND} -E copy ${CPM_PACKAGE_LOCK_FILE} + ${CPM_ABSOLUTE_PACKAGE_LOCK_PATH} + ) + endif() + set(CPM_PACKAGE_LOCK_ENABLED true) + endif() +endmacro() + +# registers a package that has been added to CPM +function(CPMRegisterPackage PACKAGE VERSION) + list(APPEND CPM_PACKAGES ${PACKAGE}) + set(CPM_PACKAGES + ${CPM_PACKAGES} + CACHE INTERNAL "" + ) + set("CPM_PACKAGE_${PACKAGE}_VERSION" + ${VERSION} + CACHE INTERNAL "" + ) +endfunction() + +# retrieve the current version of the package to ${OUTPUT} +function(CPMGetPackageVersion PACKAGE OUTPUT) + set(${OUTPUT} + "${CPM_PACKAGE_${PACKAGE}_VERSION}" + PARENT_SCOPE + ) +endfunction() + +# declares a package in FetchContent_Declare +function(cpm_declare_fetch PACKAGE VERSION INFO) + if(${CPM_DRY_RUN}) + cpm_message(STATUS "${CPM_INDENT} Package not declared (dry run)") + return() + endif() + + FetchContent_Declare(${PACKAGE} ${ARGN}) +endfunction() + +# returns properties for a package previously defined by cpm_declare_fetch +function(cpm_get_fetch_properties PACKAGE) + if(${CPM_DRY_RUN}) + return() + endif() + + set(${PACKAGE}_SOURCE_DIR + "${CPM_PACKAGE_${PACKAGE}_SOURCE_DIR}" + PARENT_SCOPE + ) + set(${PACKAGE}_BINARY_DIR + "${CPM_PACKAGE_${PACKAGE}_BINARY_DIR}" + PARENT_SCOPE + ) +endfunction() + +function(cpm_store_fetch_properties PACKAGE source_dir binary_dir) + if(${CPM_DRY_RUN}) + return() + endif() + + set(CPM_PACKAGE_${PACKAGE}_SOURCE_DIR + "${source_dir}" + CACHE INTERNAL "" + ) + set(CPM_PACKAGE_${PACKAGE}_BINARY_DIR + "${binary_dir}" + CACHE INTERNAL "" + ) +endfunction() + +# adds a package as a subdirectory if viable, according to provided options +function( + cpm_add_subdirectory + PACKAGE + DOWNLOAD_ONLY + SOURCE_DIR + BINARY_DIR + EXCLUDE + SYSTEM + OPTIONS +) + + if(NOT DOWNLOAD_ONLY AND EXISTS ${SOURCE_DIR}/CMakeLists.txt) + set(addSubdirectoryExtraArgs "") + if(EXCLUDE) + list(APPEND addSubdirectoryExtraArgs EXCLUDE_FROM_ALL) + endif() + if("${SYSTEM}" AND "${CMAKE_VERSION}" VERSION_GREATER_EQUAL "3.25") + # https://cmake.org/cmake/help/latest/prop_dir/SYSTEM.html#prop_dir:SYSTEM + list(APPEND addSubdirectoryExtraArgs SYSTEM) + endif() + if(OPTIONS) + foreach(OPTION ${OPTIONS}) + cpm_parse_option("${OPTION}") + set(${OPTION_KEY} "${OPTION_VALUE}") + endforeach() + endif() + set(CPM_OLD_INDENT "${CPM_INDENT}") + set(CPM_INDENT "${CPM_INDENT} ${PACKAGE}:") + add_subdirectory(${SOURCE_DIR} ${BINARY_DIR} ${addSubdirectoryExtraArgs}) + set(CPM_INDENT "${CPM_OLD_INDENT}") + endif() +endfunction() + +# downloads a previously declared package via FetchContent and exports the variables +# `${PACKAGE}_SOURCE_DIR` and `${PACKAGE}_BINARY_DIR` to the parent scope +function(cpm_fetch_package PACKAGE populated) + set(${populated} + FALSE + PARENT_SCOPE + ) + if(${CPM_DRY_RUN}) + cpm_message(STATUS "${CPM_INDENT} Package ${PACKAGE} not fetched (dry run)") + return() + endif() + + FetchContent_GetProperties(${PACKAGE}) + + string(TOLOWER "${PACKAGE}" lower_case_name) + + if(NOT ${lower_case_name}_POPULATED) + FetchContent_Populate(${PACKAGE}) + set(${populated} + TRUE + PARENT_SCOPE + ) + endif() + + cpm_store_fetch_properties( + ${CPM_ARGS_NAME} ${${lower_case_name}_SOURCE_DIR} ${${lower_case_name}_BINARY_DIR} + ) + + set(${PACKAGE}_SOURCE_DIR + ${${lower_case_name}_SOURCE_DIR} + PARENT_SCOPE + ) + set(${PACKAGE}_BINARY_DIR + ${${lower_case_name}_BINARY_DIR} + PARENT_SCOPE + ) +endfunction() + +# splits a package option +function(cpm_parse_option OPTION) + string(REGEX MATCH "^[^ ]+" OPTION_KEY "${OPTION}") + string(LENGTH "${OPTION}" OPTION_LENGTH) + string(LENGTH "${OPTION_KEY}" OPTION_KEY_LENGTH) + if(OPTION_KEY_LENGTH STREQUAL OPTION_LENGTH) + # no value for key provided, assume user wants to set option to "ON" + set(OPTION_VALUE "ON") + else() + math(EXPR OPTION_KEY_LENGTH "${OPTION_KEY_LENGTH}+1") + string(SUBSTRING "${OPTION}" "${OPTION_KEY_LENGTH}" "-1" OPTION_VALUE) + endif() + set(OPTION_KEY + "${OPTION_KEY}" + PARENT_SCOPE + ) + set(OPTION_VALUE + "${OPTION_VALUE}" + PARENT_SCOPE + ) +endfunction() + +# guesses the package version from a git tag +function(cpm_get_version_from_git_tag GIT_TAG RESULT) + string(LENGTH ${GIT_TAG} length) + if(length EQUAL 40) + # GIT_TAG is probably a git hash + set(${RESULT} + 0 + PARENT_SCOPE + ) + else() + string(REGEX MATCH "v?([0123456789.]*).*" _ ${GIT_TAG}) + set(${RESULT} + ${CMAKE_MATCH_1} + PARENT_SCOPE + ) + endif() +endfunction() + +# guesses if the git tag is a commit hash or an actual tag or a branch name. +function(cpm_is_git_tag_commit_hash GIT_TAG RESULT) + string(LENGTH "${GIT_TAG}" length) + # full hash has 40 characters, and short hash has at least 7 characters. + if(length LESS 7 OR length GREATER 40) + set(${RESULT} + 0 + PARENT_SCOPE + ) + else() + if(${GIT_TAG} MATCHES "^[a-fA-F0-9]+$") + set(${RESULT} + 1 + PARENT_SCOPE + ) + else() + set(${RESULT} + 0 + PARENT_SCOPE + ) + endif() + endif() +endfunction() + +function(cpm_prettify_package_arguments OUT_VAR IS_IN_COMMENT) + set(oneValueArgs + NAME + FORCE + VERSION + GIT_TAG + DOWNLOAD_ONLY + GITHUB_REPOSITORY + GITLAB_REPOSITORY + BITBUCKET_REPOSITORY + GIT_REPOSITORY + SOURCE_DIR + FIND_PACKAGE_ARGUMENTS + NO_CACHE + SYSTEM + GIT_SHALLOW + EXCLUDE_FROM_ALL + SOURCE_SUBDIR + ) + set(multiValueArgs URL OPTIONS DOWNLOAD_COMMAND) + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + foreach(oneArgName ${oneValueArgs}) + if(DEFINED CPM_ARGS_${oneArgName}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + if(${oneArgName} STREQUAL "SOURCE_DIR") + string(REPLACE ${CMAKE_SOURCE_DIR} "\${CMAKE_SOURCE_DIR}" CPM_ARGS_${oneArgName} + ${CPM_ARGS_${oneArgName}} + ) + endif() + string(APPEND PRETTY_OUT_VAR " ${oneArgName} ${CPM_ARGS_${oneArgName}}\n") + endif() + endforeach() + foreach(multiArgName ${multiValueArgs}) + if(DEFINED CPM_ARGS_${multiArgName}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " ${multiArgName}\n") + foreach(singleOption ${CPM_ARGS_${multiArgName}}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " \"${singleOption}\"\n") + endforeach() + endif() + endforeach() + + if(NOT "${CPM_ARGS_UNPARSED_ARGUMENTS}" STREQUAL "") + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " ") + foreach(CPM_ARGS_UNPARSED_ARGUMENT ${CPM_ARGS_UNPARSED_ARGUMENTS}) + string(APPEND PRETTY_OUT_VAR " ${CPM_ARGS_UNPARSED_ARGUMENT}") + endforeach() + string(APPEND PRETTY_OUT_VAR "\n") + endif() + + set(${OUT_VAR} + ${PRETTY_OUT_VAR} + PARENT_SCOPE + ) + +endfunction() diff --git a/cmake/stdexec.cmake b/cmake/stdexec.cmake new file mode 100644 index 0000000000..6e7c8fd5aa --- /dev/null +++ b/cmake/stdexec.cmake @@ -0,0 +1,38 @@ +function (fetch_stdexec) + set(branch_or_tag "main") + set(url "https://github.com/NVIDIA/stdexec.git") + set(target_folder "${CMAKE_BINARY_DIR}/_deps/stdexec-src") + + if (NOT EXISTS ${target_folder}) + execute_process( + COMMAND git clone --depth 1 --branch "${branch_or_tag}" --recursive "${url}" "${target_folder}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + RESULT_VARIABLE exec_process_result + OUTPUT_VARIABLE exec_process_output + ) + if(NOT exec_process_result EQUAL "0") + message(FATAL_ERROR "Git clone failed: ${exec_process_output}") + else() + message(STATUS "Git clone succeeded: ${exec_process_output}") + endif() + endif() + + set(stdexec_SOURCE_DIR ${target_folder} CACHE INTERNAL "stdexec source folder" FORCE) + set(stdexec_INCLUDE_DIR ${target_folder}/include CACHE INTERNAL "stdexec include folder" FORCE) + + #[[ + include(FetchContent) + + FetchContent_Declare( + stdexec + GIT_REPOSITORY https://github.com/NVIDIA/stdexec.git + GIT_TAG main + GIT_SHALLOW TRUE + ) + + set(STDEXEC_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) + + FetchContent_MakeAvailable(stdexec) + #]] + +endfunction() \ No newline at end of file diff --git a/include/glaze/core/error.hpp b/include/glaze/core/error.hpp new file mode 100644 index 0000000000..dae1540f59 --- /dev/null +++ b/include/glaze/core/error.hpp @@ -0,0 +1,27 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/core/context.hpp" +#include "glaze/core/common.hpp" + +#include + +namespace glz +{ + struct error_category : public std::error_category + { + static const error_category& instance() { + static error_category instance{}; + return instance; + } + + const char* name() const noexcept override { return "glz::error_category"; } + + std::string message(int ec) const override + { + return std::string{nameof(error_code(ec))}; + } + }; +} diff --git a/include/glaze/coroutine.hpp b/include/glaze/coroutine.hpp new file mode 100644 index 0000000000..c4eaccc125 --- /dev/null +++ b/include/glaze/coroutine.hpp @@ -0,0 +1,16 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/coroutine/awaitable.hpp" +#include "glaze/coroutine/generator.hpp" +#include "glaze/coroutine/latch.hpp" +#include "glaze/coroutine/mutex.hpp" +#include "glaze/coroutine/ring_buffer.hpp" +#include "glaze/coroutine/semaphore.hpp" +#include "glaze/coroutine/shared_mutex.hpp" +#include "glaze/coroutine/sync_wait.hpp" +#include "glaze/coroutine/task.hpp" +#include "glaze/coroutine/thread_pool.hpp" +#include "glaze/coroutine/when_all.hpp" diff --git a/include/glaze/coroutine/awaitable.hpp b/include/glaze/coroutine/awaitable.hpp new file mode 100644 index 0000000000..83fbe10b6e --- /dev/null +++ b/include/glaze/coroutine/awaitable.hpp @@ -0,0 +1,106 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include +#include +#include + +namespace glz +{ + template + concept in_types = (std::same_as || ...); + + /** + * This concept declares a type that is required to meet the c++20 coroutine operator co_await() + * retun type. It requires the following three member functions: + * await_ready() -> bool + * await_suspend(std::coroutine_handle<>) -> void|bool|std::coroutine_handle<> + * await_resume() -> decltype(auto) + * Where the return type on await_resume is the requested return of the awaitable. + */ + template + concept awaiter = requires(type t, std::coroutine_handle<> c) { + { + t.await_ready() + } -> std::same_as; + { + t.await_suspend(c) + } -> in_types>; + { + t.await_resume() + }; + }; + + template + concept member_co_await_awaitable = requires(type t) { + { + t.operator co_await() + } -> awaiter; + }; + + template + concept global_co_await_awaitable = requires(type t) { + { + operator co_await(t) + } -> awaiter; + }; + + /** + * This concept declares a type that can be operator co_await()'ed and returns an awaiter_type. + */ + template + concept awaitable = member_co_await_awaitable || global_co_await_awaitable || awaiter; + + template + concept awaiter_void = awaiter && requires(type t) { + { + t.await_resume() + } -> std::same_as; + }; + + template + concept member_co_await_awaitable_void = requires(type t) { + { + t.operator co_await() + } -> awaiter_void; + }; + + template + concept global_co_await_awaitable_void = requires(type t) { + { + operator co_await(t) + } -> awaiter_void; + }; + + template + concept awaitable_void = + member_co_await_awaitable_void || global_co_await_awaitable_void || awaiter_void; + + template + struct awaitable_traits + {}; + + template + static auto get_awaiter(awaitable&& value) + { + if constexpr (member_co_await_awaitable) + return std::forward(value).operator co_await(); + else if constexpr (global_co_await_awaitable) + return operator co_await(std::forward(value)); + else if constexpr (awaiter) { + return std::forward(value); + } + } + + template + struct awaitable_traits + { + using type = decltype(get_awaiter(std::declval())); + using return_type = decltype(std::declval().await_resume()); + }; +} diff --git a/include/glaze/coroutine/concepts.hpp b/include/glaze/coroutine/concepts.hpp new file mode 100644 index 0000000000..12721a83d8 --- /dev/null +++ b/include/glaze/coroutine/concepts.hpp @@ -0,0 +1,63 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include + +#include "glaze/coroutine/awaitable.hpp" + +namespace glz +{ + /** + * Concept to require that the range contains a specific type of value. + */ + template + concept range_of = std::ranges::range && std::is_same_v>; + + /** + * Concept to require that a sized range contains a specific type of value. + */ + template + concept sized_range_of = std::ranges::sized_range && std::is_same_v>; + + template + concept executor = requires(T t, std::coroutine_handle<> c) { + { + t.schedule() + } -> awaiter; + { + t.yield() + } -> awaiter; + { + t.resume(c) + } -> std::same_as; + }; + + template + concept io_exceutor = executor;/* and requires(T t, std::coroutine_handle<> c, net::file_handle_t fd, glz::poll_op op, + std::chrono::milliseconds timeout) { + { + t.poll(fd, op, timeout) + } -> std::same_as>; + };*/ + + template + concept const_buffer = requires(const T t) + { + { t.empty() } -> std::same_as; + { t.data() } -> std::same_as; + { t.size() } -> std::same_as; + }; + + template + concept mutable_buffer = requires(T t) + { + { t.empty() } -> std::same_as; + { t.data() } -> std::same_as; + { t.size() } -> std::same_as; + }; +} diff --git a/include/glaze/coroutine/delete.hpp b/include/glaze/coroutine/delete.hpp new file mode 100644 index 0000000000..9f1e331307 --- /dev/null +++ b/include/glaze/coroutine/delete.hpp @@ -0,0 +1,4 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once diff --git a/include/glaze/coroutine/event.hpp b/include/glaze/coroutine/event.hpp new file mode 100644 index 0000000000..d829bda544 --- /dev/null +++ b/include/glaze/coroutine/event.hpp @@ -0,0 +1,221 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include + +#include "glaze/coroutine/concepts.hpp" + +namespace glz +{ + enum struct resume_order_policy { + /// Last in first out, this is the default policy and will execute the fastest + /// if you do not need the first waiter to execute first upon the event being set. + lifo, + /// First in first out, this policy has an extra overhead to reverse the order of + /// the waiters but will guarantee the ordering is fifo. + fifo + }; + + /** + * Event is a manully triggered thread safe signal that can be co_await()'ed by multiple awaiters. + * Each awaiter should co_await the event and upon the event being set each awaiter will have their + * coroutine resumed. + * + * The event can be manually reset to the un-set state to be re-used. + * \code + t1: glz::event e; + ... + t2: func(glz::event& e) { ... co_await e; ... } + ... + t1: do_work(); + t1: e.set(); + ... + t2: resume() + * \endcode + */ + struct event + { + struct awaiter + { + /** + * @param e The event to wait for it to be set. + */ + awaiter(const event& e) noexcept : m_event(e) {} + + /** + * @return True if the event is already set, otherwise false to suspend this coroutine. + */ + auto await_ready() const noexcept -> bool { return m_event.is_set(); } + + /** + * Adds this coroutine to the list of awaiters in a thread safe fashion. If the event + * is set while attempting to add this coroutine to the awaiters then this will return false + * to resume execution immediately. + * @return False if the event is already set, otherwise true to suspend this coroutine. + */ + bool await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + { + const void* const set_state = &m_event; + + m_awaiting_coroutine = awaiting_coroutine; + + // This value will update if other threads write to it via acquire. + void* old_value = m_event.m_state.load(std::memory_order::acquire); + do + { + // Resume immediately if already in the set state. + if (old_value == set_state) + { + return false; + } + + m_next = static_cast(old_value); + } while (!m_event.m_state.compare_exchange_weak( + old_value, this, std::memory_order::release, std::memory_order::acquire)); + + return true; + } + + /** + * Nothing to do on resume. + */ + auto await_resume() noexcept {} + + /// Refernce to the event that this awaiter is waiting on. + const event& m_event; + /// The awaiting continuation coroutine handle. + std::coroutine_handle<> m_awaiting_coroutine; + /// The next awaiter in line for this event, nullptr if this is the end. + awaiter* m_next{nullptr}; + }; + + /** + * Creates an event with the given initial state of being set or not set. + * @param initially_set By default all events start as not set, but if needed this parameter can + * set the event to already be triggered. + */ + explicit event(bool initially_set = false) noexcept + : m_state((initially_set) ? static_cast(this) : nullptr) + {} + + ~event() = default; + + event(const event&) = delete; + event(event&&) = delete; + auto operator=(const event&) -> event& = delete; + auto operator=(event&&) -> event& = delete; + + /** + * @return True if this event is currently in the set state. + */ + auto is_set() const noexcept -> bool { return m_state.load(std::memory_order_acquire) == this; } + + /** + * Sets this event and resumes all awaiters. Note that all waiters will be resumed onto this + * thread of execution. + * @param policy The order in which the waiters should be resumed, defaults to LIFO since it + * is more efficient, FIFO requires reversing the order of the waiters first. + */ + void set(resume_order_policy policy = resume_order_policy::lifo) noexcept + { + // Exchange the state to this, if the state was previously not this, then traverse the list + // of awaiters and resume their coroutines. + void* old_value = m_state.exchange(this, std::memory_order::acq_rel); + if (old_value != this) + { + // If FIFO has been requsted then reverse the order upon resuming. + if (policy == resume_order_policy::fifo) + { + old_value = reverse(static_cast(old_value)); + } + // else lifo nothing to do + + auto* waiters = static_cast(old_value); + while (waiters) + { + auto* next = waiters->m_next; + waiters->m_awaiting_coroutine.resume(); + waiters = next; + } + } + } + + /** + * Sets this event and resumes all awaiters onto the given executor. This will distribute + * the waiters across the executor's threads. + */ + template + auto set(executor_type& e, resume_order_policy policy = resume_order_policy::lifo) noexcept -> void + { + void* old_value = m_state.exchange(this, std::memory_order::acq_rel); + if (old_value != this) { + // If FIFO has been requsted then reverse the order upon resuming. + if (policy == resume_order_policy::fifo) { + old_value = reverse(static_cast(old_value)); + } + // else lifo nothing to do + + auto* waiters = static_cast(old_value); + while (waiters) { + auto* next = waiters->m_next; + e.resume(waiters->m_awaiting_coroutine); + waiters = next; + } + } + } + + /** + * @return An awaiter struct to suspend and resume this coroutine for when the event is set. + */ + auto operator co_await() const noexcept -> awaiter { return awaiter(*this); } + + /** + * Resets the event from set to not set so it can be re-used. If the event is not currently + * set then this function has no effect. + */ + void reset() noexcept + { + void* old_value = this; + m_state.compare_exchange_strong(old_value, nullptr, std::memory_order::acquire); + } + + protected: + /// For access to m_state. + friend struct awaiter; + /// The state of the event, nullptr is not set with zero awaiters. Set to an awaiter* there are + /// coroutines awaiting the event to be set, and set to this the event has triggered. + /// 1) nullptr == not set + /// 2) awaiter* == linked list of awaiters waiting for the event to trigger. + /// 3) this == The event is triggered and all awaiters are resumed. + mutable std::atomic m_state; + + private: + /** + * Reverses the set of waiters from LIFO->FIFO and returns the new head. + */ + auto reverse(awaiter* curr) -> awaiter* + { + if (curr == nullptr || curr->m_next == nullptr) + { + return curr; + } + + awaiter* prev = nullptr; + awaiter* next = nullptr; + while (curr) + { + next = curr->m_next; + curr->m_next = prev; + prev = curr; + curr = next; + } + + return prev; + } + }; +} diff --git a/include/glaze/coroutine/generator.hpp b/include/glaze/coroutine/generator.hpp new file mode 100644 index 0000000000..38c5664f94 --- /dev/null +++ b/include/glaze/coroutine/generator.hpp @@ -0,0 +1,185 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include +#include +#include + +namespace glz +{ + template + struct generator; + + namespace detail + { + template + struct generator_promise + { + using value_type = std::remove_reference_t; + using reference_type = std::conditional_t, T, T&>; + using pointer_type = value_type*; + + generator_promise() = default; + + auto get_return_object() noexcept -> generator; + + auto initial_suspend() const { return std::suspend_always{}; } + + auto final_suspend() const noexcept(true) { return std::suspend_always{}; } + + template ::value, int> = 0> + auto yield_value(std::remove_reference_t& value) noexcept + { + m_value = std::addressof(value); + return std::suspend_always{}; + } + + auto yield_value(std::remove_reference_t&& value) noexcept + { + m_value = std::addressof(value); + return std::suspend_always{}; + } + + auto unhandled_exception() -> void { m_exception = std::current_exception(); } + + auto return_void() noexcept -> void {} + + auto value() const noexcept -> reference_type { return static_cast(*m_value); } + + template + auto await_transform(U&& value) -> std::suspend_never = delete; + + auto rethrow_if_exception() -> void + { + if (m_exception) { + std::rethrow_exception(m_exception); + } + } + + private: + pointer_type m_value{nullptr}; + std::exception_ptr m_exception; + }; + + struct generator_sentinel + {}; + + template + struct generator_iterator + { + using coroutine_handle = std::coroutine_handle>; + + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = typename generator_promise::value_type; + using reference = typename generator_promise::reference_type; + using pointer = typename generator_promise::pointer_type; + + generator_iterator() noexcept {} + + explicit generator_iterator(coroutine_handle coroutine) noexcept : m_coroutine(coroutine) {} + + friend auto operator==(const generator_iterator& it, generator_sentinel) noexcept -> bool + { + return it.m_coroutine == nullptr || it.m_coroutine.done(); + } + + friend auto operator!=(const generator_iterator& it, generator_sentinel s) noexcept -> bool + { + return !(it == s); + } + + friend auto operator==(generator_sentinel s, const generator_iterator& it) noexcept -> bool + { + return (it == s); + } + + friend auto operator!=(generator_sentinel s, const generator_iterator& it) noexcept -> bool { return it != s; } + + generator_iterator& operator++() + { + m_coroutine.resume(); + if (m_coroutine.done()) { + m_coroutine.promise().rethrow_if_exception(); + } + + return *this; + } + + auto operator++(int) -> void { (void)operator++(); } + + reference operator*() const noexcept { return m_coroutine.promise().value(); } + + pointer operator->() const noexcept { return std::addressof(operator*()); } + + private: + coroutine_handle m_coroutine{nullptr}; + }; + + } // namespace detail + + template + struct generator : public std::ranges::view_base + { + using promise_type = detail::generator_promise; + using iterator = detail::generator_iterator; + using sentinel = detail::generator_sentinel; + + generator() noexcept : m_coroutine(nullptr) {} + + generator(const generator&) = delete; + generator(generator&& other) noexcept : m_coroutine(other.m_coroutine) { other.m_coroutine = nullptr; } + + auto operator=(const generator&) = delete; + auto operator=(generator&& other) noexcept -> generator& + { + m_coroutine = other.m_coroutine; + other.m_coroutine = nullptr; + + return *this; + } + + ~generator() + { + if (m_coroutine) { + m_coroutine.destroy(); + } + } + + auto begin() -> iterator + { + if (m_coroutine) { + m_coroutine.resume(); + if (m_coroutine.done()) { + m_coroutine.promise().rethrow_if_exception(); + } + } + + return iterator{m_coroutine}; + } + + auto end() noexcept -> sentinel { return sentinel{}; } + + private: + friend struct detail::generator_promise; + + explicit generator(std::coroutine_handle coroutine) noexcept : m_coroutine(coroutine) {} + + std::coroutine_handle m_coroutine; + }; + + namespace detail + { + template + auto generator_promise::get_return_object() noexcept -> generator + { + return generator{std::coroutine_handle>::from_promise(*this)}; + } + + } +} diff --git a/include/glaze/coroutine/latch.hpp b/include/glaze/coroutine/latch.hpp new file mode 100644 index 0000000000..53b1f8829e --- /dev/null +++ b/include/glaze/coroutine/latch.hpp @@ -0,0 +1,81 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include + +#include "glaze/coroutine/event.hpp" +#include "glaze/coroutine/thread_pool.hpp" + +namespace glz +{ + /** + * The latch is thread safe counter to wait for 1 or more other tasks to complete, they signal their + * completion by calling `count_down()` on the latch and upon the latch counter reaching zero the + * coroutine `co_await`ing the latch then resumes execution. + * + * This is useful for spawning many worker tasks to complete either a computationally complex task + * across a thread pool of workers, or waiting for many asynchronous results like http requests + * to complete. + */ + struct latch + { + /** + * Creates a latch with the given count of tasks to wait to complete. + * @param count The number of tasks to wait to complete, if this is zero or negative then the + * latch starts 'completed' immediately and execution is resumed with no suspension. + */ + latch(int64_t count) noexcept : m_count(count), m_event(count <= 0) {} + + latch(const latch&) = delete; + latch(latch&&) = delete; + auto operator=(const latch&) -> latch& = delete; + auto operator=(latch&&) -> latch& = delete; + + /** + * @return True if the latch has been counted down to zero. + */ + auto is_ready() const noexcept -> bool { return m_event.is_set(); } + + /** + * @return The number of tasks this latch is still waiting to complete. + */ + size_t remaining() const noexcept { return m_count.load(std::memory_order::acquire); } + + /** + * If the latch counter goes to zero then the task awaiting the latch is resumed. + * @param n The number of tasks to complete towards the latch, defaults to 1. + */ + auto count_down(std::int64_t n = 1) noexcept -> void + { + if (m_count.fetch_sub(n, std::memory_order::acq_rel) <= n) { + m_event.set(); + } + } + + /** + * If the latch counter goes to zero then the task awaiting the latch is resumed on the given + * thread pool. + * @param tp The thread pool to schedule the task that is waiting on the latch on. + * @param n The number of tasks to complete towards the latch, defaults to 1. + */ + auto count_down(glz::thread_pool& tp, std::int64_t n = 1) noexcept -> void + { + if (m_count.fetch_sub(n, std::memory_order::acq_rel) <= n) { + m_event.set(tp); + } + } + + auto operator co_await() const noexcept -> event::awaiter { return m_event.operator co_await(); } + + private: + /// The number of tasks to wait for completion before triggering the event to resume. + std::atomic m_count; + /// The event to trigger when the latch counter reaches zero, this resumes the coroutine that + /// is co_await'ing on the latch. + event m_event; + }; +} diff --git a/include/glaze/coroutine/mutex.hpp b/include/glaze/coroutine/mutex.hpp new file mode 100644 index 0000000000..7353a43c5a --- /dev/null +++ b/include/glaze/coroutine/mutex.hpp @@ -0,0 +1,223 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include +#include +#include + +namespace glz +{ + struct mutex; + + /** + * A scoped RAII lock holder, just like std::lock_guard or std::scoped_lock in that the coro::mutex + * is always unlocked unpon this coro::scoped_lock going out of scope. It is possible to unlock the + * coro::mutex prior to the end of its current scope by manually calling the unlock() function. + */ + struct scoped_lock + { + friend struct mutex; + + public: + enum struct lock_strategy { + /// The lock is already acquired, adopt it as the new owner. + adopt + }; + + explicit scoped_lock(mutex& m, lock_strategy strategy = lock_strategy::adopt) : m_mutex(&m) + { + // Future -> support acquiring the lock? Not sure how to do that without being able to + // co_await in the constructor. + (void)strategy; + } + + /** + * Unlocks the mutex upon this shared lock destructing. + */ + ~scoped_lock() + { + unlock(); + } + + scoped_lock(const scoped_lock&) = delete; + scoped_lock(scoped_lock&& other) : m_mutex(std::exchange(other.m_mutex, nullptr)) {} + auto operator=(const scoped_lock&) -> scoped_lock& = delete; + auto operator=(scoped_lock&& other) noexcept -> scoped_lock& + { + if (std::addressof(other) != this) { + m_mutex = std::exchange(other.m_mutex, nullptr); + } + return *this; + } + + /** + * Unlocks the scoped lock prior to it going out of scope. Calling this multiple times has no + * additional affect after the first call. + */ + void unlock(); + + private: + mutex* m_mutex{}; + }; + + struct mutex + { + explicit mutex() noexcept : m_state(const_cast(unlocked_value())) {} + ~mutex() = default; + + mutex(const mutex&) = delete; + mutex(mutex&&) = delete; + auto operator=(const mutex&) -> mutex& = delete; + auto operator=(mutex&&) -> mutex& = delete; + + struct lock_operation + { + explicit lock_operation(mutex& m) : m_mutex(m) {} + + bool await_ready() const noexcept + { + if (m_mutex.try_lock()) + { + // Since there is no mutex acquired, insert a memory fence to act like it. + std::atomic_thread_fence(std::memory_order::acquire); + return true; + } + return false; + } + + bool await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + { + m_awaiting_coroutine = awaiting_coroutine; + void* current = m_mutex.m_state.load(std::memory_order::acquire); + void* new_value; + + const void* unlocked_value = m_mutex.unlocked_value(); + do + { + if (current == unlocked_value) + { + // If the current value is 'unlocked' then attempt to lock it. + new_value = nullptr; + } + else + { + // If the current value is a waiting lock operation, or nullptr, set our next to that + // lock op and attempt to set ourself as the head of the waiter list. + m_next = static_cast(current); + new_value = static_cast(this); + } + } while (!m_mutex.m_state.compare_exchange_weak(current, new_value, std::memory_order::acq_rel)); + + // Don't suspend if the state went from unlocked -> locked with zero waiters. + if (current == unlocked_value) + { + std::atomic_thread_fence(std::memory_order::acquire); + m_awaiting_coroutine = nullptr; // nothing to await later since this doesn't suspend + return false; + } + + return true; + } + + scoped_lock await_resume() noexcept { return scoped_lock{m_mutex}; } + + private: + friend struct mutex; + + mutex& m_mutex; + std::coroutine_handle<> m_awaiting_coroutine; + lock_operation* m_next{nullptr}; + }; + + /** + * To acquire the mutex's lock co_await this function. Upon acquiring the lock it returns + * a coro::scoped_lock which will hold the mutex until the coro::scoped_lock destructs. + * @return A co_await'able operation to acquire the mutex. + */ + [[nodiscard]] auto lock() -> lock_operation { return lock_operation{*this}; }; + + /** + * Attempts to lock the mutex. + * @return True if the mutex lock was acquired, otherwise false. + */ + bool try_lock() + { + void* expected = const_cast(unlocked_value()); + return m_state.compare_exchange_strong(expected, nullptr, std::memory_order::acq_rel, std::memory_order::relaxed); + } + + /** + * Releases the mutex's lock. + */ + void unlock() + { + if (m_internal_waiters == nullptr) + { + void* current = m_state.load(std::memory_order::relaxed); + if (current == nullptr) + { + // If there are no internal waiters and there are no atomic waiters, attempt to set the + // mutex as unlocked. + if (m_state.compare_exchange_strong( + current, + const_cast(unlocked_value()), + std::memory_order::release, + std::memory_order::relaxed)) + { + return; // The mutex is now unlocked with zero waiters. + } + // else we failed to unlock, someone added themself as a waiter. + } + + // There are waiters on the atomic list, acquire them and update the state for all others. + m_internal_waiters = static_cast(m_state.exchange(nullptr, std::memory_order::acq_rel)); + + // Should internal waiters be reversed to allow for true FIFO, or should they be resumed + // in this reverse order to maximum throuhgput? If this list ever gets 'long' the reversal + // will take some time, but it might guarantee better latency across waiters. This LIFO + // middle ground on the atomic waiters means the best throughput at the cost of the first + // waiter possibly having added latency based on the queue length of waiters. Either way + // incurs a cost but this way for short lists will most likely be faster even though it + // isn't completely fair. + } + + // assert m_internal_waiters != nullptr + + lock_operation* to_resume = m_internal_waiters; + m_internal_waiters = m_internal_waiters->m_next; + to_resume->m_awaiting_coroutine.resume(); + } + + private: + friend struct lock_operation; + + /// unlocked -> state == unlocked_value() + /// locked but empty waiter list == nullptr + /// locked with waiters == lock_operation* + std::atomic m_state; + + /// A list of grabbed internal waiters that are only accessed by the unlock()'er. + lock_operation* m_internal_waiters{nullptr}; + + /// Inactive value, this cannot be nullptr since we want nullptr to signify that the mutex + /// is locked but there are zero waiters, this makes it easy to CAS new waiters into the + /// m_state linked list. + auto unlocked_value() const noexcept -> const void* { return &m_state; } + }; + + void scoped_lock::unlock() + { + if (m_mutex) + { + std::atomic_thread_fence(std::memory_order::release); + m_mutex->unlock(); + // Only allow a scoped lock to unlock the mutex a single time. + m_mutex = nullptr; + } + } +} diff --git a/include/glaze/coroutine/poll.hpp b/include/glaze/coroutine/poll.hpp new file mode 100644 index 0000000000..58b3026506 --- /dev/null +++ b/include/glaze/coroutine/poll.hpp @@ -0,0 +1,18 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include + +#include "glaze/network/core.hpp" +#include "glaze/reflection/enum_macro.hpp" + +namespace glz +{ + GLZ_ENUM(poll_op, read, write, read_write); + + GLZ_ENUM(poll_status, event, timeout, error, closed); +} diff --git a/include/glaze/coroutine/poll_info.hpp b/include/glaze/coroutine/poll_info.hpp new file mode 100644 index 0000000000..ccc1226dfa --- /dev/null +++ b/include/glaze/coroutine/poll_info.hpp @@ -0,0 +1,79 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#pragma once + +#include +#include +#include +#include +#include + +#include "glaze/coroutine/poll.hpp" +#include "glaze/network/core.hpp" + +namespace glz +{ + /** + * Poll Info encapsulates everything about a poll operation for the event as well as its paired + * timeout. This is important since coroutines that are waiting on an event or timeout do not + * immediately execute, they are re-scheduled onto the thread pool, so its possible its pair + * event or timeout also triggers while the coroutine is still waiting to resume. This means that + * the first one to happen, the event itself or its timeout, needs to disable the other pair item + * prior to resuming the coroutine. + * + * Finally, its also important to note that the event and its paired timeout could happen during + * the same epoll_wait and possibly trigger the coroutine to start twice. Only one can win, so the + * first one processed sets m_processed to true and any subsequent events in the same epoll batch + * are effectively discarded. + */ + struct poll_info final + { + using timed_events = std::multimap; + + poll_info() = default; + ~poll_info() = default; + + poll_info(const poll_info&) = delete; + poll_info(poll_info&&) = delete; + auto operator=(const poll_info&) -> poll_info& = delete; + auto operator=(poll_info&&) -> poll_info& = delete; + + struct poll_awaiter final + { + glz::poll_info& poll_info; + + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + { + poll_info.m_awaiting_coroutine = awaiting_coroutine; + std::atomic_thread_fence(std::memory_order::release); + } + poll_status await_resume() noexcept { return poll_info.m_poll_status; } + }; + + poll_awaiter operator co_await() noexcept { return {*this}; } + + /// The file descriptor being polled on. This is needed so that if the timeout occurs first then + /// the event loop can immediately disable the event within epoll. + net::event_handle_t m_fd{net::invalid_event_handle}; + /// The timeout's position in the timeout map. A poll() with no timeout or yield() this is empty. + /// This is needed so that if the event occurs first then the event loop can immediately disable + /// the timeout within epoll. + std::optional m_timer_pos{}; + /// The awaiting coroutine for this poll info to resume upon event or timeout. + std::coroutine_handle<> m_awaiting_coroutine{}; + /// The status of the poll operation. + poll_status m_poll_status{glz::poll_status::error}; + /// Did the timeout and event trigger at the same time on the same epoll_wait call? + /// Once this is set to true all future events on this poll info are null and void. + bool m_processed{false}; + + /// The operation for deleting on Mac + glz::poll_op op{}; + }; +} diff --git a/include/glaze/coroutine/ring_buffer.hpp b/include/glaze/coroutine/ring_buffer.hpp new file mode 100644 index 0000000000..f700092904 --- /dev/null +++ b/include/glaze/coroutine/ring_buffer.hpp @@ -0,0 +1,301 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include +#include +#include +#include + +#include "glaze/util/expected.hpp" + +namespace glz +{ + namespace rb + { + enum struct produce_result { produced, ring_buffer_stopped }; + + enum struct consume_result { ring_buffer_stopped }; + } + + /** + * @tparam element The type of element the ring buffer will store. Note that this type should be + * cheap to move if possible as it is moved into and out of the buffer upon produce and + * consume operations. + * @tparam num_elements The maximum number of elements the ring buffer can store, must be >= 1. + */ + template + class ring_buffer + { + public: + /** + * static_assert If `num_elements` == 0. + */ + ring_buffer() { static_assert(num_elements != 0, "num_elements cannot be zero"); } + + ~ring_buffer() + { + // Wake up anyone still using the ring buffer. + notify_waiters(); + } + + ring_buffer(const ring_buffer&) = delete; + ring_buffer(ring_buffer&&) = delete; + + auto operator=(const ring_buffer&) noexcept + -> ring_buffer& = delete; + auto operator=(ring_buffer&&) noexcept -> ring_buffer& = delete; + + struct produce_operation + { + produce_operation(ring_buffer& rb, element e) : m_rb(rb), m_e(std::move(e)) {} + + auto await_ready() noexcept -> bool + { + std::unique_lock lk{m_rb.m_mutex}; + return m_rb.try_produce_locked(lk, m_e); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + std::unique_lock lk{m_rb.m_mutex}; + // Its possible a consumer on another thread consumed an item between await_ready() and await_suspend() + // so we must check to see if there is space again. + if (m_rb.try_produce_locked(lk, m_e)) { + return false; + } + + // Don't suspend if the stop signal has been set. + if (m_rb.m_stopped.load(std::memory_order::acquire)) { + m_stopped = true; + return false; + } + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_produce_waiters; + m_rb.m_produce_waiters = this; + return true; + } + + /** + * @return produce_result + */ + auto await_resume() -> rb::produce_result + { + return !m_stopped ? rb::produce_result::produced : rb::produce_result::ring_buffer_stopped; + } + + private: + template + friend class ring_buffer; + + /// The ring buffer the element is being produced into. + ring_buffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be produced. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of produce operations that are awaiting to produce their element. + produce_operation* m_next{nullptr}; + /// The element this produce operation is producing into the ring buffer. + element m_e; + /// Was the operation stopped? + bool m_stopped{false}; + }; + + struct consume_operation + { + explicit consume_operation(ring_buffer& rb) : m_rb(rb) {} + + auto await_ready() noexcept -> bool + { + std::unique_lock lk{m_rb.m_mutex}; + return m_rb.try_consume_locked(lk, this); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + std::unique_lock lk{m_rb.m_mutex}; + // We have to check again as there is a race condition between await_ready() and now on the mutex acquire. + // It is possible that a producer added items between await_ready() and await_suspend(). + if (m_rb.try_consume_locked(lk, this)) { + return false; + } + + // Don't suspend if the stop signal has been set. + if (m_rb.m_stopped.load(std::memory_order::acquire)) { + m_stopped = true; + return false; + } + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_consume_waiters; + m_rb.m_consume_waiters = this; + return true; + } + + /** + * @return The consumed element or std::nullopt if the consume has failed. + */ + auto await_resume() -> expected + { + if (m_stopped) { + return unexpected(rb::consume_result::ring_buffer_stopped); + } + + return std::move(m_e); + } + + private: + template + friend class ring_buffer; + + /// The ring buffer to consume an element from. + ring_buffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be consumed. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of consume operations that are awaiting to consume an element. + consume_operation* m_next{nullptr}; + /// The element this consume operation will consume. + element m_e; + /// Was the operation stopped? + bool m_stopped{false}; + }; + + /** + * Produces the given element into the ring buffer. This operation will suspend until a slot + * in the ring buffer becomes available. + * @param e The element to produce. + */ + [[nodiscard]] auto produce(element e) -> produce_operation { return produce_operation{*this, std::move(e)}; } + + /** + * Consumes an element from the ring buffer. This operation will suspend until an element in + * the ring buffer becomes available. + */ + [[nodiscard]] auto consume() -> consume_operation { return consume_operation{*this}; } + + /** + * @return The current number of elements contained in the ring buffer. + */ + auto size() const -> size_t + { + std::atomic_thread_fence(std::memory_order::acquire); + return m_used; + } + + /** + * @return True if the ring buffer contains zero elements. + */ + auto empty() const -> bool { return size() == 0; } + + /** + * Wakes up all currently awaiting producers and consumers. Their await_resume() function + * will return an expected consume result that the ring buffer has stopped. + */ + auto notify_waiters() -> void + { + std::unique_lock lk{m_mutex}; + // Only wake up waiters once. + if (m_stopped.load(std::memory_order::acquire)) { + return; + } + + m_stopped.exchange(true, std::memory_order::release); + + while (m_produce_waiters != nullptr) { + auto* to_resume = m_produce_waiters; + to_resume->m_stopped = true; + m_produce_waiters = m_produce_waiters->m_next; + + lk.unlock(); + to_resume->m_awaiting_coroutine.resume(); + lk.lock(); + } + + while (m_consume_waiters != nullptr) { + auto* to_resume = m_consume_waiters; + to_resume->m_stopped = true; + m_consume_waiters = m_consume_waiters->m_next; + + lk.unlock(); + to_resume->m_awaiting_coroutine.resume(); + lk.lock(); + } + } + + private: + friend produce_operation; + friend consume_operation; + + std::mutex m_mutex{}; + + std::array m_elements{}; + /// The current front pointer to an open slot if not full. + size_t m_front{0}; + /// The current back pointer to the oldest item in the buffer if not empty. + size_t m_back{0}; + /// The number of items in the ring buffer. + size_t m_used{0}; + + /// The LIFO list of produce waiters. + produce_operation* m_produce_waiters{nullptr}; + /// The LIFO list of consume watier. + consume_operation* m_consume_waiters{nullptr}; + + std::atomic m_stopped{false}; + + auto try_produce_locked(std::unique_lock& lk, element& e) -> bool + { + if (m_used == num_elements) { + return false; + } + + m_elements[m_front] = std::move(e); + m_front = (m_front + 1) % num_elements; + ++m_used; + + if (m_consume_waiters != nullptr) { + consume_operation* to_resume = m_consume_waiters; + m_consume_waiters = m_consume_waiters->m_next; + + // Since the consume operation suspended it needs to be provided an element to consume. + to_resume->m_e = std::move(m_elements[m_back]); + m_back = (m_back + 1) % num_elements; + --m_used; // And we just consumed up another item. + + lk.unlock(); + to_resume->m_awaiting_coroutine.resume(); + } + + return true; + } + + auto try_consume_locked(std::unique_lock& lk, consume_operation* op) -> bool + { + if (m_used == 0) { + return false; + } + + op->m_e = std::move(m_elements[m_back]); + m_back = (m_back + 1) % num_elements; + --m_used; + + if (m_produce_waiters != nullptr) { + produce_operation* to_resume = m_produce_waiters; + m_produce_waiters = m_produce_waiters->m_next; + + // Since the produce operation suspended it needs to be provided a slot to place its element. + m_elements[m_front] = std::move(to_resume->m_e); + m_front = (m_front + 1) % num_elements; + ++m_used; // And we just produced another item. + + lk.unlock(); + to_resume->m_awaiting_coroutine.resume(); + } + + return true; + } + }; +} diff --git a/include/glaze/coroutine/scheduler.hpp b/include/glaze/coroutine/scheduler.hpp new file mode 100644 index 0000000000..0e1ad90a18 --- /dev/null +++ b/include/glaze/coroutine/scheduler.hpp @@ -0,0 +1,978 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#include +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +#include "glaze/coroutine/poll.hpp" +#include "glaze/coroutine/poll_info.hpp" +#include "glaze/coroutine/task_container.hpp" +#include "glaze/coroutine/thread_pool.hpp" +#include "glaze/network/core.hpp" +#include "glaze/network/socket.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace glz +{ + enum struct thread_strategy { + /// Spawns a dedicated background thread for the scheduler to run on. + spawn, + /// Requires the user to call process_events() to drive the scheduler. + manual + }; + + struct scheduler final + { + using clock = std::chrono::steady_clock; + using time_point = clock::time_point; + using timed_events = poll_info::timed_events; + + enum struct execution_strategy { + /// Tasks will be FIFO queued to be executed on a thread pool. This is better for tasks that + /// are long lived and will use lots of CPU because long lived tasks will block other i/o + /// operations while they complete. This strategy is generally better for lower latency + /// requirements at the cost of throughput. + process_tasks_on_thread_pool, + /// Tasks will be executed inline on the io scheduler thread. This is better for short tasks + /// that can be quickly processed and not block other i/o operations for very long. This + /// strategy is generally better for higher throughput at the cost of latency. + process_tasks_inline + }; + + struct options + { + /// Should the io scheduler spawn a dedicated event processor? + glz::thread_strategy thread_strategy{glz::thread_strategy::spawn}; + /// If spawning a dedicated event processor a functor to call upon that thread starting. + std::function on_io_thread_start_functor{}; + /// If spawning a dedicated event processor a functor to call upon that thread stopping. + std::function on_io_thread_stop_functor{}; + /// Thread pool options for the task processor threads. See thread pool for more details. + thread_pool::options pool{ + .thread_count = ((std::thread::hardware_concurrency() > 1) ? (std::thread::hardware_concurrency() - 1) : 1), + .on_thread_start_functor = nullptr, + .on_thread_stop_functor = nullptr}; + + /// If inline task processing is enabled then the io worker will resume tasks on its thread + /// rather than scheduling them to be picked up by the thread pool. + const glz::scheduler::execution_strategy execution_strategy{glz::scheduler::execution_strategy::process_tasks_on_thread_pool}; + }; + + scheduler() { init(); } + + scheduler(options opts) : opts(std::move(opts)) { init(); } + + scheduler(const scheduler&) = delete; + scheduler(scheduler&&) = delete; + scheduler& operator=(const scheduler&) = delete; + scheduler& operator=(scheduler&&) = delete; + + ~scheduler() + { + shutdown(); + + if (m_io_thread.joinable()) { + m_io_thread.join(); + } + +#if defined(__linux__) + glz::net::close_socket(event_fd); + glz::net::close_socket(timer_fd); + glz::net::close_socket(schedule_fd); + +#elif defined(_WIN32) + glz::net::close_event(event_fd); + glz::net::close_event(timer_fd); + glz::net::close_event(schedule_fd); +#endif + } + + /** + * Given a thread_strategy_t::manual this function should be called at regular intervals to + * process events that are ready. If a using thread_strategy_t::spawn this is run continously + * on a dedicated background thread and does not need to be manually invoked. + * @param timeout If no events are ready how long should the function wait for events to be ready? + * Passing zero (default) for the timeout will check for any events that are + * ready now, and then return. This could be zero events. Passing -1 means block + * indefinitely until an event happens. + * @param return The number of tasks currently executing or waiting to execute. + */ + size_t process_events(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + { + process_events_manual(timeout); + return size(); + } + + struct schedule_operation + { + /// The thread pool that this operation will execute on. + glz::scheduler& scheduler; + + /** + * Operations always pause so the executing thread can be switched. + */ + bool await_ready() noexcept { return false; } + + /** + * Suspending always returns to the caller (using void return of await_suspend()) and + * stores the coroutine internally for the executing thread to resume from. + */ + void await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + { + if (scheduler.opts.execution_strategy == execution_strategy::process_tasks_inline) { + scheduler.n_active_tasks.fetch_add(1, std::memory_order::release); + { + std::scoped_lock lk{scheduler.m_scheduled_tasks_mutex}; + scheduler.m_scheduled_tasks.emplace_back(awaiting_coroutine); + } + + // Trigger the event to wake-up the scheduler if this event isn't currently triggered. + bool expected{false}; + if (scheduler.m_schedule_fd_triggered.compare_exchange_strong( + expected, true, std::memory_order::release, std::memory_order::relaxed)) { +#if defined(__linux__) + eventfd_t value{1}; + eventfd_write(scheduler.schedule_fd, value); +#elif defined(__APPLE__) + net::poll_event_t e{ + .filter = EVFILT_USER, .fflags = NOTE_TRIGGER, .udata = const_cast(m_schedule_ptr)}; + if (::kevent(scheduler.event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + GLZ_THROW_OR_ABORT(std::runtime_error("Failed to trigger wke up")); + } +#endif + } + } + else { + scheduler.m_thread_pool->resume(awaiting_coroutine); + } + } + + /** + * no-op as this is the function called first by the thread pool's executing thread. + */ + void await_resume() noexcept {} + }; + + /** + * Schedules the current task onto this scheduler for execution. + */ + schedule_operation schedule() { return {*this}; } + + /** + * Schedules the current task to run after the given amount of time has elapsed. + * @param amount The amount of time to wait before resuming execution of this task. + * Given zero or negative amount of time this behaves identical to schedule(). + */ + [[nodiscard]] auto schedule_after(std::chrono::milliseconds amount) -> glz::task + { + return yield_for(amount); + } + + /** + * Schedules the current task to run at a given time point in the future. + * @param time The time point to resume execution of this task. Given 'now' or a time point + * in the past this behaves identical to schedule(). + */ + [[nodiscard]] auto schedule_at(time_point time) -> glz::task { return yield_until(time); } + + /** + * Yields the current task to the end of the queue of waiting tasks. + */ + [[nodiscard]] auto yield() -> schedule_operation { return schedule_operation{*this}; }; + + /** + * Yields the current task for the given amount of time. + * @param amount The amount of time to yield for before resuming executino of this task. + * Given zero or negative amount of time this behaves identical to yield(). + */ + [[nodiscard]] auto yield_for(std::chrono::milliseconds amount) -> glz::task + { + if (amount <= std::chrono::milliseconds(0)) { + co_await schedule(); + } + else { + // Yield/timeout tasks are considered live in the scheduler and must be accounted for. Note + // that if the user gives an invalid amount and schedule() is directly called it will account + // for the scheduled task there. + n_active_tasks.fetch_add(1, std::memory_order::release); + + // Yielding does not requiring setting the timer position on the poll info since + // it doesn't have a corresponding 'event' that can trigger, it always waits for + // the timeout to occur before resuming. + + poll_info pi{}; + add_timer_token(clock::now() + amount, pi); + co_await pi; + + n_active_tasks.fetch_sub(1, std::memory_order::release); + } + co_return; + } + + /** + * Yields the current task until the given time point in the future. + * @param time The time point to resume execution of this task. Given 'now' or a time point in the + * in the past this behaves identical to yield(). + */ + [[nodiscard]] auto yield_until(time_point time) -> glz::task + { + auto now = clock::now(); + + // If the requested time is in the past (or now!) bail out! + if (time <= now) { + co_await schedule(); + } + else { + n_active_tasks.fetch_add(1, std::memory_order::release); + + auto amount = std::chrono::duration_cast(time - now); + + poll_info pi{}; + add_timer_token(now + amount, pi); + co_await pi; + + n_active_tasks.fetch_sub(1, std::memory_order::release); + } + co_return; + } + + /** + * Polls the given file descriptor for the given operations. + * @param fd The file descriptor to poll for events. + * @param op The operations to poll for. + * @param timeout The amount of time to wait for the events to trigger. A timeout of zero will + * block indefinitely until the event triggers. + * @return The result of the poll operation. + */ + [[nodiscard]] task poll(net::event_handle_t fd, glz::poll_op op, + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + { + // Because the size will drop when this coroutine suspends every poll needs to undo the subtraction + // on the number of active tasks in the scheduler. When this task is resumed by the event loop. + n_active_tasks.fetch_add(1, std::memory_order::release); + + // Setup two events, a timeout event and the actual poll for op event. + // Whichever triggers first will delete the other to guarantee only one wins. + // The resume token will be set by the scheduler to what the event turned out to be. + + bool timeout_requested = (timeout > std::chrono::milliseconds(0)); + + glz::poll_info poll_info{}; + poll_info.m_fd = fd; + poll_info.op = op; + + if (timeout_requested) { + poll_info.m_timer_pos = add_timer_token(clock::now() + timeout, poll_info); + } + +#if defined(__linux__) + net::poll_event_t e{}; + + switch (op) + { + case poll_op::read: { + e.events = EPOLLIN | EPOLLONESHOT | EPOLLRDHUP; + break; + } + case poll_op::write: { + e.events = EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP; + break; + } + case poll_op::read_write: { + e.events = EPOLLIN | EPOLLOUT | EPOLLONESHOT | EPOLLRDHUP; + break; + } + default: { + break; + } + } + + e.data.ptr = &poll_info; + if (epoll_ctl(event_fd, EPOLL_CTL_ADD, fd, &e) == -1) { + std::cerr << "epoll ctl error on fd " << fd << "\n"; + } +#elif defined(__APPLE__) + if (op == poll_op::read_write) { + struct kevent events[2]; + + EV_SET(&events[0], fd, EVFILT_READ, EV_ADD | EV_EOF | EV_ONESHOT, 0, 0, &poll_info); + EV_SET(&events[1], fd, EVFILT_WRITE, EV_ADD | EV_EOF | EV_ONESHOT, 0, 0, &poll_info); + + if (kevent(event_fd, events, 2, nullptr, 0, nullptr) == -1) { + std::cerr << "kqueue failed to register read/write for file_descriptor: " << fd << "\n"; + } + } + else { + net::poll_event_t e{.flags = EV_ADD | EV_EOF | EV_ONESHOT, .udata = &poll_info}; + + switch (op) + { + case poll_op::read: { + e.filter = EVFILT_READ; + break; + } + case poll_op::write: { + e.filter = EVFILT_WRITE; + break; + } + default: { + break; + } + } + + if (::kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + std::cerr << "kqueue failed to register for file_descriptor: " << fd << "\n"; + } + } +#elif defined(_WIN32) + +#endif + + // The event loop will 'clean-up' whichever event didn't win since the coroutine is scheduled + // onto the thread poll its possible the other type of event could trigger while its waiting + // to execute again, thus restarting the coroutine twice, that would be quite bad. + auto result = co_await poll_info; + n_active_tasks.fetch_sub(1, std::memory_order::release); + co_return result; + } + + /** + * Polls the given coro::net::socket for the given operations. + * @param sock The socket to poll for events on. + * @param op The operations to poll for. + * @param timeout The amount of time to wait for the events to trigger. A timeout of zero will + * block indefinitely until the event triggers. + * @return THe result of the poll operation. + */ + [[nodiscard]] task poll(const socket& sock, poll_op op, + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + { + return poll(sock.socket_fd, op, timeout); + } + + /** + * Resumes execution of a direct coroutine handle on this io scheduler. + * @param handle The coroutine handle to resume execution. + */ + auto resume(std::coroutine_handle<> handle) -> bool + { + if (not handle) { + return false; + } + + if (m_shutdown_requested.load(std::memory_order::acquire)) { + return false; + } + + if (opts.execution_strategy == execution_strategy::process_tasks_inline) { + { + std::scoped_lock lk{m_scheduled_tasks_mutex}; + m_scheduled_tasks.emplace_back(handle); + } + + bool expected{false}; + if (m_schedule_fd_triggered.compare_exchange_strong(expected, true, std::memory_order::release, + std::memory_order::relaxed)) { +#if defined(__linux__) + eventfd_t value{1}; + eventfd_write(schedule_fd, value); +#elif defined(__APPLE__) + net::poll_event_t e{ + .filter = EVFILT_USER, .fflags = NOTE_TRIGGER, .udata = const_cast(m_schedule_ptr)}; + if (::kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + GLZ_THROW_OR_ABORT(std::runtime_error("Failed to trigger wake up")); + } +#endif + } + + return true; + } + else { + return m_thread_pool->resume(handle); + } + } + + /** + * @return The number of tasks waiting in the task queue + the executing tasks. + */ + size_t size() const noexcept + { + if (opts.execution_strategy == execution_strategy::process_tasks_inline) { + return n_active_tasks.load(std::memory_order::acquire); + } + else { + return n_active_tasks.load(std::memory_order::acquire) + m_thread_pool->size(); + } + } + + /** + * @return True if the task queue is empty and zero tasks are currently executing. + */ + bool empty() const noexcept { return size() == 0; } + + /** + * Starts the shutdown of the io scheduler. All currently executing and pending tasks will complete + * prior to shutting down. This call is blocking and will not return until all tasks complete. + */ + void shutdown() noexcept + { + // Only allow shutdown to occur once. + if (m_shutdown_requested.exchange(true, std::memory_order::acq_rel) == false) { + if (m_thread_pool) { + m_thread_pool->shutdown(); + } + + // Signal the event loop to stop asap, triggering the event fd is safe. +#if defined(__linux__) + uint64_t value{1}; + auto written = ::write(shutdown_fd, &value, sizeof(value)); + (void)written; +#elif defined(__APPLE__) + net::poll_event_t e{ + .filter = EVFILT_USER, .fflags = NOTE_TRIGGER, .udata = const_cast(m_shutdown_ptr)}; + if (::kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + GLZ_THROW_OR_ABORT(std::runtime_error("Failed to signal shutdown event")); + } +#elif defined(_WIN32) + +#endif + + if (m_io_thread.joinable()) { + m_io_thread.join(); + } + } + } + + private: + /// The configuration options. + options opts{}; + + /// The event loop epoll file descriptor. + net::event_handle_t event_fd{net::create_event_poll()}; + /// The event loop fd to trigger a shutdown. + net::event_handle_t shutdown_fd{net::create_shutdown_handle()}; + /// The event loop timer fd for timed events, e.g. yield_for() or scheduler_after(). + net::event_handle_t timer_fd{net::create_timer_handle()}; + /// The schedule file descriptor if the scheduler is in inline processing mode. + net::event_handle_t schedule_fd{net::create_schedule_handle()}; + std::atomic m_schedule_fd_triggered{false}; + + /// The number of tasks executing or awaiting events in this io scheduler. + std::atomic n_active_tasks{0}; + + /// The background io worker threads. + std::thread m_io_thread; + /// Thread pool for executing tasks when not in inline mode. + std::unique_ptr m_thread_pool{nullptr}; + + std::mutex m_timed_events_mutex{}; + /// The map of time point's to poll infos for tasks that are yielding for a period of time + /// or for tasks that are polling with timeouts. + timed_events m_timed_events{}; + + /// Has the scheduler been requested to shut down? + std::atomic m_shutdown_requested{false}; + + std::atomic m_io_processing{false}; + void process_events_manual(std::chrono::milliseconds timeout) + { + bool expected{false}; + if (m_io_processing.compare_exchange_strong(expected, true, std::memory_order::release, + std::memory_order::relaxed)) { + process_events_execute(timeout); + m_io_processing.exchange(false, std::memory_order::release); + } + } + + void init() + { + if (opts.execution_strategy == execution_strategy::process_tasks_on_thread_pool) { + m_thread_pool = std::make_unique(std::move(opts.pool)); + } + + [[maybe_unused]] net::poll_event_t e{}; + + [[maybe_unused]] bool event_setup_failed{}; +#if defined(__linux__) + e.events = EPOLLIN; + + e.data.ptr = const_cast(m_shutdown_ptr); + epoll_ctl(event_fd, EPOLL_CTL_ADD, shutdown_fd, &e); + + e.data.ptr = const_cast(m_timer_ptr); + epoll_ctl(event_fd, EPOLL_CTL_ADD, timer_fd, &e); + + e.data.ptr = const_cast(m_schedule_ptr); + epoll_ctl(event_fd, EPOLL_CTL_ADD, schedule_fd, &e); +#elif defined(__APPLE__) + net::poll_event_t e_timer{.filter = EVFILT_TIMER, .flags = EV_ADD, .udata = const_cast(m_timer_ptr)}; + net::poll_event_t e_shutdown{ + .filter = EVFILT_USER, .flags = EV_ADD | EV_CLEAR, .udata = const_cast(m_shutdown_ptr)}; + net::poll_event_t e_schedule{ + .filter = EVFILT_USER, .flags = EV_ADD, .udata = const_cast(m_schedule_ptr)}; + + ::kevent(event_fd, &e_schedule, 1, nullptr, 0, nullptr); + ::kevent(event_fd, &e_shutdown, 1, nullptr, 0, nullptr); + ::kevent(event_fd, &e_timer, 1, nullptr, 0, nullptr); +#endif + + if (opts.thread_strategy == glz::thread_strategy::spawn) { + m_io_thread = std::thread([this] { process_events_dedicated_thread(); }); + } + // else manual mode, the user must call process_events. + } + + void process_events_dedicated_thread() + { + if (opts.on_io_thread_start_functor) { + opts.on_io_thread_start_functor(); + } + + m_io_processing.exchange(true, std::memory_order::release); + // Execute tasks until stopped or there are no more tasks to complete. + while (!m_shutdown_requested.load(std::memory_order::acquire) || size() > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); // prevent pegging 100% + process_events_execute(m_default_timeout); + } + m_io_processing.exchange(false, std::memory_order::release); + + if (opts.on_io_thread_stop_functor) { + opts.on_io_thread_stop_functor(); + } + } + + void process_events_execute(std::chrono::milliseconds timeout) + { +#if defined(__APPLE__) + struct timespec tlimit + { + 0, timeout.count() * 1'000'000 + }; + const auto event_count = ::kevent(event_fd, nullptr, 0, m_events.data(), int(m_events.size()), &tlimit); +#elif defined(__linux__) + const auto event_count = ::epoll_wait(event_fd, m_events.data(), max_events, timeout.count()); +#elif defined(_WIN32) + + m_events = {shutdown_fd, timer_fd, schedule_fd}; + + enum struct event_id : DWORD { + shutdown, + timer, + schedual + }; + + const auto event_obj = + WaitForMultipleObjects(3 /*number of object handles*/, m_events.data(), FALSE, DWORD(timeout.count())); + + if (WAIT_FAILED == event_obj) { + GLZ_THROW_OR_ABORT(std::runtime_error{"WaitForMultipleObjects for event failed"}); + } + using enum event_id; + + switch (event_id(event_obj)) { + case shutdown: + // TODO: + break; + + case timer: + process_timeout_execute(); + break; + + case schedual: + process_scheduled_execute_inline(); + break; + + default: + break; + // GLZ_THROW_OR_ABORT(std::runtime_error{"Unhandled event id!"}); + } +#endif + +#if defined(__linux__) || defined(__APPLE__) + + //if (event_count == -1) { + // net::close_event(event_fd); + // GLZ_THROW_OR_ABORT(std::runtime_error{"wait for event failed"}); + //} + + if (event_count > 0) { + for (size_t i = 0; i < size_t(event_count); ++i) { + auto& event = m_events[i]; +#if defined(__linux__) + void* handle_ptr = event.data.ptr; +#elif defined(__APPLE__) + void* handle_ptr = event.udata; + + if (event.flags & EV_ERROR) { + GLZ_THROW_OR_ABORT(std::runtime_error{"event error"}); + } +#endif + + if (not handle_ptr) { + GLZ_THROW_OR_ABORT(std::runtime_error{"handle_ptr is null"}); + } + + if (handle_ptr == m_timer_ptr) { + process_timeout_execute(); + } + else if (handle_ptr == m_schedule_ptr) { + process_scheduled_execute_inline(); + } + else if (handle_ptr == m_shutdown_ptr) [[unlikely]] { + + } + else { + // Individual poll task wake-up. +#if defined(__linux__) + process_event_execute(static_cast(handle_ptr), event_to_poll_status(event.events)); +#elif defined(__APPLE__) + process_event_execute(static_cast(handle_ptr), event_to_poll_status(event.flags)); +#endif + } + } + } +#endif + // Its important to not resume any handles until the full set is accounted for. If a timeout + // and an event for the same handle happen in the same epoll_wait() call then inline processing + // will destruct the poll_info object before the second event is handled. This is also possible + // with thread pool processing, but probably has an extremely low chance of occuring due to + // the thread switch required. If max_events == 1 this would be unnecessary. + + if (!m_handles_to_resume.empty()) { + if (opts.execution_strategy == execution_strategy::process_tasks_inline) { + for (auto& handle : m_handles_to_resume) { + handle.resume(); + } + } + else { + m_thread_pool->resume(m_handles_to_resume); + } + + m_handles_to_resume.clear(); + } + } + + static poll_status event_to_poll_status(uint32_t events) + { + if (events & net::poll_in || events & net::poll_out) { + return poll_status::event; + } + else if (net::poll_error(events)) { + return poll_status::error; + } + else if (net::event_closed(events)) { + return poll_status::closed; + } + + GLZ_THROW_OR_ABORT(std::runtime_error{"invalid epoll state"}); + } + + void process_scheduled_execute_inline() + { + std::vector> tasks{}; + { + // Acquire the entire list, and then reset it. + std::scoped_lock lk{m_scheduled_tasks_mutex}; + tasks.swap(m_scheduled_tasks); + + // Clear the schedule eventfd if this is a scheduled task. +#if defined(__linux__) + eventfd_t value{0}; + eventfd_read(schedule_fd, &value); +#elif defined(__APPLE__) +#endif + + // Clear the in memory flag to reduce eventfd_* calls on scheduling. + m_schedule_fd_triggered.exchange(false, std::memory_order::release); + } + + // This set of handles can be safely resumed now since they do not have a corresponding timeout event. + for (auto& task : tasks) { + task.resume(); + } + n_active_tasks.fetch_sub(tasks.size(), std::memory_order::release); + } + + std::mutex m_scheduled_tasks_mutex{}; + std::vector> m_scheduled_tasks{}; + + static constexpr int m_shutdown_object{0}; + static constexpr const void* m_shutdown_ptr = &m_shutdown_object; + + static constexpr int m_timer_object{0}; + static constexpr const void* m_timer_ptr = &m_timer_object; + + static constexpr int m_schedule_object{0}; + static constexpr const void* m_schedule_ptr = &m_schedule_object; + + static constexpr std::chrono::milliseconds m_default_timeout{1000}; + static constexpr std::chrono::milliseconds m_no_timeout{0}; + static constexpr size_t max_events = 16; + std::array m_events{}; + std::vector> m_handles_to_resume{}; + + void apple_delete_poll_event(glz::poll_info* poll_info) + { +#if defined(__APPLE__) + if (poll_info->op == poll_op::read_write) { + struct kevent event; + EV_SET(&event, poll_info->m_fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + + if (kevent(event_fd, &event, 1, nullptr, 0, nullptr) == -1) { + // It's often okay if this fails, as the descriptor might not have been registered + // But you might want to log it for debugging purposes + // std::cerr << "kqueue failed to delete read event for fd: " << poll_info->m_fd << "\n"; + } + + EV_SET(&event, poll_info->m_fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + + if (kevent(event_fd, &event, 1, nullptr, 0, nullptr) == -1) { + // Again, failure here is often okay + // std::cerr << "kqueue failed to delete write event for fd: " << poll_info->m_fd << "\n"; + } + } + else { + switch (poll_info->op) + { + case poll_op::read: { + struct kevent e; + EV_SET(&e, poll_info->m_fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + + if (kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + // It's often okay if this fails, as the descriptor might not have been registered + // But you might want to log it for debugging purposes + // std::cerr << "kqueue failed to delete read event for fd: " << poll_info->m_fd << "\n"; + } + + break; + } + case poll_op::write: { + struct kevent e; + EV_SET(&e, poll_info->m_fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + + if (kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + // Failure here is often okay + // std::cerr << "kqueue failed to delete write event for fd: " << poll_info->m_fd << "\n"; + } + break; + } + default: { + break; + } + } + } +#else + (void)poll_info; +#endif + } + + void process_event_execute(glz::poll_info* poll_info, poll_status status) + { + if (not poll_info) { + GLZ_THROW_OR_ABORT(std::runtime_error{"invalid poll_info"}); + } + + if (not poll_info->m_processed) { + std::atomic_thread_fence(std::memory_order::acquire); + // Its possible the event and the timeout occurred in the same epoll, make sure only one + // is ever processed, the other is discarded. + poll_info->m_processed = true; + + // Given a valid fd always remove it from epoll so the next poll can blindly EPOLL_CTL_ADD. + if (poll_info->m_fd != net::invalid_event_handle) { +#if defined(__APPLE__) + apple_delete_poll_event(poll_info); +#elif defined(__linux__) + epoll_ctl(event_fd, EPOLL_CTL_DEL, poll_info->m_fd, nullptr); +#endif + } + + // Since this event triggered, remove its corresponding timeout if it has one. + if (poll_info->m_timer_pos.has_value()) { + remove_timer_token(poll_info->m_timer_pos.value()); + } + + poll_info->m_poll_status = status; + + while (not poll_info->m_awaiting_coroutine) { + std::atomic_thread_fence(std::memory_order::acquire); + } + + m_handles_to_resume.emplace_back(poll_info->m_awaiting_coroutine); + } + } + + void process_timeout_execute() + { + std::vector poll_infos{}; + auto now = clock::now(); + + { + std::scoped_lock lk{m_timed_events_mutex}; + while (!m_timed_events.empty()) { + auto first = m_timed_events.begin(); + auto [tp, pi] = *first; + + if (tp <= now) { + m_timed_events.erase(first); + poll_infos.emplace_back(pi); + } + else { + break; + } + } + } + + for (auto pi : poll_infos) { + if (!pi->m_processed) { + // Its possible the event and the timeout occurred in the same epoll, make sure only one + // is ever processed, the other is discarded. + pi->m_processed = true; + + // Since this timed out, remove its corresponding event if it has one. + if (pi->m_fd != net::invalid_event_handle) { +#if defined(__linux__) + epoll_ctl(event_fd, EPOLL_CTL_DEL, pi->m_fd, nullptr); +#elif defined(__APPLE__) + apple_delete_poll_event(pi); +#endif + } + + while (not pi->m_awaiting_coroutine) { + std::atomic_thread_fence(std::memory_order::acquire); + std::cerr << "process_event_execute() has a null event\n"; + } + + m_handles_to_resume.emplace_back(pi->m_awaiting_coroutine); + pi->m_poll_status = poll_status::timeout; + } + } + + // Update the time to the next smallest time point, re-take the current now time + // since updating and resuming tasks could shift the time. + update_timeout(clock::now()); + } + + auto add_timer_token(time_point tp, poll_info& pi) -> timed_events::iterator + { + std::scoped_lock lk{m_timed_events_mutex}; + auto pos = m_timed_events.emplace(tp, &pi); + + // If this item was inserted as the smallest time point, update the timeout. + if (pos == m_timed_events.begin()) { + update_timeout(clock::now()); + } + + return pos; + } + + void remove_timer_token(timed_events::iterator pos) + { + std::scoped_lock lk{m_timed_events_mutex}; + auto is_first = (m_timed_events.begin() == pos); + + m_timed_events.erase(pos); + + // If this was the first item, update the timeout. It would be acceptable to just let it + // also fire the timeout as the event loop will ignore it since nothing will have timed + // out but it feels like the right thing to do to update it to the correct timeout value. + if (is_first) { + update_timeout(clock::now()); + } + } + + void update_timeout(time_point now) + { + if (!m_timed_events.empty()) { + auto& [tp, pi] = *m_timed_events.begin(); + +#if defined(__linux__) + size_t seconds{}; + size_t nanoseconds{1}; + if (tp > now) { + const auto time_left = tp - now; + const auto s = std::chrono::duration_cast(time_left); + seconds = s.count(); + nanoseconds = std::chrono::duration_cast(time_left - s).count(); + } + + itimerspec ts{}; + ts.it_value.tv_sec = seconds; + ts.it_value.tv_nsec = nanoseconds; + + if (timerfd_settime(timer_fd, 0, &ts, nullptr) == -1) { + std::cerr << "Failed to set timerfd errorno=[" << std::string{std::strerror(errno)} << "]."; + } +#elif defined(__APPLE__) + size_t milliseconds{}; + if (tp > now) { + milliseconds = std::chrono::duration_cast(tp - now).count(); + } + + net::poll_event_t e{.filter = EVFILT_TIMER, + .fflags = NOTE_TRIGGER, + .data = int64_t(milliseconds), + .udata = const_cast(m_timer_ptr)}; + if (::kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + std::cerr << "Error: kevent (update timer).\n"; + } +#elif defined(_WIN32) + size_t seconds{}; + size_t nanoseconds{1}; + if (tp > now) { + const auto time_left = tp - now; + const auto s = std::chrono::duration_cast(time_left); + seconds = s.count(); + nanoseconds = std::chrono::duration_cast(time_left - s).count(); + } + LARGE_INTEGER signal_time{}; + signal_time.QuadPart = -std::chrono::duration_cast(tp - now).count() / 100; + + if (!SetWaitableTimer(timer_fd, &signal_time, 0, nullptr, nullptr, FALSE)) { + std::cerr << "SetWaitableTimer failed (" << GetLastError() << ")\n"; + } +#endif + } + else { +#if defined(__linux__) + // Setting these values to zero disables the timer. + itimerspec ts{}; + ts.it_value.tv_sec = 0; + ts.it_value.tv_nsec = 0; + if (timerfd_settime(timer_fd, 0, &ts, nullptr) == -1) { + std::cerr << "Failed to set timerfd errorno=[" << std::string{strerror(errno)} << "]."; + } +#elif defined(__APPLE__) + net::poll_event_t e{ + .filter = EVFILT_TIMER, .fflags = NOTE_TRIGGER, .data = 0, .udata = const_cast(m_timer_ptr)}; + if (::kevent(event_fd, &e, 1, nullptr, 0, nullptr) == -1) { + std::cerr << "Error: kevent (update timer).\n"; + } +#elif defined(_WIN32) + LARGE_INTEGER signal_time{}; + if (!SetWaitableTimer(timer_fd, &signal_time, 0, nullptr, nullptr, FALSE)) { + std::cerr << "Error: SetWaitableTimer (disable timer) failed (" << GetLastError() << ")\n"; + } +#endif + } + } + }; +} diff --git a/include/glaze/coroutine/semaphore.hpp b/include/glaze/coroutine/semaphore.hpp new file mode 100644 index 0000000000..de6cbf7ed0 --- /dev/null +++ b/include/glaze/coroutine/semaphore.hpp @@ -0,0 +1,208 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include "glaze/util/expected.hpp" +#include +#include +#include + +namespace glz +{ + struct semaphore + { + enum struct acquire_result { acquired, semaphore_stopped }; + + static std::string to_string(acquire_result ar) + { + switch (ar) { + case acquire_result::acquired: + return "acquired"; + case acquire_result::semaphore_stopped: + return "semaphore_stopped"; + } + + return "unknown"; + } + + explicit semaphore(std::ptrdiff_t least_max_value_and_starting_value) + : semaphore(least_max_value_and_starting_value, least_max_value_and_starting_value) + {} + + explicit semaphore(std::ptrdiff_t least_max_value, std::ptrdiff_t starting_value) + : m_least_max_value(least_max_value), + m_counter(starting_value <= least_max_value ? starting_value : least_max_value) + {} + + + ~semaphore() { + notify_waiters(); + } + + semaphore(const semaphore&) = delete; + semaphore(semaphore&&) = delete; + + auto operator=(const semaphore&) noexcept -> semaphore& = delete; + auto operator=(semaphore&&) noexcept -> semaphore& = delete; + + struct acquire_operation + { + explicit acquire_operation(semaphore& s) : m_semaphore(s) + {} + + bool await_ready() const noexcept + { + if (m_semaphore.m_notify_all_set.load(std::memory_order::relaxed)) + { + return true; + } + return m_semaphore.try_acquire(); + } + + bool await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + { + std::unique_lock lk{m_semaphore.m_waiter_mutex}; + if (m_semaphore.m_notify_all_set.load(std::memory_order::relaxed)) + { + return false; + } + + if (m_semaphore.try_acquire()) + { + return false; + } + + if (m_semaphore.m_acquire_waiters == nullptr) + { + m_semaphore.m_acquire_waiters = this; + } + else + { + // This is LIFO, but semaphores are not meant to be fair. + + // Set our next to the current head. + m_next = m_semaphore.m_acquire_waiters; + // Set the semaphore head to this. + m_semaphore.m_acquire_waiters = this; + } + + m_awaiting_coroutine = awaiting_coroutine; + return true; + } + + acquire_result await_resume() const + { + if (m_semaphore.m_notify_all_set.load(std::memory_order::relaxed)) + { + return acquire_result::semaphore_stopped; + } + return acquire_result::acquired; + } + + private: + friend semaphore; + + semaphore& m_semaphore; + std::coroutine_handle<> m_awaiting_coroutine; + acquire_operation* m_next{nullptr}; + }; + + void release() + { + // It seems like the atomic counter could be incremented, but then resuming a waiter could have + // a race between a new acquirer grabbing the just incremented resource value from us. So its + // best to check if there are any waiters first, and transfer owernship of the resource thats + // being released directly to the waiter to avoid this problem. + + std::unique_lock lk{m_waiter_mutex}; + if (m_acquire_waiters != nullptr) + { + acquire_operation* to_resume = m_acquire_waiters; + m_acquire_waiters = m_acquire_waiters->m_next; + lk.unlock(); + + // This will transfer ownership of the resource to the resumed waiter. + to_resume->m_awaiting_coroutine.resume(); + } + else + { + // Normally would be release but within a lock use releaxed. + m_counter.fetch_add(1, std::memory_order::relaxed); + } + } + + /** + * Acquires a resource from the semaphore, if the semaphore has no resources available then + * this will wait until a resource becomes available. + */ + [[nodiscard]] acquire_operation acquire() { return acquire_operation{*this}; } + + /** + * Attemtps to acquire a resource if there is any resources available. + * @return True if the acquire operation was able to acquire a resource. + */ + bool try_acquire() + { + // Optimistically grab the resource. + auto previous = m_counter.fetch_sub(1, std::memory_order::acq_rel); + if (previous <= 0) + { + // If it wasn't available undo the acquisition. + m_counter.fetch_add(1, std::memory_order::release); + return false; + } + return true; + } + + /** + * @return The maximum number of resources the semaphore can contain. + */ + std::ptrdiff_t max_resources() const noexcept { return m_least_max_value; } + + /** + * The current number of resources available in this semaphore. + */ + std::ptrdiff_t value() const noexcept { return m_counter.load(std::memory_order::relaxed); } + + /** + * Stops the semaphore and will notify all release/acquire waiters to wake up in a failed state. + * Once this is set it cannot be un-done and all future oprations on the semaphore will fail. + */ + void notify_waiters() noexcept + { + m_notify_all_set.exchange(true, std::memory_order::release); + while (true) + { + std::unique_lock lk{m_waiter_mutex}; + if (m_acquire_waiters != nullptr) + { + acquire_operation* to_resume = m_acquire_waiters; + m_acquire_waiters = m_acquire_waiters->m_next; + lk.unlock(); + + to_resume->m_awaiting_coroutine.resume(); + } + else + { + break; + } + } + } + + private: + friend struct release_operation; + friend struct acquire_operation; + + const std::ptrdiff_t m_least_max_value; + std::atomic m_counter; + + std::mutex m_waiter_mutex{}; + acquire_operation* m_acquire_waiters{nullptr}; + + std::atomic m_notify_all_set{false}; + }; +} diff --git a/include/glaze/coroutine/shared_mutex.hpp b/include/glaze/coroutine/shared_mutex.hpp new file mode 100644 index 0000000000..3e58e41559 --- /dev/null +++ b/include/glaze/coroutine/shared_mutex.hpp @@ -0,0 +1,340 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#include +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +#include +#include +#include + +#include "glaze/coroutine/concepts.hpp" + +namespace glz +{ + template + struct shared_mutex; + + /** + * A scoped RAII lock holder for a coro::shared_mutex. It will call the appropriate unlock() or + * unlock_shared() based on how the coro::shared_mutex was originally acquired, either shared or + * exclusive modes. + */ + template + struct shared_scoped_lock + { + shared_scoped_lock(shared_mutex& sm, bool exclusive) : m_shared_mutex(&sm), m_exclusive(exclusive) + {} + + /** + * Unlocks the mutex upon this shared scoped lock destructing. + */ + ~shared_scoped_lock() { unlock(); } + + shared_scoped_lock(const shared_scoped_lock&) = delete; + shared_scoped_lock(shared_scoped_lock&& other) + : m_shared_mutex(std::exchange(other.m_shared_mutex, nullptr)), m_exclusive(other.m_exclusive) + {} + + auto operator=(const shared_scoped_lock&) -> shared_scoped_lock& = delete; + auto operator=(shared_scoped_lock&& other) noexcept -> shared_scoped_lock& + { + if (std::addressof(other) != this) { + m_shared_mutex = std::exchange(other.m_shared_mutex, nullptr); + m_exclusive = other.m_exclusive; + } + return *this; + } + + /** + * Unlocks the shared mutex prior to this lock going out of scope. + */ + auto unlock() -> void + { + if (m_shared_mutex != nullptr) { + if (m_exclusive) { + m_shared_mutex->unlock(); + } + else { + m_shared_mutex->unlock_shared(); + } + + m_shared_mutex = nullptr; + } + } + + private: + shared_mutex* m_shared_mutex{nullptr}; + bool m_exclusive{false}; + }; + + template + struct shared_mutex + { + /** + * @param e The executor for when multiple shared waiters can be woken up at the same time, + * each shared waiter will be scheduled to immediately run on this executor in + * parallel. + */ + explicit shared_mutex(std::shared_ptr e) : m_executor(std::move(e)) + { + if (m_executor == nullptr) { + GLZ_THROW_OR_ABORT(std::runtime_error{"shared_mutex cannot have a nullptr executor"}); + } + } + ~shared_mutex() = default; + + shared_mutex(const shared_mutex&) = delete; + shared_mutex(shared_mutex&&) = delete; + auto operator=(const shared_mutex&) -> shared_mutex& = delete; + auto operator=(shared_mutex&&) -> shared_mutex& = delete; + + struct lock_operation + { + lock_operation(shared_mutex& sm, bool exclusive) : m_shared_mutex(sm), m_exclusive(exclusive) {} + + auto await_ready() const noexcept -> bool + { + if (m_exclusive) { + return m_shared_mutex.try_lock(); + } + else { + return m_shared_mutex.try_lock_shared(); + } + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + std::unique_lock lk{m_shared_mutex.m_mutex}; + // Its possible the lock has been released between await_ready() and await_suspend(), double + // check and make sure we are not going to suspend when nobody holds the lock. + if (m_exclusive) { + if (m_shared_mutex.try_lock_locked(lk)) { + return false; + } + } + else { + if (m_shared_mutex.try_lock_shared_locked(lk)) { + return false; + } + } + + // For sure the lock is currently held in a manner that it cannot be acquired, suspend ourself + // at the end of the waiter list. + + if (m_shared_mutex.m_tail_waiter == nullptr) { + m_shared_mutex.m_head_waiter = this; + m_shared_mutex.m_tail_waiter = this; + } + else { + m_shared_mutex.m_tail_waiter->m_next = this; + m_shared_mutex.m_tail_waiter = this; + } + + // If this is an exclusive lock acquire then mark it as so so that shared locks after this + // exclusive one will also suspend so this exclusive lock doens't get starved. + if (m_exclusive) { + ++m_shared_mutex.m_exclusive_waiters; + } + + m_awaiting_coroutine = awaiting_coroutine; + return true; + } + auto await_resume() noexcept -> shared_scoped_lock + { + return shared_scoped_lock{m_shared_mutex, m_exclusive}; + } + + private: + friend struct shared_mutex; + + shared_mutex& m_shared_mutex; + bool m_exclusive{false}; + std::coroutine_handle<> m_awaiting_coroutine; + lock_operation* m_next{nullptr}; + }; + + /** + * Locks the mutex in a shared state. If there are any exclusive waiters then the shared waiters + * will also wait so the exclusive waiters are not starved. + */ + [[nodiscard]] auto lock_shared() -> lock_operation { return lock_operation{*this, false}; } + + /** + * Locks the mutex in an exclusive state. + */ + [[nodiscard]] auto lock() -> lock_operation { return lock_operation{*this, true}; } + + /** + * @return True if the lock could immediately be acquired in a shared state. + */ + auto try_lock_shared() -> bool + { + // To acquire the shared lock the state must be one of two states: + // 1) unlocked + // 2) shared locked with zero exclusive waiters + // Zero exclusive waiters prevents exclusive starvation if shared locks are + // always continuously happening. + + std::unique_lock lk{m_mutex}; + return try_lock_shared_locked(lk); + } + + /** + * @return True if the lock could immediately be acquired in an exclusive state. + */ + auto try_lock() -> bool + { + // To acquire the exclusive lock the state must be unlocked. + std::unique_lock lk{m_mutex}; + return try_lock_locked(lk); + } + + /** + * Unlocks a single shared state user. *REQUIRES* that the lock was first acquired exactly once + * via `lock_shared()` or `try_lock_shared() -> True` before being called, otherwise undefined + * behavior. + * + * If the shared user count drops to zero and this lock has an exclusive waiter then the exclusive + * waiter acquires the lock. + */ + auto unlock_shared() -> void + { + std::unique_lock lk{m_mutex}; + --m_shared_users; + + // Only wake waiters from shared state if all shared users have completed. + if (m_shared_users == 0) { + if (m_head_waiter != nullptr) { + wake_waiters(lk); + } + else { + m_state = state::unlocked; + } + } + } + + /** + * Unlocks the mutex from its exclusive state. If there is a following exclusive watier then + * that exclusive waiter acquires the lock. If there are 1 or more shared waiters then all the + * shared waiters acquire the lock in a shared state in parallel and are resumed on the original + * thread pool this shared mutex was created with. + */ + auto unlock() -> void + { + std::unique_lock lk{m_mutex}; + if (m_head_waiter != nullptr) { + wake_waiters(lk); + } + else { + m_state = state::unlocked; + } + } + + private: + friend struct lock_operation; + + enum struct state { unlocked, locked_shared, locked_exclusive }; + + /// This executor is for resuming multiple shared waiters. + std::shared_ptr m_executor{nullptr}; + + std::mutex m_mutex; + + state m_state{state::unlocked}; + + /// The current number of shared users that have acquired the lock. + uint64_t m_shared_users{0}; + /// The current number of exclusive waiters waiting to acquire the lock. This is used to block + /// new incoming shared lock attempts so the exclusive waiter is not starved. + uint64_t m_exclusive_waiters{0}; + + lock_operation* m_head_waiter{nullptr}; + lock_operation* m_tail_waiter{nullptr}; + + auto try_lock_shared_locked(std::unique_lock& lk) -> bool + { + if (m_state == state::unlocked) { + // If the shared mutex is unlocked put it into shared mode and add ourself as using the lock. + m_state = state::locked_shared; + ++m_shared_users; + lk.unlock(); + return true; + } + else if (m_state == state::locked_shared && m_exclusive_waiters == 0) { + // If the shared mutex is in a shared locked state and there are no exclusive waiters + // the add ourself as using the lock. + ++m_shared_users; + lk.unlock(); + return true; + } + + // If the lock is in shared mode but there are exclusive waiters then we will also wait so + // the writers are not starved. + + // If the lock is in exclusive mode already then we need to wait. + + return false; + } + + auto try_lock_locked(std::unique_lock& lk) -> bool + { + if (m_state == state::unlocked) { + m_state = state::locked_exclusive; + lk.unlock(); + return true; + } + return false; + } + + auto wake_waiters(std::unique_lock& lk) -> void + { + // First determine what the next lock state will be based on the first waiter. + if (m_head_waiter->m_exclusive) { + // If its exclusive then only this waiter can be woken up. + m_state = state::locked_exclusive; + lock_operation* to_resume = m_head_waiter; + m_head_waiter = m_head_waiter->m_next; + --m_exclusive_waiters; + if (m_head_waiter == nullptr) { + m_tail_waiter = nullptr; + } + + // Since this is an exclusive lock waiting we can resume it directly. + lk.unlock(); + to_resume->m_awaiting_coroutine.resume(); + } + else { + // If its shared then we will scan forward and awake all shared waiters onto the given + // thread pool so they can run in parallel. + m_state = state::locked_shared; + do { + lock_operation* to_resume = m_head_waiter; + m_head_waiter = m_head_waiter->m_next; + if (m_head_waiter == nullptr) { + m_tail_waiter = nullptr; + } + ++m_shared_users; + + m_executor->resume(to_resume->m_awaiting_coroutine); + } while (m_head_waiter != nullptr && !m_head_waiter->m_exclusive); + + // Cannot unlock until the entire set of shared waiters has been traversed. I think this + // makes more sense than allocating space for all the shared waiters, unlocking, and then + // resuming in a batch? + lk.unlock(); + } + } + }; + +} // namespace coro diff --git a/include/glaze/coroutine/sync_wait.hpp b/include/glaze/coroutine/sync_wait.hpp new file mode 100644 index 0000000000..eba3813a55 --- /dev/null +++ b/include/glaze/coroutine/sync_wait.hpp @@ -0,0 +1,321 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#include +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +#include +#include +#include +#include + +#include "glaze/coroutine/awaitable.hpp" + +namespace glz +{ + namespace detail + { + struct sync_wait_event + { + sync_wait_event(bool initially_set = false) : m_set(initially_set) {} + sync_wait_event(const sync_wait_event&) = delete; + sync_wait_event(sync_wait_event&&) = delete; + auto operator=(const sync_wait_event&) -> sync_wait_event& = delete; + auto operator=(sync_wait_event&&) -> sync_wait_event& = delete; + ~sync_wait_event() = default; + + void set() noexcept { + m_set.exchange(true, std::memory_order::release); + m_cv.notify_all(); + } + + void reset() noexcept { + m_set.exchange(false, std::memory_order::release); + } + + void wait() noexcept { + std::unique_lock lk{m_mutex}; + m_cv.wait(lk, [this] { return m_set.load(std::memory_order::acquire); }); + } + + private: + std::mutex m_mutex; + std::condition_variable m_cv; + std::atomic m_set{false}; + }; + + struct sync_wait_task_promise_base + { + sync_wait_task_promise_base() noexcept = default; + + auto initial_suspend() noexcept -> std::suspend_always { return {}; } + + auto unhandled_exception() -> void { m_exception = std::current_exception(); } + + protected: + sync_wait_event* m_event{nullptr}; + std::exception_ptr m_exception; + + ~sync_wait_task_promise_base() = default; + }; + + template + struct sync_wait_task_promise : public sync_wait_task_promise_base + { + using coroutine_type = std::coroutine_handle>; + + static constexpr bool return_type_is_reference = std::is_reference_v; + using stored_type = std::conditional_t*, + std::remove_const_t>; + using variant_type = std::variant; + + sync_wait_task_promise() noexcept = default; + sync_wait_task_promise(const sync_wait_task_promise&) = delete; + sync_wait_task_promise(sync_wait_task_promise&&) = delete; + auto operator=(const sync_wait_task_promise&) -> sync_wait_task_promise& = delete; + auto operator=(sync_wait_task_promise&&) -> sync_wait_task_promise& = delete; + ~sync_wait_task_promise() = default; + + auto start(sync_wait_event& event) + { + m_event = &event; + coroutine_type::from_promise(*this).resume(); + } + + auto get_return_object() noexcept { return coroutine_type::from_promise(*this); } + + template + requires(return_type_is_reference and std::is_constructible_v) or + (not return_type_is_reference and std::is_constructible_v) + auto return_value(value_type&& value) -> void + { + if constexpr (return_type_is_reference) { + return_type ref = static_cast(value); + m_storage.template emplace(std::addressof(ref)); + } + else { + m_storage.template emplace(std::forward(value)); + } + } + + auto return_value(stored_type value) -> void + requires(not return_type_is_reference) + { + if constexpr (std::is_move_constructible_v) { + m_storage.template emplace(std::move(value)); + } + else { + m_storage.template emplace(value); + } + } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept { return false; } + auto await_suspend(coroutine_type coroutine) const noexcept { coroutine.promise().m_event->set(); } + auto await_resume() noexcept {}; + }; + + return completion_notifier{}; + } + + auto result() & -> decltype(auto) + { + if (std::holds_alternative(m_storage)) { + if constexpr (return_type_is_reference) { + return static_cast(*std::get(m_storage)); + } + else { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) { + std::rethrow_exception(std::get(m_storage)); + } + else { + GLZ_THROW_OR_ABORT(std::runtime_error{"The return value was never set, did you execute the coroutine?"}); + } + } + + auto result() const& -> decltype(auto) + { + if (std::holds_alternative(m_storage)) { + if constexpr (return_type_is_reference) { + return static_cast>(*std::get(m_storage)); + } + else { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) { + std::rethrow_exception(std::get(m_storage)); + } + else { + GLZ_THROW_OR_ABORT(std::runtime_error{"The return value was never set, did you execute the coroutine?"}); + } + } + + auto result() && -> decltype(auto) + { + if (std::holds_alternative(m_storage)) { + if constexpr (return_type_is_reference) { + return static_cast(*std::get(m_storage)); + } + else if constexpr (std::is_assignable_v) { + return static_cast(std::get(m_storage)); + } + else { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) { + std::rethrow_exception(std::get(m_storage)); + } + else { + GLZ_THROW_OR_ABORT(std::runtime_error{"The return value was never set, did you execute the coroutine?"}); + } + } + + private: + variant_type m_storage{}; + }; + + template <> + struct sync_wait_task_promise : public sync_wait_task_promise_base + { + using coroutine_type = std::coroutine_handle>; + + sync_wait_task_promise() noexcept = default; + ~sync_wait_task_promise() = default; + + auto start(sync_wait_event& event) + { + m_event = &event; + coroutine_type::from_promise(*this).resume(); + } + + auto get_return_object() noexcept { return coroutine_type::from_promise(*this); } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept { return false; } + auto await_suspend(coroutine_type coroutine) const noexcept { coroutine.promise().m_event->set(); } + auto await_resume() noexcept {}; + }; + + return completion_notifier{}; + } + + auto return_void() noexcept -> void {} + + auto result() -> void + { + if (m_exception) { + std::rethrow_exception(m_exception); + } + } + }; + + template + struct sync_wait_task + { + using promise_type = sync_wait_task_promise; + using coroutine_type = std::coroutine_handle; + + sync_wait_task(coroutine_type coroutine) noexcept : m_coroutine(coroutine) {} + + sync_wait_task(const sync_wait_task&) = delete; + sync_wait_task(sync_wait_task&& other) noexcept + : m_coroutine(std::exchange(other.m_coroutine, coroutine_type{})) + {} + auto operator=(const sync_wait_task&) -> sync_wait_task& = delete; + auto operator=(sync_wait_task&& other) -> sync_wait_task& + { + if (std::addressof(other) != this) { + m_coroutine = std::exchange(other.m_coroutine, coroutine_type{}); + } + + return *this; + } + + ~sync_wait_task() + { + if (m_coroutine) { + m_coroutine.destroy(); + } + } + + auto promise() & -> promise_type& { return m_coroutine.promise(); } + auto promise() const& -> const promise_type& { return m_coroutine.promise(); } + auto promise() && -> promise_type&& { return std::move(m_coroutine.promise()); } + + private: + coroutine_type m_coroutine; + }; + + template ::return_type> + static auto make_sync_wait_task(awaitable_type&& a) -> sync_wait_task; + + template + static auto make_sync_wait_task(awaitable_type&& a) -> sync_wait_task + { + if constexpr (std::is_void_v) { + co_await std::forward(a); + co_return; + } + else { + co_return co_await std::forward(a); + } + } + } + + template ::return_type> + auto sync_wait(awaitable_type&& a) -> decltype(auto) + { + detail::sync_wait_event e{}; + auto task = detail::make_sync_wait_task(std::forward(a)); + task.promise().start(e); + e.wait(); + + if constexpr (std::is_void_v) { + task.promise().result(); + return; + } + else if constexpr (std::is_reference_v) { + return task.promise().result(); + } + else if constexpr (std::is_move_assignable_v) { + // issue-242 + // For non-trivial types (or possibly types that don't fit in a register) + // the compiler will end up calling the ~return_type() when the promise + // is destructed at the end of sync_wait(). This causes the return_type + // object to also be destructed causingn the final return/move from + // sync_wait() to be a 'use after free' bug. To work around this the result + // must be moved off the promise object before the promise is destructed. + // Other solutions could be heap allocating the return_type but that has + // other downsides, for now it is determined that a double move is an + // acceptable solution to work around this bug. + auto result = std::move(task).promise().result(); + return result; + } + else { + return task.promise().result(); + } + } +} diff --git a/include/glaze/coroutine/task.hpp b/include/glaze/coroutine/task.hpp new file mode 100644 index 0000000000..726fa4719c --- /dev/null +++ b/include/glaze/coroutine/task.hpp @@ -0,0 +1,324 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include +#include +#include +#include + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +namespace glz +{ + template + struct task; + + namespace detail + { + struct promise_base + { + friend struct final_awaitable; + struct final_awaitable + { + bool await_ready() const noexcept { return false; } + + template + auto await_suspend(std::coroutine_handle coroutine) noexcept -> std::coroutine_handle<> + { + // If there is a continuation call it, otherwise this is the end of the line. + auto& promise = coroutine.promise(); + if (promise.m_continuation) { + return promise.m_continuation; + } + else { + return std::noop_coroutine(); + } + } + + void await_resume() noexcept + { + // no-op + } + }; + + promise_base() noexcept = default; + ~promise_base() = default; + + auto initial_suspend() noexcept { return std::suspend_always{}; } + + auto final_suspend() noexcept { return final_awaitable{}; } + + auto continuation(std::coroutine_handle<> continuation) noexcept -> void { m_continuation = continuation; } + + protected: + std::coroutine_handle<> m_continuation{}; + }; + + template + struct promise final : public promise_base + { + using task_type = task; + using coroutine_handle = std::coroutine_handle>; + static constexpr bool return_type_is_reference = std::is_reference_v; + using stored_type = std::conditional_t*, + std::remove_const_t>; + using variant_type = std::variant; + + promise() noexcept {} + promise(const promise&) = delete; + promise(promise&& other) = delete; + promise& operator=(const promise&) = delete; + promise& operator=(promise&& other) = delete; + ~promise() = default; + + auto get_return_object() noexcept -> task_type; + + template + requires(return_type_is_reference and std::is_constructible_v) or + (not return_type_is_reference and std::is_constructible_v) + void return_value(T&& value) + { + if constexpr (return_type_is_reference) { + Return ref = static_cast(value); + m_storage.template emplace(std::addressof(ref)); + } + else { + m_storage.template emplace(std::forward(value)); + } + } + + void return_value(stored_type value) + requires(not return_type_is_reference) + { + if constexpr (std::is_move_constructible_v) { + m_storage.template emplace(std::move(value)); + } + else { + m_storage.template emplace(value); + } + } + + void unhandled_exception() noexcept { new (&m_storage) variant_type(std::current_exception()); } + + decltype(auto) result() & + { + if (std::holds_alternative(m_storage)) { + if constexpr (return_type_is_reference) { + return static_cast(*std::get(m_storage)); + } + else { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) { + std::rethrow_exception(std::get(m_storage)); + } + else { + GLZ_THROW_OR_ABORT(std::runtime_error{"The return value was never set, did you execute the coroutine?"}); + } + } + + decltype(auto) result() const& + { + if (std::holds_alternative(m_storage)) { + if constexpr (return_type_is_reference) { + return static_cast>(*std::get(m_storage)); + } + else { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) { + std::rethrow_exception(std::get(m_storage)); + } + else { + GLZ_THROW_OR_ABORT(std::runtime_error{"The return value was never set, did you execute the coroutine?"}); + } + } + + decltype(auto) result() && + { + if (std::holds_alternative(m_storage)) { + if constexpr (return_type_is_reference) { + return static_cast(*std::get(m_storage)); + } + else if constexpr (std::is_move_constructible_v) { + return static_cast(std::get(m_storage)); + } + else { + return static_cast(std::get(m_storage)); + } + } + else if (std::holds_alternative(m_storage)) { + std::rethrow_exception(std::get(m_storage)); + } + else { + GLZ_THROW_OR_ABORT(std::runtime_error{"The return value was never set, did you execute the coroutine?"}); + } + } + + private: + variant_type m_storage{}; + }; + + template <> + struct promise : public promise_base + { + using task_type = task; + using coroutine_handle = std::coroutine_handle>; + + promise() noexcept = default; + promise(const promise&) = delete; + promise(promise&& other) = delete; + promise& operator=(const promise&) = delete; + promise& operator=(promise&& other) = delete; + ~promise() = default; + + auto get_return_object() noexcept -> task_type; + + auto return_void() noexcept -> void {} + + auto unhandled_exception() noexcept -> void { m_exception_ptr = std::current_exception(); } + + auto result() -> void + { + if (m_exception_ptr) { + std::rethrow_exception(m_exception_ptr); + } + } + + private: + std::exception_ptr m_exception_ptr{}; + }; + + } // namespace detail + + template + struct [[nodiscard]] task + { + using task_type = task; + using promise_type = detail::promise; + using coroutine_handle = std::coroutine_handle; + + struct awaitable_base + { + awaitable_base(coroutine_handle coroutine) noexcept : m_coroutine(coroutine) {} + + bool await_ready() const noexcept { return !m_coroutine || m_coroutine.done(); } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> std::coroutine_handle<> + { + m_coroutine.promise().continuation(awaiting_coroutine); + return m_coroutine; + } + + std::coroutine_handle m_coroutine{}; + }; + + task() noexcept : m_coroutine(nullptr) {} + + explicit task(coroutine_handle handle) : m_coroutine(handle) {} + task(const task&) = delete; + task(task&& other) noexcept : m_coroutine(std::exchange(other.m_coroutine, nullptr)) {} + + ~task() + { + if (m_coroutine) { + m_coroutine.destroy(); + } + } + + auto operator=(const task&) -> task& = delete; + + auto operator=(task&& other) noexcept -> task& + { + if (std::addressof(other) != this) { + if (m_coroutine) { + m_coroutine.destroy(); + } + + m_coroutine = std::exchange(other.m_coroutine, nullptr); + } + + return *this; + } + + /** + * @return True if the task is in its final suspend or if the task has been destroyed. + */ + bool is_ready() const noexcept { return m_coroutine == nullptr || m_coroutine.done(); } + + bool resume() + { + if (!m_coroutine.done()) { + m_coroutine.resume(); + } + return !m_coroutine.done(); + } + + bool destroy() + { + if (m_coroutine) { + m_coroutine.destroy(); + m_coroutine = nullptr; + return true; + } + + return false; + } + + auto operator co_await() const& noexcept + { + struct awaitable : public awaitable_base + { + auto await_resume() -> decltype(auto) { return this->m_coroutine.promise().result(); } + }; + + return awaitable{m_coroutine}; + } + + auto operator co_await() const&& noexcept + { + struct awaitable : public awaitable_base + { + auto await_resume() -> decltype(auto) { return std::move(this->m_coroutine.promise()).result(); } + }; + + return awaitable{m_coroutine}; + } + + auto promise() & -> promise_type& { return m_coroutine.promise(); } + auto promise() const& -> const promise_type& { return m_coroutine.promise(); } + auto promise() && -> promise_type&& { return std::move(m_coroutine.promise()); } + + auto handle() -> coroutine_handle { return m_coroutine; } + + private: + coroutine_handle m_coroutine{}; + }; + + namespace detail + { + template + task promise::get_return_object() noexcept + { + return task{coroutine_handle::from_promise(*this)}; + } + + inline task<> promise::get_return_object() noexcept + { + return task<>{coroutine_handle::from_promise(*this)}; + } + } +} diff --git a/include/glaze/coroutine/task_container.hpp b/include/glaze/coroutine/task_container.hpp new file mode 100644 index 0000000000..64459bd60d --- /dev/null +++ b/include/glaze/coroutine/task_container.hpp @@ -0,0 +1,283 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#include +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "glaze/coroutine/concepts.hpp" +#include "glaze/coroutine/task.hpp" + +namespace glz +{ + struct scheduler; + + template + struct task_container + { + struct options + { + /// The number of task spots to reserve space for upon creating the container. + size_t reserve_size{8}; + /// The growth factor for task space in the container when capacity is full. + double growth_factor{2}; + }; + + /** + * @param e Tasks started in the container are scheduled onto this executor. For tasks created + * from a coro::scheduler, this would usually be that coro::scheduler instance. + * @param opts Task container options. + */ + task_container(std::shared_ptr e, + const options opts = options{.reserve_size = 8, .growth_factor = 2}) + : m_growth_factor(opts.growth_factor), m_executor(std::move(e)), m_executor_ptr(m_executor.get()) + { + if (m_executor == nullptr) { + GLZ_THROW_OR_ABORT(std::runtime_error{"task_container cannot have a nullptr executor"}); + } + + init(opts.reserve_size); + } + task_container(const task_container&) = delete; + task_container(task_container&&) = delete; + auto operator=(const task_container&) -> task_container& = delete; + auto operator=(task_container&&) -> task_container& = delete; + ~task_container() + { + // This will hang the current thread.. but if tasks are not complete thats also pretty bad. + while (!empty()) { + garbage_collect(); + } + } + + enum struct garbage_collect_t { + /// Execute garbage collection. + yes, + /// Do not execute garbage collection. + no + }; + + /** + * Stores a user task and starts its execution on the container's thread pool. + * @param user_task The scheduled user's task to store in this task container and start its execution. + * @param cleanup Should the task container run garbage collect at the beginning of this store + * call? Calling at regular intervals will reduce memory usage of completed + * tasks and allow for the task container to re-use allocated space. + */ + auto start(glz::task&& user_task, garbage_collect_t cleanup = garbage_collect_t::yes) -> void + { + m_size.fetch_add(1, std::memory_order::relaxed); + + std::unique_lock lk{m_mutex}; + + if (cleanup == garbage_collect_t::yes) { + gc_internal(); + } + + // Only grow if completely full and attempting to add more. + if (m_free_task_indices.empty()) { + grow(); + } + + // Reserve a free task index + size_t index = m_free_task_indices.front(); + m_free_task_indices.pop(); + + // We've reserved the slot, we can release the lock. + lk.unlock(); + + // Store the task inside a cleanup task for self deletion. + m_tasks[index] = make_cleanup_task(std::move(user_task), index); + + // Start executing from the cleanup task to schedule the user's task onto the thread pool. + m_tasks[index].resume(); + } + + /** + * Garbage collects any tasks that are marked as deleted. This frees up space to be re-used by + * the task container for newly stored tasks. + * @return The number of tasks that were deleted. + */ + auto garbage_collect() -> size_t + { + std::scoped_lock lk{m_mutex}; + return gc_internal(); + } + + /** + * @return The number of active tasks in the container. + */ + auto size() const -> size_t { return m_size.load(std::memory_order::relaxed); } + + /** + * @return True if there are no active tasks in the container. + */ + auto empty() const -> bool { return size() == 0; } + + /** + * @return The capacity of this task manager before it will need to grow in size. + */ + auto capacity() const -> size_t + { + std::atomic_thread_fence(std::memory_order::acquire); + return m_tasks.size(); + } + + /** + * Will continue to garbage collect and yield until all tasks are complete. This method can be + * co_await'ed to make it easier to wait for the task container to have all its tasks complete. + * + * This does not shut down the task container, but can be used when shutting down, or if your + * logic requires all the tasks contained within to complete, it is similar to coro::latch. + */ + auto garbage_collect_and_yield_until_empty() -> glz::task + { + while (!empty()) { + garbage_collect(); + co_await m_executor_ptr->yield(); + } + } + + private: + /** + * Grows each task container by the growth factor. + * @return The position of the free index after growing. + */ + auto grow() -> void + { + // Save an index at the current last item. + size_t new_size = m_tasks.size() * m_growth_factor; + for (size_t i = m_tasks.size(); i < new_size; ++i) { + m_free_task_indices.emplace(i); + } + m_tasks.resize(new_size); + } + + /** + * Internal GC call, expects the public function to lock. + */ + auto gc_internal() -> size_t + { + size_t deleted{0}; + auto pos = std::begin(m_tasks_to_delete); + while (pos != std::end(m_tasks_to_delete)) { + // Skip tasks that are still running or have yet to start. + if (!m_tasks[*pos].is_ready()) { + pos++; + continue; + } + // Destroy the cleanup task and the user task. + m_tasks[*pos].destroy(); + // Put the deleted position at the end of the free indexes list. + m_free_task_indices.emplace(*pos); + // Remove index from tasks to delete + m_tasks_to_delete.erase(pos++); + // Indicate a task was deleted. + ++deleted; + } + m_size.fetch_sub(deleted, std::memory_order::relaxed); + return deleted; + } + + /** + * Encapsulate the users tasks in a cleanup task which marks itself for deletion upon + * completion. Simply co_await the users task until its completed and then mark the given + * position within the task manager as being deletable. The scheduler's next iteration + * in its event loop will then free that position up to be re-used. + * + * This function will also unconditionally catch all unhandled exceptions by the user's + * task to prevent the scheduler from throwing exceptions. + * @param user_task The user's task. + * @param index The index where the task data will be stored in the task manager. + * @return The user's task wrapped in a self cleanup task. + */ + auto make_cleanup_task(task user_task, size_t index) -> glz::task + { + // Immediately move the task onto the executor. + co_await m_executor_ptr->schedule(); + +#if __cpp_exceptions + try { + // Await the users task to complete. + co_await user_task; + } + catch (const std::exception& e) { + // TODO: what would be a good way to report this to the user...? Catching here is required + // since the co_await will unwrap the unhandled exception on the task. + // The user's task should ideally be wrapped in a catch all and handle it themselves, but + // that cannot be guaranteed. + std::cerr << "coro::task_container user_task had an unhandled exception e.what()= " << e.what() << "\n"; + } + catch (...) { + // don't crash if they throw something that isn't derived from std::exception + std::cerr << "coro::task_container user_task had unhandle exception, not derived from std::exception.\n"; + } +#else + co_await user_task; +#endif + + { + // This scope is required around this lock otherwise if this task on destruction schedules a new task it + // can cause a deadlock, notably tls::client schedules a task to cleanup tls resources. + std::scoped_lock lk{m_mutex}; + m_tasks_to_delete.emplace_back(index); + } + + co_return; + } + + /// Mutex for safely mutating the task containers across threads, expected usage is within + /// thread pools for indeterminate lifetime requests. + std::mutex m_mutex{}; + /// The number of alive tasks. + std::atomic m_size{}; + /// Maintains the lifetime of the tasks until they are completed. + std::vector> m_tasks{}; + /// The full set of free indicies into `m_tasks`. + std::queue m_free_task_indices{}; + /// The set of tasks that have completed and need to be deleted. + std::list m_tasks_to_delete{}; + /// The amount to grow the containers by when all spaces are taken. + double m_growth_factor{}; + /// The executor to schedule tasks that have just started. + std::shared_ptr m_executor{nullptr}; + /// This is used internally since scheduler cannot pass itself in as a shared_ptr. + executor_type* m_executor_ptr{nullptr}; + + /** + * Special constructor for internal types to create their embeded task containers. + */ + + friend scheduler; + task_container(executor_type& e, const options opts = options{.reserve_size = 8, .growth_factor = 2}) + : m_growth_factor(opts.growth_factor), m_executor_ptr(&e) + { + init(opts.reserve_size); + } + + auto init(size_t reserve_size) -> void + { + m_tasks.resize(reserve_size); + for (size_t i = 0; i < reserve_size; ++i) { + m_free_task_indices.emplace(i); + } + } + }; +} diff --git a/include/glaze/coroutine/thread_pool.hpp b/include/glaze/coroutine/thread_pool.hpp new file mode 100644 index 0000000000..ba183084b2 --- /dev/null +++ b/include/glaze/coroutine/thread_pool.hpp @@ -0,0 +1,368 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#include +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "glaze/coroutine/concepts.hpp" +#include "glaze/coroutine/event.hpp" +#include "glaze/coroutine/task.hpp" + +namespace glz +{ + /** + * Creates a thread pool that executes arbitrary coroutine tasks in a FIFO scheduler policy. + * The thread pool by default will create an execution thread per available core on the system. + * + * When shutting down, either by the thread pool destructing or by manually calling shutdown() + * the thread pool will stop accepting new tasks but will complete all tasks that were scheduled + * prior to the shutdown request. + */ + struct thread_pool + { + /** + * An operation is an awaitable type with a coroutine to resume the task scheduled on one of + * the executor threads. + */ + class operation + { + friend struct thread_pool; + /** + * Only thread_pools can create operations when a task is being scheduled. + * @param tp The thread pool that created this operation. + */ + explicit operation(thread_pool& tp) noexcept : m_thread_pool(tp) {} + + public: + /** + * Operations always pause so the executing thread can be switched. + */ + auto await_ready() noexcept -> bool { return false; } + + /** + * Suspending always returns to the caller (using void return of await_suspend()) and + * stores the coroutine internally for the executing thread to resume from. + */ + void await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + { + m_awaiting_coroutine = awaiting_coroutine; + m_thread_pool.schedule_impl(m_awaiting_coroutine); + + // void return on await_suspend suspends the _this_ coroutine, which is now scheduled on the + // thread pool and returns control to the caller. They could be sync_wait'ing or go do + // something else while this coroutine gets picked up by the thread pool. + } + + /** + * no-op as this is the function called first by the thread pool's executing thread. + */ + auto await_resume() noexcept -> void {} + + private: + /// The thread pool that this operation will execute on. + thread_pool& m_thread_pool; + /// The coroutine awaiting execution. + std::coroutine_handle<> m_awaiting_coroutine{nullptr}; + }; + + struct options + { + /// The number of executor threads for this thread pool. Uses the hardware concurrency + /// value by default. + uint32_t thread_count = std::thread::hardware_concurrency(); + /// Functor to call on each executor thread upon starting execution. The parameter is the + /// thread's ID assigned to it by the thread pool. + std::function on_thread_start_functor = nullptr; + /// Functor to call on each executor thread upon stopping execution. The parameter is the + /// thread's ID assigned to it by the thread pool. + std::function on_thread_stop_functor = nullptr; + }; + + /** + * @param opts Thread pool configuration options. + */ + explicit thread_pool(options opts = options{.thread_count = std::thread::hardware_concurrency(), + .on_thread_start_functor = nullptr, + .on_thread_stop_functor = nullptr}) + : m_opts(std::move(opts)) + { + m_threads.reserve(m_opts.thread_count); + + for (uint32_t i = 0; i < m_opts.thread_count; ++i) { + m_threads.emplace_back([this, i]() { executor(i); }); + } + } + + thread_pool(const thread_pool&) = delete; + thread_pool(thread_pool&&) = delete; + auto operator=(const thread_pool&) -> thread_pool& = delete; + auto operator=(thread_pool&&) -> thread_pool& = delete; + + virtual ~thread_pool() { shutdown(); } + + /** + * @return The number of executor threads for processing tasks. + */ + auto thread_count() const noexcept -> size_t { return m_threads.size(); } + + /** + * Schedules the currently executing coroutine to be run on this thread pool. This must be + * called from within the coroutines function body to schedule the coroutine on the thread pool. + * @throw std::runtime_error If the thread pool is `shutdown()` scheduling new tasks is not permitted. + * @return The operation to switch from the calling scheduling thread to the executor thread + * pool thread. + */ + [[nodiscard]] auto schedule() -> operation + { + if (!m_shutdown_requested.load(std::memory_order::acquire)) { + m_size.fetch_add(1, std::memory_order::release); + return operation{*this}; + } + + GLZ_THROW_OR_ABORT(std::runtime_error("glz::thread_pool is shutting down, unable to schedule new tasks.")); + } + + /** + * @throw std::runtime_error If the thread pool is `shutdown()` scheduling new tasks is not permitted. + * @param f The function to execute on the thread pool. + * @param args The arguments to call the functor with. + * @return A task that wraps the given functor to be executed on the thread pool. + */ + template + [[nodiscard]] auto schedule(functor&& f, arguments... args) -> task(args)...))> + { + co_await schedule(); + + if constexpr (std::is_same_v(args)...))>) { + f(std::forward(args)...); + co_return; + } + else { + co_return f(std::forward(args)...); + } + } + + /** + * Schedules any coroutine handle that is ready to be resumed. + * @param handle The coroutine handle to schedule. + * @return True if the coroutine is resumed, false if its a nullptr. + */ + bool resume(std::coroutine_handle<> handle) noexcept + { + if (handle == nullptr) { + return false; + } + + if (m_shutdown_requested.load(std::memory_order::acquire)) { + return false; + } + + m_size.fetch_add(1, std::memory_order::release); + schedule_impl(handle); + return true; + } + + /** + * Schedules the set of coroutine handles that are ready to be resumed. + * @param handles The coroutine handles to schedule. + * @param uint64_t The number of tasks resumed, if any where null they are discarded. + */ + template > range_type> + uint64_t resume(const range_type& handles) noexcept + { + m_size.fetch_add(std::size(handles), std::memory_order::release); + + size_t null_handles{0}; + + { + std::scoped_lock lk{m_wait_mutex}; + for (const auto& handle : handles) { + if (handle) [[likely]] { + m_queue.emplace_back(handle); + } + else { + ++null_handles; + } + } + } + + if (null_handles > 0) { + m_size.fetch_sub(null_handles, std::memory_order::release); + } + + uint64_t total = std::size(handles) - null_handles; + if (total >= m_threads.size()) { + m_wait_cv.notify_all(); + } + else { + for (uint64_t i = 0; i < total; ++i) { + m_wait_cv.notify_one(); + } + } + + return total; + } + + /** + * Immediately yields the current task and places it at the end of the queue of tasks waiting + * to be processed. This will immediately be picked up again once it naturally goes through the + * FIFO task queue. This function is useful to yielding long processing tasks to let other tasks + * get processing time. + */ + [[nodiscard]] operation yield() { return schedule(); } + + /** + * Shutsdown the thread pool. This will finish any tasks scheduled prior to calling this + * function but will prevent the thread pool from scheduling any new tasks. This call is + * blocking and will wait until all inflight tasks are completed before returnin. + */ + void shutdown() noexcept + { + // Only allow shutdown to occur once. + if (m_shutdown_requested.exchange(true, std::memory_order::acq_rel) == false) { + { + // There is a race condition if we are not holding the lock with the executors + // to always receive this. std::jthread stop token works without this properly. + std::unique_lock lk{m_wait_mutex}; + m_wait_cv.notify_all(); + } + + for (auto& thread : m_threads) { + if (thread.joinable()) { + thread.join(); + } + } + } + } + + /** + * @return The number of tasks waiting in the task queue + the executing tasks. + */ + size_t size() const noexcept { return m_size.load(std::memory_order::acquire); } + + /** + * @return True if the task queue is empty and zero tasks are currently executing. + */ + bool empty() const noexcept { return size() == 0; } + + /** + * @return The number of tasks waiting in the task queue to be executed. + */ + size_t queue_size() const noexcept + { + std::atomic_thread_fence(std::memory_order::acquire); + return m_queue.size(); + } + + /** + * @return True if the task queue is currently empty. + */ + bool queue_empty() const noexcept { return queue_size() == 0; } + + private: + /// The configuration options. + options m_opts; + /// The background executor threads. + std::vector m_threads; + /// Mutex for executor threads to sleep on the condition variable. + std::mutex m_wait_mutex; + /// Condition variable for each executor thread to wait on when no tasks are available. + std::condition_variable_any m_wait_cv; + /// FIFO queue of tasks waiting to be executed. + std::deque> m_queue; + /** + * Each background thread runs from this function. + * @param idx The executor's idx for internal data structure accesses. + */ + void executor(size_t idx) + { + if (m_opts.on_thread_start_functor) { + m_opts.on_thread_start_functor(idx); + } + + // Process until shutdown is requested. + while (!m_shutdown_requested.load(std::memory_order::acquire)) { + std::unique_lock lk{m_wait_mutex}; + m_wait_cv.wait(lk, + [&]() { return !m_queue.empty() || m_shutdown_requested.load(std::memory_order::acquire); }); + + if (m_queue.empty()) { + continue; + } + + auto handle = m_queue.front(); + m_queue.pop_front(); + lk.unlock(); + + // Release the lock while executing the coroutine. + handle.resume(); + m_size.fetch_sub(1, std::memory_order::release); + } + + // Process until there are no ready tasks left. + while (m_size.load(std::memory_order::acquire) > 0) { + std::unique_lock lk{m_wait_mutex}; + // m_size will only drop to zero once all executing coroutines are finished + // but the queue could be empty for threads that finished early. + if (m_queue.empty()) { + break; + } + + auto handle = m_queue.front(); + m_queue.pop_front(); + lk.unlock(); + + // Release the lock while executing the coroutine. + handle.resume(); + m_size.fetch_sub(1, std::memory_order::release); + } + + if (m_opts.on_thread_stop_functor) { + m_opts.on_thread_stop_functor(idx); + } + } + /** + * @param handle Schedules the given coroutine to be executed upon the first available thread. + */ + void schedule_impl(std::coroutine_handle<> handle) noexcept + { + if (handle == nullptr) { + return; + } + + { + std::scoped_lock lk{m_wait_mutex}; + m_queue.emplace_back(handle); + m_wait_cv.notify_one(); + } + } + + /// The number of tasks in the queue + currently executing. + std::atomic m_size{0}; + /// Has the thread pool been requested to shut down? + std::atomic m_shutdown_requested{false}; + }; + +} diff --git a/include/glaze/coroutine/when_all.hpp b/include/glaze/coroutine/when_all.hpp new file mode 100644 index 0000000000..fd3e14405e --- /dev/null +++ b/include/glaze/coroutine/when_all.hpp @@ -0,0 +1,465 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "glaze/coroutine/awaitable.hpp" + +namespace glz +{ + namespace detail + { + struct when_all_latch + { + when_all_latch(size_t count) noexcept : m_count(count + 1) {} + + when_all_latch(const when_all_latch&) = delete; + when_all_latch(when_all_latch&& other) + : m_count(other.m_count.load(std::memory_order::acquire)), + m_awaiting_coroutine(std::exchange(other.m_awaiting_coroutine, nullptr)) + {} + + auto operator=(const when_all_latch&) -> when_all_latch& = delete; + auto operator=(when_all_latch&& other) -> when_all_latch& + { + if (std::addressof(other) != this) { + m_count.store(other.m_count.load(std::memory_order::acquire), std::memory_order::relaxed); + m_awaiting_coroutine = std::exchange(other.m_awaiting_coroutine, nullptr); + } + + return *this; + } + + bool is_ready() const noexcept + { + return m_awaiting_coroutine && m_awaiting_coroutine.done(); + } + + bool try_await(std::coroutine_handle<> awaiting_coroutine) noexcept + { + m_awaiting_coroutine = awaiting_coroutine; + return m_count.fetch_sub(1, std::memory_order::acq_rel) > 1; + } + + void notify_awaitable_completed() noexcept + { + if (m_count.fetch_sub(1, std::memory_order::acq_rel) == 1) { + m_awaiting_coroutine.resume(); + } + } + + private: + /// The number of tasks that are being waited on. + std::atomic m_count; + /// The when_all_task awaiting to be resumed upon all task completions. + std::coroutine_handle<> m_awaiting_coroutine{}; + }; + + template + struct when_all_ready_awaitable; + + template + struct when_all_task; + + /// Empty tuple<> implementation. + template <> + struct when_all_ready_awaitable> + { + constexpr when_all_ready_awaitable() noexcept {} + explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} + + constexpr auto await_ready() const noexcept -> bool { return true; } + auto await_suspend(std::coroutine_handle<>) noexcept -> void {} + auto await_resume() const noexcept -> std::tuple<> { return {}; } + }; + + template + struct when_all_ready_awaitable> + { + explicit when_all_ready_awaitable(Tasks&&... tasks) noexcept( + std::conjunction...>::value) + : m_latch(sizeof...(Tasks)), m_tasks(std::move(tasks)...) + {} + + explicit when_all_ready_awaitable(std::tuple&& tasks) noexcept( + std::is_nothrow_move_constructible_v>) + : m_latch(sizeof...(Tasks)), m_tasks(std::move(tasks)) + {} + + when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; + when_all_ready_awaitable(when_all_ready_awaitable&& other) + : m_latch(std::move(other.m_latch)), m_tasks(std::move(other.m_tasks)) + {} + + auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; + auto operator=(when_all_ready_awaitable&&) -> when_all_ready_awaitable& = delete; + + auto operator co_await() & noexcept + { + struct awaiter + { + explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {} + + bool await_ready() const noexcept { return m_awaitable.is_ready(); } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> std::tuple& { return m_awaitable.m_tasks; } + + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{*this}; + } + + auto operator co_await() && noexcept + { + struct awaiter + { + explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {} + + bool await_ready() const noexcept { return m_awaitable.is_ready(); } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> std::tuple&& { return std::move(m_awaitable.m_tasks); } + + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{*this}; + } + + private: + bool is_ready() const noexcept { return m_latch.is_ready(); } + + bool try_await(std::coroutine_handle<> awaiting_coroutine) noexcept + { + std::apply([this](auto&&... tasks) { ((tasks.start(m_latch)), ...); }, m_tasks); + return m_latch.try_await(awaiting_coroutine); + } + + when_all_latch m_latch; + std::tuple m_tasks; + }; + + template + struct when_all_ready_awaitable + { + explicit when_all_ready_awaitable(TaskContainer&& tasks) noexcept + : m_latch(std::size(tasks)), m_tasks(std::forward(tasks)) + {} + + when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; + when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept( + std::is_nothrow_move_constructible_v) + : m_latch(std::move(other.m_latch)), m_tasks(std::move(m_tasks)) + {} + + auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; + auto operator=(when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; + + auto operator co_await() & noexcept + { + struct awaiter + { + awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {} + + auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> TaskContainer& { return m_awaitable.m_tasks; } + + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{*this}; + } + + auto operator co_await() && noexcept + { + struct awaiter + { + awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {} + + auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> TaskContainer&& { return std::move(m_awaitable.m_tasks); } + + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{*this}; + } + + private: + auto is_ready() const noexcept -> bool { return m_latch.is_ready(); } + + auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + for (auto& task : m_tasks) { + task.start(m_latch); + } + + return m_latch.try_await(awaiting_coroutine); + } + + when_all_latch m_latch; + TaskContainer m_tasks; + }; + + template + struct when_all_task_promise + { + using coroutine_handle_type = std::coroutine_handle>; + + when_all_task_promise() noexcept {} + + auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); } + + auto initial_suspend() noexcept -> std::suspend_always { return {}; } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept -> bool { return false; } + auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void + { + coroutine.promise().m_latch->notify_awaitable_completed(); + } + auto await_resume() const noexcept {} + }; + + return completion_notifier{}; + } + + auto unhandled_exception() noexcept { m_exception_ptr = std::current_exception(); } + + auto yield_value(return_type&& value) noexcept + { + m_return_value = std::addressof(value); + return final_suspend(); + } + + auto start(when_all_latch& latch) noexcept -> void + { + m_latch = &latch; + coroutine_handle_type::from_promise(*this).resume(); + } + + auto result() & -> return_type& + { + if (m_exception_ptr) { + std::rethrow_exception(m_exception_ptr); + } + return *m_return_value; + } + + auto result() && -> return_type&& + { + if (m_exception_ptr) { + std::rethrow_exception(m_exception_ptr); + } + return std::forward(*m_return_value); + } + + auto return_void() noexcept -> void + { + // We should have either suspended at co_yield point or + // an exception was thrown before running off the end of + // the coroutine. + assert(false); + } + + private: + when_all_latch* m_latch{nullptr}; + std::exception_ptr m_exception_ptr; + std::add_pointer_t m_return_value; + }; + + template <> + struct when_all_task_promise + { + using coroutine_handle_type = std::coroutine_handle>; + + when_all_task_promise() noexcept {} + + auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); } + + auto initial_suspend() noexcept -> std::suspend_always { return {}; } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept -> bool { return false; } + auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void + { + coroutine.promise().m_latch->notify_awaitable_completed(); + } + auto await_resume() const noexcept -> void {} + }; + + return completion_notifier{}; + } + + auto unhandled_exception() noexcept -> void { m_exception_ptr = std::current_exception(); } + + auto return_void() noexcept -> void {} + + auto result() -> void + { + if (m_exception_ptr) { + std::rethrow_exception(m_exception_ptr); + } + } + + auto start(when_all_latch& latch) -> void + { + m_latch = &latch; + coroutine_handle_type::from_promise(*this).resume(); + } + + private: + when_all_latch* m_latch{nullptr}; + std::exception_ptr m_exception_ptr; + }; + + template + struct when_all_task + { + using promise_type = when_all_task_promise; + using coroutine_handle_type = typename promise_type::coroutine_handle_type; + + when_all_task(coroutine_handle_type coroutine) noexcept : m_coroutine(coroutine) {} + + when_all_task(const when_all_task&) = delete; + when_all_task(when_all_task&& other) noexcept + : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_type{})) + {} + + auto operator=(const when_all_task&) -> when_all_task& = delete; + auto operator=(when_all_task&&) -> when_all_task& = delete; + + ~when_all_task() + { + if (m_coroutine) { + m_coroutine.destroy(); + } + } + + decltype(auto) return_value() & + { + if constexpr (std::is_void_v) { + m_coroutine.promise().result(); + return std::nullptr_t{}; + } + else { + return m_coroutine.promise().result(); + } + } + + decltype(auto) return_value() const& + { + if constexpr (std::is_void_v) { + m_coroutine.promise().result(); + return std::nullptr_t{}; + } + else { + return m_coroutine.promise().result(); + } + } + + decltype(auto) return_value() && + { + if constexpr (std::is_void_v) { + m_coroutine.promise().result(); + return std::nullptr_t{}; + } + else { + return m_coroutine.promise().result(); + } + } + + void start(when_all_latch& latch) noexcept { m_coroutine.promise().start(latch); } + + private: + coroutine_handle_type m_coroutine; + }; + + template ::return_type> + static auto make_when_all_task(Awaitable a) -> when_all_task; + + template + static auto make_when_all_task(Awaitable a) -> when_all_task + { + if constexpr (std::is_void_v) { + co_await static_cast(a); + co_return; + } + else { + co_yield co_await static_cast(a); + } + } + + } // namespace detail + + template + [[nodiscard]] auto when_all(Awaitables... awaitables) + { + return detail::when_all_ready_awaitable::return_type>...>>( + std::make_tuple(detail::make_when_all_task(std::move(awaitables))...)); + } + + template , + class Return = typename awaitable_traits::return_type> + [[nodiscard]] auto when_all(Range awaitables) + -> detail::when_all_ready_awaitable>> + { + std::vector> output_tasks; + + // If the size is known in constant time reserve the output tasks size. + if constexpr (std::ranges::sized_range) { + output_tasks.reserve(std::size(awaitables)); + } + + // Wrap each task into a when_all_task. + for (auto&& a : awaitables) { + output_tasks.emplace_back(detail::make_when_all_task(std::move(a))); + } + + // Return the single awaitable that drives all the user's tasks. + return detail::when_all_ready_awaitable(std::move(output_tasks)); + } +} diff --git a/include/glaze/ext/glaze_asio.hpp b/include/glaze/ext/glaze_asio.hpp deleted file mode 100644 index 887e3ea32f..0000000000 --- a/include/glaze/ext/glaze_asio.hpp +++ /dev/null @@ -1,368 +0,0 @@ -// Glaze Library -// For the license information refer to glaze.hpp - -#pragma once - -#if __has_include() && !defined(GLZ_USE_BOOST_ASIO) -#include -#elif __has_include() -#ifndef GLZ_USING_BOOST_ASIO -#define GLZ_USING_BOOST_ASIO -#endif -#include -#else -static_assert(false, "standalone asio must be included to use glaze/ext/glaze_asio.hpp"); -#endif - -#include -#include -#include - -#include "glaze/rpc/repe.hpp" - -namespace glz -{ -#if defined(GLZ_USING_BOOST_ASIO) - namespace asio = boost::asio; -#endif - inline void send_buffer(asio::ip::tcp::socket& socket, const std::string_view str) - { - const uint64_t size = str.size(); - std::array buffers{asio::buffer(&size, sizeof(uint64_t)), asio::buffer(str)}; - - asio::write(socket, buffers, asio::transfer_exactly(sizeof(uint64_t) + size)); - } - - inline void receive_buffer(asio::ip::tcp::socket& socket, std::string& str) - { - uint64_t size; - asio::read(socket, asio::buffer(&size, sizeof(size)), asio::transfer_exactly(sizeof(uint64_t))); - str.resize(size); - asio::read(socket, asio::buffer(str), asio::transfer_exactly(size)); - } - - inline asio::awaitable co_send_buffer(asio::ip::tcp::socket& socket, const std::string_view str) - { - const uint64_t size = str.size(); - std::array buffers{asio::buffer(&size, sizeof(uint64_t)), asio::buffer(str)}; - - co_await asio::async_write(socket, buffers, asio::transfer_exactly(sizeof(uint64_t) + size), asio::use_awaitable); - } - - inline asio::awaitable co_receive_buffer(asio::ip::tcp::socket& socket, std::string& str) - { - uint64_t size; - co_await asio::async_read(socket, asio::buffer(&size, sizeof(size)), asio::transfer_exactly(sizeof(uint64_t)), - asio::use_awaitable); - str.resize(size); - co_await asio::async_read(socket, asio::buffer(str), asio::transfer_exactly(size), asio::use_awaitable); - } - - inline asio::awaitable call_rpc(asio::ip::tcp::socket& socket, std::string& buffer) - { - co_await co_send_buffer(socket, buffer); - co_await co_receive_buffer(socket, buffer); - } - - template - struct func_traits; - - template - struct func_traits - { - using result_type = Result; - using params_type = void; - using std_func_sig = std::function; - }; - - template - struct func_traits - { - using result_type = Result; - using params_type = Params; - using std_func_sig = std::function; - }; - - template - using func_result_t = typename func_traits::result_type; - - template - using func_params_t = typename func_traits::params_type; - - template - using std_func_sig_t = typename func_traits::std_func_sig; - - template - struct asio_client - { - std::string host{"localhost"}; // host name - std::string service{""}; // often the port - uint32_t concurrency{1}; // how many threads to use - - struct glaze - { - using T = asio_client; - static constexpr auto value = glz::object(&T::host, &T::service, &T::concurrency); - }; - - std::shared_ptr ctx{}; - std::shared_ptr socket{}; - - std::shared_ptr buffer_pool = std::make_shared(); - - [[nodiscard]] std::error_code init() - { - ctx = std::make_shared(concurrency); - socket = std::make_shared(*ctx); - asio::ip::tcp::resolver resolver{*ctx}; - const auto endpoint = resolver.resolve(host, service); -#if !defined(GLZ_USING_BOOST_ASIO) - std::error_code ec{}; -#else - boost::system::error_code ec{}; -#endif - asio::connect(*socket, endpoint, ec); - if (ec) { - return ec; - } - return socket->set_option(asio::ip::tcp::no_delay(true), ec); - } - - template - [[nodiscard]] repe::error_t notify(repe::header&& header, Params&& params) - { - repe::unique_buffer ubuffer{buffer_pool.get()}; - auto& buffer = ubuffer.value(); - - header.notify = true; - const auto ec = repe::request(std::move(header), std::forward(params), buffer); - if (bool(ec)) [[unlikely]] { - return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; - } - - send_buffer(*socket, buffer); - return {}; - } - - template - [[nodiscard]] repe::error_t get(repe::header&& header, Result&& result) - { - repe::unique_buffer ubuffer{buffer_pool.get()}; - auto& buffer = ubuffer.value(); - - header.notify = false; - header.empty = true; // no params - const auto ec = repe::request(std::move(header), nullptr, buffer); - if (bool(ec)) [[unlikely]] { - return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; - } - - send_buffer(*socket, buffer); - receive_buffer(*socket, buffer); - - return repe::decode_response(std::forward(result), buffer); - } - - template - [[nodiscard]] glz::expected get(repe::header&& header) - { - std::decay_t result{}; - const auto error = get(std::move(header), result); - if (error) { - return glz::unexpected(error); - } - else { - return {result}; - } - } - - template - [[nodiscard]] repe::error_t set(repe::header&& header, Params&& params) - { - repe::unique_buffer ubuffer{buffer_pool.get()}; - auto& buffer = ubuffer.value(); - - header.notify = false; - const auto ec = repe::request(std::move(header), std::forward(params), buffer); - if (bool(ec)) [[unlikely]] { - return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; - } - - send_buffer(*socket, buffer); - receive_buffer(*socket, buffer); - - return repe::decode_response(buffer); - } - - template - [[nodiscard]] repe::error_t call(repe::header&& header, Params&& params, Result&& result) - { - repe::unique_buffer ubuffer{buffer_pool.get()}; - auto& buffer = ubuffer.value(); - - header.notify = false; - const auto ec = repe::request(std::move(header), std::forward(params), buffer); - if (bool(ec)) [[unlikely]] { - return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; - } - - send_buffer(*socket, buffer); - receive_buffer(*socket, buffer); - - return repe::decode_response(std::forward(result), buffer); - } - - [[nodiscard]] repe::error_t call(repe::header&& header) - { - repe::unique_buffer ubuffer{buffer_pool.get()}; - auto& buffer = ubuffer.value(); - - header.notify = false; - header.empty = true; // because no value provided - const auto ec = glz::write_json(std::forward_as_tuple(std::move(header), nullptr), buffer); - if (bool(ec)) [[unlikely]] { - return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; - } - - send_buffer(*socket, buffer); - receive_buffer(*socket, buffer); - - return repe::decode_response(buffer); - } - - template - [[nodiscard]] std_func_sig_t callable(repe::header&& header) - { - using Params = func_params_t; - using Result = func_result_t; - if constexpr (std::same_as) { - header.empty = true; - return [this, h = std::move(header)]() mutable -> Result { - std::decay_t result{}; - const auto e = call(repe::header{h}, result); - if (e) { - throw std::runtime_error(glz::write_json(e)); - } - return result; - }; - } - else { - header.empty = false; - return [this, h = std::move(header)](Params params) mutable -> Result { - std::decay_t result{}; - const auto e = call(repe::header{h}, params, result); - if (e) { - throw std::runtime_error(e.message); - } - return result; - }; - } - } - - template - [[deprecated("We use a buffer pool now, so this would cause allocations")]] [[nodiscard]] std::string call_raw( - repe::header&& header, Params&& params, repe::error_t& error) - { - std::string buffer{}; - - header.notify = false; - const auto ec = repe::request(std::move(header), std::forward(params), buffer); - if (bool(ec)) [[unlikely]] { - error = {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; - return buffer; - } - - send_buffer(*socket, buffer); - receive_buffer(*socket, buffer); - return buffer; - } - }; - - template - struct asio_server - { - uint16_t port{}; - uint32_t concurrency{1}; // how many threads to use - - struct glaze - { - using T = asio_server; - static constexpr auto value = glz::object(&T::port, &T::concurrency); - }; - - std::shared_ptr ctx{}; - std::shared_ptr signals{}; - - repe::registry registry{}; - - void clear_registry() { registry.clear(); } - - template - requires(glz::detail::glaze_object_t || glz::detail::reflectable) - void on(T& value) - { - registry.template on(value); - } - - bool initialized = false; - - void init() - { - if (!initialized) { - ctx = std::make_shared(concurrency); - signals = std::make_shared(*ctx, SIGINT, SIGTERM); - } - initialized = true; - } - - void run() - { - if (!initialized) { - init(); - } - - signals->async_wait([&](auto, auto) { ctx->stop(); }); - - asio::co_spawn(*ctx, listener(), asio::detached); - - ctx->run(); - } - - // stop the server - void stop() - { - if (ctx) { - ctx->stop(); - } - } - - asio::awaitable run_instance(asio::ip::tcp::socket socket) - { - socket.set_option(asio::ip::tcp::no_delay(true)); - std::string buffer{}; - - try { - while (true) { - co_await co_receive_buffer(socket, buffer); - auto response = registry.call(buffer); - if (response) { - co_await co_send_buffer(socket, response->value()); - } - } - } - catch (const std::exception& e) { - std::fprintf(stderr, "%s\n", e.what()); - } - } - - asio::awaitable listener() - { - auto executor = co_await asio::this_coro::executor; - asio::ip::tcp::acceptor acceptor(executor, {asio::ip::tcp::v6(), port}); - while (true) { - auto socket = co_await acceptor.async_accept(asio::use_awaitable); - asio::co_spawn(executor, run_instance(std::move(socket)), asio::detached); - } - } - }; -} diff --git a/include/glaze/network.hpp b/include/glaze/network.hpp new file mode 100644 index 0000000000..b44a975b11 --- /dev/null +++ b/include/glaze/network.hpp @@ -0,0 +1,8 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/network/client.hpp" +#include "glaze/network/server.hpp" +#include "glaze/network/socket.hpp" diff --git a/include/glaze/network/client.hpp b/include/glaze/network/client.hpp new file mode 100644 index 0000000000..90ad2e39ca --- /dev/null +++ b/include/glaze/network/client.hpp @@ -0,0 +1,184 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/coroutine/scheduler.hpp" +#include "glaze/network/ip.hpp" +#include "glaze/network/socket.hpp" + +namespace glz +{ + /** + * By default the socket + * created will be in non-blocking mode, meaning that any sending or receiving of data should + * poll for event readiness prior. + */ + struct client + { + std::shared_ptr scheduler{}; + std::string address{"127.0.0.1"}; + uint16_t port{8080}; + ip_version ipv{}; + std::shared_ptr socket = make_async_socket(); + /// Cache the status of the connect in the event the user calls connect() again. + glz::ip_status connect_status{}; + + /** + * Connects to the address+port with the given timeout. Once connected calling this function + * only returns the connected status, it will not reconnect. + * @param timeout How long to wait for the connection to establish? Timeout of zero is indefinite. + * @return The result status of trying to connect. + */ + task connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + { + // Only allow the user to connect per tcp client once, if they need to re-connect they should + // make a new tcp::client. + if (connect_status != glz::ip_status::unset) { + co_return connect_status; + } + + // This enforces the connection status is aways set on the client object upon returning. + auto return_value = [this](glz::ip_status s) -> glz::ip_status { + connect_status = s; + return s; + }; + + if (socket->socket_fd == net::invalid_socket) { + co_return return_value(ip_status::invalid_socket); + } + + sockaddr_in server_addr{}; + server_addr.sin_family = AF_INET; // Use AF_INET for IPv4 + server_addr.sin_port = htons(port); + + if (::inet_pton(AF_INET, address.c_str(), &server_addr.sin_addr) <= 0) { + std::cerr << "Invalid address/ Address not supported: " << inet_ntoa(server_addr.sin_addr) << ":" << ntohs(server_addr.sin_port) << '\n'; + co_return return_value(ip_status::invalid_ip_address); + } + + std::cout << "Attempting client connection to: " << inet_ntoa(server_addr.sin_addr) << ":" << ntohs(server_addr.sin_port) << '\n'; + auto result = ::connect(socket->socket_fd, (sockaddr*)&server_addr, sizeof(server_addr)); + if (result == 0) { + std::cout << "Client connected to: " << inet_ntoa(server_addr.sin_addr) << ":" << ntohs(server_addr.sin_port) << '\n'; + co_return return_value(ip_status::connected); + } + else if (result == -1) { + // + // If the connect is happening in the background, poll for write on the socket to trigger + // when the connection is established. + // + // TODO: Handle cross-platform... + // + std::cout << "Connection failed, polling for connection: " << inet_ntoa(server_addr.sin_addr) << ":" << ntohs(server_addr.sin_port) << "\nDetails: " << strerror(errno) << '\n'; + if (errno == EAGAIN || errno == EINPROGRESS) { + auto pstatus = co_await scheduler->poll(socket->socket_fd, poll_op::write, timeout); + if (pstatus == poll_status::event) { + result = 0; + socklen_t result_length{sizeof(result)}; + + #if defined(_WIN32) + if (getsockopt(socket.socket_fd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&result), + #else + if (getsockopt(socket->socket_fd, SOL_SOCKET, SO_ERROR, &result, + #endif + + &result_length) < 0) { + std::cerr << "connect failed to getsockopt after write poll event\n"; + } + + if (result == 0) { + co_return return_value(ip_status::connected); + } + } + else if (pstatus == poll_status::timeout) { + co_return return_value(ip_status::timeout); + } + } + } + + std::cerr << "connect: " << get_socket_error_message(errno) << '\n'; + co_return return_value(ip_status::error); + } + + /** + * Polls for the given operation on this client's tcp socket. This should be done prior to + * calling recv and after a send that doesn't send the entire buffer. + * @param op The poll operation to perform, use read for incoming data and write for outgoing. + * @param timeout The amount of time to wait for the poll event to be ready. Use zero for infinte timeout. + * @return The status result of th poll operation. When poll_status::event is returned then the + * event operation is ready. + */ + task poll(glz::poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + { + return scheduler->poll(socket->socket_fd, op, timeout); + } + + /** + * Receives incoming data into the given buffer. By default since all tcp client sockets are set + * to non-blocking use co_await poll() to determine when data is ready to be received. + * @param buffer Received bytes are written into this buffer up to the buffers size. + * @return The status of the recv call and a span of the bytes recevied (if any). The span of + * bytes will be a subspan or full span of the given input buffer. + */ + template + std::pair> recv(Buffer&& buffer) + { + // If the user requested zero bytes, just return. + if (buffer.empty()) { + return {ip_status::ok, std::span{}}; + } + + auto bytes_recv = ::recv(socket->socket_fd, buffer.data(), glz::net::ssize_t(buffer.size()), 0); + if (bytes_recv > 0) { + // Ok, we've recieved some data. + return {ip_status::ok, std::span{buffer.data(), size_t(bytes_recv)}}; + } + else if (bytes_recv == 0) { + // On TCP stream sockets 0 indicates the connection has been closed by the peer. + return {ip_status::closed, std::span{}}; + } + else { + // Report the error to the user. + return {errno_to_ip_status(), std::span{}}; + } + } + + /** + * Sends outgoing data from the given buffer. If a partial write occurs then use co_await poll() + * to determine when the tcp client socket is ready to be written to again. On partial writes + * the status will be 'ok' and the span returned will be non-empty, it will contain the buffer + * span data that was not written to the client's socket. + * @param buffer The data to write on the tcp socket. + * @return The status of the send call and a span of any remaining bytes not sent. If all bytes + * were successfully sent the status will be 'ok' and the remaining span will be empty. + */ + template + std::pair send(const Buffer& buffer) + { + // If the user requested zero bytes, just return. + if (buffer.empty()) { + return {ip_status::ok, std::string_view{buffer.data(), buffer.size()}}; + } + + int error = 0; + socklen_t len = sizeof(error); + int result = getsockopt(socket->socket_fd, SOL_SOCKET, SO_ERROR, (char*)&error, &len); + if (result == int(ip_status::error)) { + auto err = get_socket_error_message(errno); + std::cerr << err; + return {ip_status::invalid_socket, err}; + } + + auto bytes_sent = ::send(socket->socket_fd, buffer.data(), glz::net::ssize_t(buffer.size()), 0); + if (bytes_sent >= 0) { + // Some or all of the bytes were written. + return {ip_status::ok, std::string_view{buffer.data() + bytes_sent, buffer.size() - bytes_sent}}; + } + else { + // Due to the error none of the bytes were written. + return {errno_to_ip_status(), std::string_view{buffer.data(), buffer.size()}}; + } + } + }; +} diff --git a/include/glaze/network/core.hpp b/include/glaze/network/core.hpp new file mode 100644 index 0000000000..62d4eb6cfa --- /dev/null +++ b/include/glaze/network/core.hpp @@ -0,0 +1,167 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include + +#ifndef GLZ_THROW_OR_ABORT +#if __cpp_exceptions +#define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) +#include +#else +#define GLZ_THROW_OR_ABORT(EXC) (std::abort()) +#endif +#endif + +#if defined(_WIN32) +#include +#include +#pragma comment(lib, "Ws2_32.lib") +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#if __has_include() +#include +#endif + +#if defined(__APPLE__) +#include // for kqueue on macOS +#elif defined(__linux__) +#include // for epoll on Linux +#include +#include +#endif + +namespace glz::net +{ +#ifdef _WIN32 + using event_handle_t = HANDLE; + using ssize_t = int32_t; + using asize_t = uint8_t; +#else + using event_handle_t = int; + using ssize_t = ::ssize_t; + using asize_t = int; +#endif + +#if defined(__APPLE__) + constexpr auto invalid_socket = -1; + using poll_event_t = struct kevent; + constexpr int invalid_event_handle = -1; + using ident_t = uintptr_t; + constexpr uintptr_t invalid_ident = ~uintptr_t(0); // set all bits +#elif defined(__linux__) + constexpr auto invalid_socket = -1; + using poll_event_t = struct epoll_event; + constexpr int invalid_event_handle = -1; + using ident_t = int; + constexpr int invalid_ident = -1; +#elif defined(_WIN32) + constexpr auto invalid_socket = INVALID_SOCKET; + using ident_t = HANDLE; + using poll_event_t = HANDLE; + inline const HANDLE invalid_event_handle = INVALID_HANDLE_VALUE; + inline const HANDLE invalid_ident = INVALID_HANDLE_VALUE; +#endif + +#if defined(__APPLE__) + constexpr auto poll_in = EVFILT_READ; + constexpr auto poll_out = EVFILT_WRITE; +#elif defined(__linux__) + constexpr auto poll_in = EPOLLIN; + constexpr auto poll_out = EPOLLOUT; +#elif defined(_WIN32) + constexpr auto poll_in = 0; + constexpr auto poll_out = 1; +#endif + + + inline auto close_socket(auto& fd) { + if (fd != invalid_socket) { +#ifdef _WIN32 + ::closesocket(fd); +#else + ::close(fd); +#endif + } + fd = invalid_socket; + } + + inline auto close_event(auto fd) { +#ifdef _WIN32 + ::CloseHandle(fd); +#else + ::close(fd); +#endif + } + + inline event_handle_t create_event_poll() { +#if defined(__APPLE__) + return ::kqueue(); +#elif defined(__linux__) + return ::epoll_create1(EPOLL_CLOEXEC); +#elif defined(_WIN32) + return INVALID_HANDLE_VALUE; +#endif + } + + inline event_handle_t create_shutdown_handle() { +#if defined(__APPLE__) + return invalid_event_handle; +#elif defined(__linux__) + return ::eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); +#elif defined(_WIN32) + return CreateEventA(nullptr, TRUE, FALSE, "create_shutdown_handle"); +#endif + } + + inline event_handle_t create_timer_handle() { +#if defined(__APPLE__) + return invalid_event_handle; +#elif defined(__linux__) + return ::timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC); +#elif defined(_WIN32) + return CreateWaitableTimerA(nullptr, TRUE, "create_timer_handle"); +#endif + } + + inline event_handle_t create_schedule_handle() { +#if defined(__APPLE__) + return invalid_event_handle; +#elif defined(__linux__) + return ::eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); +#elif defined(_WIN32) + return CreateEventA(nullptr, TRUE, FALSE, "create_schedule_handle"); +#endif + } + + inline bool poll_error([[maybe_unused]] uint32_t events) { +#if defined(__APPLE__) + return events & EV_ERROR; +#elif defined(__linux__) + return events & EPOLLERR; +#elif defined(_WIN32) + return true; +#endif + } + + inline bool event_closed([[maybe_unused]] uint32_t events) { +#if defined(__APPLE__) + return events & EV_EOF; +#elif defined(__linux__) + return events & EPOLLRDHUP || events & EPOLLHUP; +#elif defined(_WIN32) + return true; +#endif + } +} diff --git a/include/glaze/network/ip.hpp b/include/glaze/network/ip.hpp new file mode 100644 index 0000000000..ad4c9479a9 --- /dev/null +++ b/include/glaze/network/ip.hpp @@ -0,0 +1,174 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Modified from the awesome: https://github.com/jbaldwin/libcoro + +#pragma once + +#if defined(__linux__) || defined(__APPLE__) +#include +#elif defined(_WIN32) +#include +#include +#endif + +#include +#include +#include + +#include "glaze/reflection/enum_macro.hpp" + +namespace glz { + enum struct ip_version : int { ipv4 = AF_INET, ipv6 = AF_INET6 }; + + GLZ_ENUM(ip_status, + unset, + ok, + closed, + connected, + connection_refused, + invalid_ip_address, + timeout, + error, + try_again, + would_block, + bad_file_descriptor, + invalid_socket + ); + + inline ip_status errno_to_ip_status() noexcept { +#if defined(__linux__) || defined(__APPLE__) + const auto err = errno; + using enum ip_status; + switch (err) { + case 0: + return ok; + case -1: + return closed; + case EWOULDBLOCK: + return would_block; + case ECONNREFUSED: + return connection_refused; + default: + return error; + } +#endif + } + + // Get's human readable ip if binary; returns ip if full socket address (ip::port); handles IPV4 and + // IPV6 formats. + // + inline std::optional to_ip_string(std::string_view input, ip_version ipv = ip_version::ipv6) { + if (input.empty()) return std::nullopt; + + // Remove trailing whitespace...here because input contained a '\n'... + while (!input.empty() && std::isspace(static_cast(input.back()))) { + input.remove_suffix(1); + } + + size_t buffer_size = (ipv == ip_version::ipv4) ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN; + std::string output(buffer_size, '\0'); + + // Handle human-readable IP addresses (with or without port) + std::string_view ip_part = input; + if (input.find_first_not_of("0123456789.:abcdefABCDEF[]") == std::string_view::npos) { + // Remove port if present + size_t colon_pos = input.rfind(':'); + size_t bracket_pos = input.rfind(']'); + if (colon_pos != std::string_view::npos && + (bracket_pos == std::string_view::npos || colon_pos > bracket_pos)) { + ip_part = input.substr(0, colon_pos); + } + + // Remove brackets for IPv6 addresses + if (ip_part.front() == '[' && ip_part.back() == ']') { + ip_part = ip_part.substr(1, ip_part.size() - 2); + } + + // Try to convert the string representation + int family = (ip_part.find(':') != std::string_view::npos) ? AF_INET6 : AF_INET; + unsigned char buf[sizeof(struct in6_addr)]; + if (inet_pton(family, std::string(ip_part).c_str(), buf) == 1) { + if (inet_ntop(family, buf, output.data(), net::asize_t(output.size()))) { + output.resize(std::strlen(output.c_str())); + return output; + } + } + } + + // Handle binary IP addresses here + const void* addr_ptr; + int family; + if (input.size() == sizeof(struct in_addr)) { + addr_ptr = input.data(); + family = AF_INET; + } else if (input.size() == sizeof(struct in6_addr)) { + addr_ptr = input.data(); + family = AF_INET6; + } else { + return std::nullopt; // Invalid input size for binary IP + } + + if (inet_ntop(family, addr_ptr, output.data(), net::asize_t(output.size()))) { + output.resize(std::strlen(output.c_str())); + return output; + } else { + // TODO: Implement Error handling for cross-platform portability. + int error_code = errno; + switch (error_code) { + case EAFNOSUPPORT: + std::cerr << "Address family not supported for socket address: " << input << '\n'; + break; + case ENOSPC: + std::cerr << "Insufficient space in output buffer for socket address: " << input << '\n'; + break; + default: + std::cerr << "Unexpected 'inet_ntop' error converting socket address: " << input << '\n'; + } + } + return std::nullopt; + } + + /* + // ip_status + ok = 0, + /// The peer closed the socket. + closed = -1, + /// The udp socket has not been bind()'ed to a local port. + udp_not_bound = -2, + try_again = EAGAIN, + // Note: that only the tcp::client will return this, a tls::client returns the specific ssl_would_block_* status. + would_block = EWOULDBLOCK, + bad_file_descriptor = EBADF, + connection_refused = ECONNREFUSED, + memory_fault = EFAULT, + interrupted = EINTR, + invalid_argument = EINVAL, + no_memory = ENOMEM, + not_connected = ENOTCONN, + not_a_socket = ENOTSOCK, + */ + + /* + // ip_status + ok = 0, + closed = -1, + permission_denied = EACCES, + try_again = EAGAIN, + would_block = EWOULDBLOCK, + already_in_progress = EALREADY, + bad_file_descriptor = EBADF, + connection_reset = ECONNRESET, + no_peer_address = EDESTADDRREQ, + memory_fault = EFAULT, + interrupted = EINTR, + is_connection = EISCONN, + message_size = EMSGSIZE, + output_queue_full = ENOBUFS, + no_memory = ENOMEM, + not_connected = ENOTCONN, + not_a_socket = ENOTSOCK, + operation_not_supported = EOPNOTSUPP, + pipe_closed = EPIPE, + */ +} diff --git a/include/glaze/network/repe_client.hpp b/include/glaze/network/repe_client.hpp new file mode 100644 index 0000000000..9c1cf42875 --- /dev/null +++ b/include/glaze/network/repe_client.hpp @@ -0,0 +1,160 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/network/socket.hpp" +#include "glaze/network/socket_io.hpp" +#include "glaze/rpc/repe.hpp" + +namespace glz +{ + template + struct repe_client + { + std::string hostname{"127.0.0.1"}; + uint16_t port{}; + glz::socket socket{}; + + struct glaze + { + using T = repe_client; + static constexpr auto value = glz::object(&T::hostname, &T::port); + }; + + std::shared_ptr buffer_pool = std::make_shared(); + + [[nodiscard]] std::error_code init() + { + if (auto ec = socket.connect(hostname, port)) { + return ec; + } + + if (not socket.no_delay()) { + return {ip_error::socket_bind_failed, ip_error_category::instance()}; + } + + return {}; + } + + template + [[nodiscard]] repe::error_t notify(repe::header&& header, Params&& params) + { + repe::unique_buffer ubuffer{buffer_pool.get()}; + auto& buffer = ubuffer.value(); + + header.notify = true; + const auto ec = repe::request(std::move(header), std::forward(params), buffer); + if (bool(ec)) [[unlikely]] { + return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; + } + + if (auto ec = send(socket, buffer)) { + return {ec.value(), ec.message()}; + } + return {}; + } + + template + [[nodiscard]] repe::error_t get(repe::header&& header, Result&& result) + { + repe::unique_buffer ubuffer{buffer_pool.get()}; + auto& buffer = ubuffer.value(); + + header.notify = false; + header.empty = true; // no params + const auto ec = repe::request(std::move(header), nullptr, buffer); + if (bool(ec)) [[unlikely]] { + return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; + } + + if (auto ec = send(socket, buffer)) { + return {ec.value(), ec.message()}; + } + if (auto ec = receive(socket, buffer)) { + return {ec.value(), ec.message()}; + } + + return repe::decode_response(std::forward(result), buffer); + } + + template + [[nodiscard]] glz::expected get(repe::header&& header) + { + std::decay_t result{}; + const auto error = get(std::move(header), result); + if (error) { + return glz::unexpected(error); + } + else { + return {result}; + } + } + + template + [[nodiscard]] repe::error_t set(repe::header&& header, Params&& params) + { + repe::unique_buffer ubuffer{buffer_pool.get()}; + auto& buffer = ubuffer.value(); + + header.notify = false; + const auto ec = repe::request(std::move(header), std::forward(params), buffer); + if (bool(ec)) [[unlikely]] { + return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; + } + + if (auto ec = send(socket, buffer)) { + return {ec.value(), ec.message()}; + } + if (auto ec = receive(socket, buffer)) { + return {ec.value(), ec.message()}; + } + + return repe::decode_response(buffer); + } + + template + [[nodiscard]] repe::error_t call(repe::header&& header, Params&& params, Result&& result) + { + repe::unique_buffer ubuffer{buffer_pool.get()}; + auto& buffer = ubuffer.value(); + + header.notify = false; + const auto ec = repe::request(std::move(header), std::forward(params), buffer); + if (bool(ec)) [[unlikely]] { + return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; + } + + if (auto ec = send(socket, buffer)) { + return {ec.value(), ec.message()}; + } + if (auto ec = receive(socket, buffer)) { + return {ec.value(), ec.message()}; + } + + return repe::decode_response(std::forward(result), buffer); + } + + [[nodiscard]] repe::error_t call(repe::header&& header) + { + repe::unique_buffer ubuffer{buffer_pool.get()}; + auto& buffer = ubuffer.value(); + + header.notify = false; + header.empty = true; // because no value provided + const auto ec = glz::write_json(std::forward_as_tuple(std::move(header), nullptr), buffer); + if (bool(ec)) [[unlikely]] { + return {repe::error_e::invalid_params, glz::format_error(ec, buffer)}; + } + + if (auto ec = send(socket, buffer)) { + return {ec.value(), ec.message()}; + } + if (auto ec = receive(socket, buffer)) { + return {ec.value(), ec.message()}; + } + + return repe::decode_response(buffer); + } + }; +} diff --git a/include/glaze/network/repe_server.hpp b/include/glaze/network/repe_server.hpp new file mode 100644 index 0000000000..49b7c95554 --- /dev/null +++ b/include/glaze/network/repe_server.hpp @@ -0,0 +1,86 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/network/server.hpp" +#include "glaze/network/socket_io.hpp" +#include "glaze/rpc/repe.hpp" + +namespace glz +{ + template + struct repe_server + { + uint16_t port{}; + bool print_errors = false; + glz::server server{}; + + struct glaze + { + using T = repe_server; + static constexpr auto value = glz::object(&T::port); + }; + + repe::registry registry{}; + + void clear_registry() { registry.clear(); } + + template + requires(glz::detail::glaze_object_t || glz::detail::reflectable) + void on(T& value) + { + registry.template on(value); + } + + void run() + { + server.port = port; + + auto ec = server.accept([this](socket&& socket, auto& active) { + if (not socket.no_delay()) { + std::printf("%s", "no_delay failed"); + return; + } + + std::string buffer{}; + + try { + while (active) { + if (auto ec = receive(socket, buffer)) { + if (print_errors) { + std::fprintf(stderr, "%s\n", ec.message().c_str()); + } + if (ec.value() == ip_error::client_disconnected) { + return; + } + } + else { + auto response = registry.call(buffer); + if (response) { + if (auto ec = send(socket, response->value())) { + if (print_errors) { + std::fprintf(stderr, "%s\n", ec.message().c_str()); + } + } + } + } + } + } + catch (const std::exception& e) { + std::fprintf(stderr, "%s\n", e.what()); + } + }); + + if (ec) { + std::fprintf(stderr, "%s\n", ec.message().c_str()); + } + } + + // stop the server + void stop() + { + server.active = false; + } + }; +} diff --git a/include/glaze/network/server.hpp b/include/glaze/network/server.hpp new file mode 100644 index 0000000000..520fd314c1 --- /dev/null +++ b/include/glaze/network/server.hpp @@ -0,0 +1,80 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/coroutine/task.hpp" +#include "glaze/network/client.hpp" +#include "glaze/network/ip.hpp" +#include "glaze/network/socket.hpp" + +namespace glz +{ + struct server + { + std::shared_ptr scheduler{}; + std::string address{"127.0.0.1"}; + uint16_t port{8080}; + int32_t backlog{128}; // The kernel backlog of connections to buffer. + /// The socket for accepting new tcp connections on. + std::shared_ptr accept_socket = make_accept_socket(address, port); + + /** + * Polls for new incoming tcp connections. + * @param timeout How long to wait for a new connection before timing out, zero waits indefinitely. + * @return The result of the poll, 'event' means the poll was successful and there is at least 1 + * connection ready to be accepted. + */ + task poll(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + { + return scheduler->poll(accept_socket->socket_fd, poll_op::read, timeout); + } + + /** + * Accepts an incoming tcp client connection. On failure the tls clients socket will be set to + * and invalid state, use the socket.is_value() to verify the client was correctly accepted. + * @return The newly connected tcp client connection. + */ + client accept() + { + sockaddr_in client_addr{}; + client_addr.sin_family = AF_INET; // Use AF_INET for IPv4 + client_addr.sin_port = htons(port); + + if (::inet_pton(AF_INET, address.c_str(), &client_addr.sin_addr) <= 0) { + std::cerr << "Invalid address/ Address not supported: " << inet_ntoa(client_addr.sin_addr) << ":" + << ntohs(client_addr.sin_port) << '\n'; + return {}; + } + + std::ostringstream sockaddr; + sockaddr << inet_ntoa(client_addr.sin_addr) << ":" << ntohs(client_addr.sin_port) << '\n'; + std::cout << "Accepting incoming client connection to: " << sockaddr.str(); + + constexpr int len = sizeof(struct sockaddr_in); + + auto new_client_id = ::accept(accept_socket->socket_fd, (struct sockaddr*)(&client_addr), + const_cast((const socklen_t*)(&len))); + + if (new_client_id < 0) { + std::cerr << "Unable to accept client on socket address " << sockaddr.str() << '\n'; + // + // TODO: Handle Error + // + return {}; + } + socket sock{new_client_id}; + + std::cout << "New Client Id, " << new_client_id << ", " << "Accepted On " << sockaddr.str(); + + std::string_view ip_addr_view{sockaddr.str()}; + + // clang-format off + return { .scheduler = scheduler, + .address = to_ip_string(ip_addr_view).value(), + .port = ntohs(client_addr.sin_port), + .ipv = ip_version(client_addr.sin_family)}; + // clang-format on + } + }; +} diff --git a/include/glaze/network/server_old.hpp b/include/glaze/network/server_old.hpp new file mode 100644 index 0000000000..85afae85e6 --- /dev/null +++ b/include/glaze/network/server_old.hpp @@ -0,0 +1,215 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/network/socket.hpp" + +#ifdef _WIN32 +#define GLZ_CLOSE_SOCKET closesocket +#define GLZ_EVENT_CLOSE WSACloseEvent +#define GLZ_EWOULDBLOCK WSAEWOULDBLOCK +#define GLZ_INVALID_EVENT WSA_INVALID_EVENT +#define GLZ_INVALID_SOCKET INVALID_SOCKET +#define GLZ_SOCKET SOCKET +#define GLZ_SOCKET_ERROR SOCKET_ERROR +#define GLZ_SOCKET_ERROR_CODE WSAGetLastError() +#define GLZ_WAIT_FAILED WSA_WAIT_FAILED +#define GLZ_WAIT_RESULT_TYPE DWORD +#else +#include +#include +#if __has_include() +#include +#endif +#include +#include +#include +#include + +#include +#define GLZ_CLOSE_SOCKET ::close +#define GLZ_EVENT_CLOSE ::close +#define GLZ_EWOULDBLOCK EWOULDBLOCK +#define GLZ_INVALID_EVENT (-1) +#define GLZ_INVALID_SOCKET (-1) +#define GLZ_SOCKET int +#define GLZ_SOCKET_ERROR (-1) +#define GLZ_SOCKET_ERROR_CODE errno +#define GLZ_WAIT_FAILED (-1) +#define GLZ_WAIT_RESULT_TYPE int +#endif + +#if defined(__APPLE__) +#include // for kqueue on macOS +#elif defined(__linux__) +#include // for epoll on Linux +#endif + +namespace glz +{ + namespace detail + { + inline void server_thread_cleanup(auto& threads) + { + threads.erase(std::partition(threads.begin(), threads.end(), + [](auto& future) { + if (auto status = future.wait_for(std::chrono::milliseconds(0)); + status == std::future_status::ready) { + return false; + } + return true; + }), + threads.end()); + } + } + + struct server final + { + int port{}; + std::atomic active = true; + std::shared_future async_accept_thread{}; + std::vector> threads{}; + + ~server() { active = false; } + + template + std::shared_future async_accept(AcceptCallback&& callback) + { + async_accept_thread = { + std::async([this, callback = std::forward(callback)] { return accept(callback); })}; + return async_accept_thread; + } + + template + [[nodiscard]] std::error_code accept(AcceptCallback&& callback) + { + glz::socket accept_socket{}; + + const auto ec = bind_and_listen(accept_socket, port); + if (ec) { + return {int(ip_error::socket_bind_failed), ip_error_category::instance()}; + } + +#if defined(__APPLE__) + int event_fd = ::kqueue(); +#elif defined(__linux__) + int event_fd = ::epoll_create1(0); +#elif defined(_WIN32) + HANDLE event_fd = WSACreateEvent(); +#endif + + if (event_fd == GLZ_INVALID_EVENT) { + return {int(ip_error::queue_create_failed), ip_error_category::instance()}; + } + + bool event_setup_failed = false; +#if defined(__APPLE__) + struct kevent change; + EV_SET(&change, accept_socket.socket_fd, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr); + event_setup_failed = ::kevent(event_fd, &change, 1, nullptr, 0, nullptr) == -1; +#elif defined(__linux__) + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = accept_socket.socket_fd; + event_setup_failed = epoll_ctl(event_fd, EPOLL_CTL_ADD, accept_socket.socket_fd, &ev) == -1; +#elif defined(_WIN32) + event_setup_failed = WSAEventSelect(accept_socket.socket_fd, event_fd, FD_ACCEPT) == GLZ_SOCKET_ERROR; +#endif + + if (event_setup_failed) { + GLZ_EVENT_CLOSE(event_fd); + return {int(ip_error::event_ctl_failed), ip_error_category::instance()}; + } + +#if defined(__APPLE__) + std::vector events(16); +#elif defined(__linux__) + std::vector epoll_events(16); +#endif + + while (active) { + GLZ_WAIT_RESULT_TYPE n{}; + +#if defined(__APPLE__) + struct timespec timeout + { + 0, 10000000 + }; // 10ms + n = ::kevent(event_fd, nullptr, 0, events.data(), static_cast(events.size()), &timeout); +#elif defined(__linux__) + n = ::epoll_wait(event_fd, epoll_events.data(), static_cast(epoll_events.size()), 10); +#elif defined(_WIN32) + n = WSAWaitForMultipleEvents(1, &event_fd, FALSE, 10, FALSE); +#endif + + if (n == GLZ_WAIT_FAILED) { +#if defined(__APPLE__) || defined(__linux__) + if (errno == EINTR) continue; +#else + if (n == WSA_WAIT_TIMEOUT) continue; +#endif + GLZ_EVENT_CLOSE(event_fd); + return {int(ip_error::event_wait_failed), ip_error_category::instance()}; + } + + auto spawn_socket = [&] { + sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + auto client_fd = ::accept(accept_socket.socket_fd, (sockaddr*)&client_addr, &client_len); + if (client_fd != GLZ_INVALID_SOCKET) { + threads.emplace_back( + std::async([this, callback, client_fd] { callback(socket{client_fd}, active); })); + } + }; + +#if defined(__APPLE__) || defined(__linux__) + for (int i = 0; i < n; ++i) { +#if defined(__APPLE__) + if (events[i].ident == uintptr_t(accept_socket.socket_fd) && events[i].filter == EVFILT_READ) { +#elif defined(__linux__) + if (epoll_events[i].data.fd == accept_socket.socket_fd && epoll_events[i].events & EPOLLIN) { +#endif + spawn_socket(); + } + } + +#else // Windows + WSANETWORKEVENTS events; + if (WSAEnumNetworkEvents(accept_socket.socket_fd, event_fd, &events) == GLZ_SOCKET_ERROR) { + + WSACloseEvent(event_fd); + + // requires explicit 'std::error_code'...otherwise the following error with msvc... + // + // error C2440: 'return': cannot convert from 'initializer list' to 'std::error_code' + // + return {int(ip_error::event_enum_failed), ip_error_category::instance()}; + } + + if (events.lNetworkEvents & FD_ACCEPT) { + if (events.iErrorCode[FD_ACCEPT_BIT] == 0) { + spawn_socket(); + } + } +#endif + + detail::server_thread_cleanup(threads); + } + + GLZ_EVENT_CLOSE(event_fd); + return {}; + } + }; +} + +#undef GLZ_CLOSE_SOCKET +#undef GLZ_EVENT_CLOSE +#undef GLZ_EWOULDBLOCK +#undef GLZ_INVALID_EVENT +#undef GLZ_INVALID_SOCKET +#undef GLZ_SOCKET +#undef GLZ_SOCKET_ERROR +#undef GLZ_SOCKET_ERROR_CODE +#undef GLZ_WAIT_FAILED +#undef GLZ_WAIT_RESULT_TYPE diff --git a/include/glaze/network/socket.hpp b/include/glaze/network/socket.hpp new file mode 100644 index 0000000000..619176a791 --- /dev/null +++ b/include/glaze/network/socket.hpp @@ -0,0 +1,286 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include + +#include "glaze/network/core.hpp" +#include "glaze/network/ip.hpp" +#include "glaze/network/socket_core.hpp" + +namespace glz +{ +#ifdef _WIN32 + constexpr auto e_would_block = WSAEWOULDBLOCK; + using socket_t = SOCKET; + constexpr auto socket_error = SOCKET_ERROR; +#else + constexpr auto e_would_block = EWOULDBLOCK; + using socket_t = int; + constexpr auto socket_error = -1; +#endif +} + +#include +#include +#include +#include +#include +#include + +namespace glz +{ + struct socket final + { + socket_t socket_fd{net::invalid_socket}; + + void close() + { + if (socket_fd != net::invalid_socket) { + net::close_socket(socket_fd); + } + } + + bool valid() const { return socket_fd != net::invalid_socket; } + + ~socket() { close(); } + + [[nodiscard]] bool no_delay() + { + int flag = 1; + int result = setsockopt(socket_fd, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int)); + return result == 0; + } + }; + + inline void set_non_blocking(socket& sock) noexcept + { +#ifdef _WIN32 + u_long mode = 1; + ioctlsocket(sock.socket_fd, FIONBIO, &mode); +#else + int flags = fcntl(sock.socket_fd, F_GETFL, 0); + fcntl(sock.socket_fd, F_SETFL, flags | O_NONBLOCK); +#endif + } + + [[nodiscard]] inline std::error_code connect(socket& sock, const std::string& address, const int port) + { + // TODO: Support ipv6 + sock.socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); + set_non_blocking(sock); + if (sock.socket_fd == net::invalid_socket) { + return {int(ip_error::socket_connect_failed), ip_error_category::instance()}; + } + + sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(uint16_t(port)); + ::inet_pton(AF_INET, address.c_str(), &server_addr.sin_addr); + + if (::connect(sock.socket_fd, (sockaddr*)&server_addr, sizeof(server_addr)) == -1) { + return {int(ip_error::socket_connect_failed), ip_error_category::instance()}; + } + + set_non_blocking(sock); + + return {}; + } + + [[nodiscard]] inline std::error_code bind_and_listen(socket& sock, const std::string& address, int port) + { + set_non_blocking(sock); + if (sock.socket_fd == net::invalid_socket) { + return {int(ip_error::socket_bind_failed), ip_error_category::instance()}; + } + +#ifdef _WIN32 + char sock_opt{1}; +#else + int sock_opt{1}; +#endif + if (setsockopt(sock.socket_fd, SOL_SOCKET, SO_REUSEADDR, &sock_opt, sizeof(sock_opt)) < 0) { + std::cerr << "setsockopt SO_REUSEADDR: " << get_socket_error_message(errno) << '\n'; + return {int(ip_error::socket_bind_failed), ip_error_category::instance()}; + } + +#ifdef SO_REUSEPORT + if (setsockopt(sock.socket_fd, SOL_SOCKET, SO_REUSEPORT, &sock_opt, sizeof(sock_opt)) < 0) { + std::cerr << "setsockopt SO_REUSEPORT: " << get_socket_error_message(errno) << '\n'; + // You might want to handle this error differently, as it's not critical + } +#endif + + sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(uint16_t(port)); + ::inet_pton(glz::net::asize_t(AF_INET), address.c_str(), &server_addr.sin_addr); + + if (::bind(sock.socket_fd, (sockaddr*)&server_addr, sizeof(server_addr)) == -1) { + return {int(ip_error::socket_bind_failed), ip_error_category::instance()}; + } + + if (::listen(sock.socket_fd, SOMAXCONN) == -1) { + return {int(ip_error::socket_bind_failed), ip_error_category::instance()}; + } + + if (not sock.no_delay()) { + return {int(ip_error::socket_bind_failed), ip_error_category::instance()}; + } + + return {}; + } + + [[nodiscard]] inline std::shared_ptr make_async_socket() + { + auto sock = std::make_shared(); + sock->socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); + set_non_blocking(*sock); + return sock; + } + + [[nodiscard]] inline std::shared_ptr make_accept_socket(const std::string& address, int port) + { + auto sock = std::make_shared(); + sock->socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); + + if (auto ec = bind_and_listen(*sock, address, port)) { + return {}; + } + + return sock; + } + + GLZ_ENUM(socket_event, bytes_read, wait, client_disconnected, receive_failed); + + struct socket_state + { + size_t bytes_read{}; + socket_event event{}; + }; + + [[nodiscard]] inline socket_state async_recv(socket& sckt, char* buffer, size_t size) + { + auto bytes = ::recv(sckt.socket_fd, buffer, net::ssize_t(size), 0); + if (bytes == -1) { + if (GLZ_SOCKET_ERROR_CODE == e_would_block || GLZ_SOCKET_ERROR_CODE == EAGAIN) { + return {0, socket_event::wait}; + } + else { + return {0, socket_event::receive_failed}; + } + } + else if (bytes == 0) { + return {0, socket_event::client_disconnected}; + } + return {size_t(bytes), socket_event::bytes_read}; + } + + template + [[nodiscard]] std::error_code blocking_header_receive(socket& sckt, Header& header, std::string& buffer, + size_t timeout_ms) + { + // first receive the header + auto t0 = std::chrono::steady_clock::now(); + size_t total_bytes{}; + while (total_bytes < sizeof(Header)) { + auto t1 = std::chrono::steady_clock::now(); + if (size_t(std::chrono::duration_cast(t1 - t0).count()) >= timeout_ms) { + std::cout << std::chrono::duration_cast(t1 - t0).count() << '\n'; + return {int(ip_error::receive_timeout), ip_error_category::instance()}; + } + auto [bytes, event] = + async_recv(sckt, reinterpret_cast(&header) + total_bytes, sizeof(Header) - total_bytes); + using enum socket_event; + switch (event) { + case bytes_read: { + total_bytes += bytes; + break; + } + case wait: { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + case client_disconnected: { + return {int(ip_error::client_disconnected), ip_error_category::instance()}; + } + case receive_failed: { + [[fallthrough]]; + } + default: { + buffer.clear(); + return {int(ip_error::receive_failed), ip_error_category::instance()}; + } + } + } + + size_t size{}; + if constexpr (std::same_as) { + size = header; + } + else { + size = size_t(header.body_size); + } + + buffer.resize(size); + + t0 = std::chrono::steady_clock::now(); + total_bytes = 0; + while (total_bytes < size) { + auto t1 = std::chrono::steady_clock::now(); + if (size_t(std::chrono::duration_cast(t1 - t0).count()) >= timeout_ms) { + std::cout << std::chrono::duration_cast(t1 - t0).count() << '\n'; + return {int(ip_error::receive_timeout), ip_error_category::instance()}; + } + auto [bytes, event] = async_recv(sckt, buffer.data() + total_bytes, buffer.size() - total_bytes); + using enum socket_event; + switch (event) { + case bytes_read: { + total_bytes += bytes; + break; + } + case wait: { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + case client_disconnected: { + return {int(ip_error::client_disconnected), ip_error_category::instance()}; + } + case receive_failed: { + [[fallthrough]]; + } + default: { + buffer.clear(); + return {int(ip_error::receive_failed), ip_error_category::instance()}; + } + } + } + return {}; + } + + [[nodiscard]] inline std::error_code blocking_send(socket& sckt, const std::string_view buffer) + { + const size_t size = buffer.size(); + size_t total_bytes{}; + while (total_bytes < size) { + auto bytes = + ::send(sckt.socket_fd, buffer.data() + total_bytes, glz::net::ssize_t(buffer.size() - total_bytes), 0); + if (bytes == -1) { + if (GLZ_SOCKET_ERROR_CODE == e_would_block || GLZ_SOCKET_ERROR_CODE == EAGAIN) { + std::this_thread::yield(); + continue; + } + else { + return {int(ip_error::send_failed), ip_error_category::instance()}; + } + } + + total_bytes += bytes; + } + return {}; + } +} diff --git a/include/glaze/network/socket_core.hpp b/include/glaze/network/socket_core.hpp new file mode 100644 index 0000000000..403c036ea0 --- /dev/null +++ b/include/glaze/network/socket_core.hpp @@ -0,0 +1,228 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/network/core.hpp" +#include "glaze/network/ip.hpp" + +#ifdef _WIN32 +#pragma comment(lib, "Ws2_32.lib") +#define GLZ_SOCKET_ERROR_CODE WSAGetLastError() +#else +#define GLZ_SOCKET_ERROR_CODE errno +#include +#include +#if __has_include() +#include +#endif +#include +#include +#include +#include +#endif + +#include +#include + +namespace glz +{ + namespace detail + { + inline std::string format_ip_port(const sockaddr_in& server_addr) + { + char ip_str[INET_ADDRSTRLEN]{}; + + #ifdef _WIN32 + inet_ntop(AF_INET, &(server_addr.sin_addr), ip_str, INET_ADDRSTRLEN); + #else + inet_ntop(AF_INET, &(server_addr.sin_addr), ip_str, sizeof(ip_str)); + #endif + + return {std::format("{}:{}", ip_str, ntohs(server_addr.sin_port))}; + } + } + + inline std::string get_socket_error_message(int err) + { +#ifdef _WIN32 + + char* msg = nullptr; + FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, + err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&msg, 0, NULL); + std::string message(msg); + LocalFree(msg); + return {message}; + +#else + return strerror(err); +#endif + } + + struct socket_api_error_category_t final : std::error_category + { + std::string what{}; + const char* name() const noexcept override { return "socket error"; } + std::string message(int ev) const override + { + if (what.empty()) { + return {get_socket_error_message(ev)}; + } + else { + return {std::format("{}\nDetails: {}", what, get_socket_error_message(ev))}; + } + } + + void operator()(int ev, const std::string_view w) + { + what = w; + this->message(ev); + } + }; + + inline const socket_api_error_category_t& socket_api_error_category(const std::string_view what) + { + static socket_api_error_category_t singleton; + singleton.what = what; + return singleton; + } + + inline std::error_code get_socket_error(const std::string_view what = "") + { +#ifdef _WIN32 + int err = WSAGetLastError(); +#else + int err = errno; +#endif + + return {err, socket_api_error_category(what)}; + } + + inline std::error_code check_status(int ec, const std::string_view what = "") + { + if (ec >= 0) { + return {}; + } + + return {get_socket_error(what)}; + } + + // Example: + // + // std::error_code ec = check_status(result, "connect failed"); + // + // if (ec) { + // std::cerr << get_socket_error(std::format("Failed to connect to socket at address: {}.\nIs the server + // running?", ip_port)).message(); + // } + // else { + // std::cout << "Connected successfully!"; + // } + + // For Windows WSASocket Compatability + + inline constexpr uint16_t make_version(uint8_t low_byte, uint8_t high_byte) noexcept + { + return uint16_t(low_byte) | (uint16_t(high_byte) << 8); + } + + inline constexpr uint8_t major_version(uint16_t version) noexcept + { + return uint8_t(version & 0xFF); // Extract the low byte + } + + inline constexpr uint8_t minor_version(uint16_t version) noexcept + { + return uint8_t((version >> 8) & 0xFF); // Shift right by 8 bits and extract the low byte + } + + // Function to get Winsock version string on Windows, return "na" otherwise + inline std::string get_winsock_version_string(uint32_t version = make_version(2, 2)) + { +#if _WIN32 + BYTE major = major_version(uint16_t(version)); + BYTE minor = minor_version(uint16_t(version)); + return std::format("{}.{}", int(major), int(minor)); +#else + (void)version; + return ""; // Default behavior for non-Windows platforms +#endif + } + + // The 'wsa_startup_t' calls the windows WSAStartup function. This must be the first Windows + // Sockets function called by an application or DLL. It allows an application or DLL to + // specify the version of Windows Sockets required and retrieve details of the specific + // Windows Sockets implementation.The application or DLL can only issue further Windows Sockets + // functions after successfully calling WSAStartup. + // + // Important: WSAStartup and its corresponding WSACleanup must be called on the same thread. + // + template + struct windows_socket_startup_t final + { +#ifdef _WIN64 + WSADATA wsa_data{}; + + std::error_code error_code{}; + + std::error_code start(const WORD win_sock_version = make_version(2, 2)) // Request latest Winsock version 2.2 + { + static std::once_flag flag{}; + std::error_code startup_error{}; + std::call_once(flag, [this, win_sock_version, &startup_error]() { + int result = WSAStartup(win_sock_version, &wsa_data); + if (result != 0) { + error_code = get_socket_error( + std::format("Unable to initialize Winsock library version {}.", get_winsock_version_string())); + } + }); + return {error_code}; + } + + windows_socket_startup_t() + { + if constexpr (run_wsa_startup) { + error_code = start(); + } + } + + ~windows_socket_startup_t() { WSACleanup(); } + +#else + std::error_code start() { return std::error_code{}; } +#endif + }; + + GLZ_ENUM(ip_error, none, + queue_create_failed, + event_ctl_failed, + event_wait_failed, + event_enum_failed, + socket_connect_failed, + socket_bind_failed, + send_failed, + receive_failed, + receive_timeout, + client_disconnected); + + template + concept ip_header = std::same_as || requires { T::body_size; }; + + struct ip_error_category : std::error_category + { + static const ip_error_category& instance() { + static ip_error_category instance{}; + return instance; + } + + const char* name() const noexcept override { return "ip_error_category"; } + + std::string message(int ec) const override + { + return std::string{nameof(static_cast(ec))}; + } + }; +} diff --git a/include/glaze/network/socket_io.hpp b/include/glaze/network/socket_io.hpp new file mode 100644 index 0000000000..81c33af1d1 --- /dev/null +++ b/include/glaze/network/socket_io.hpp @@ -0,0 +1,42 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include "glaze/glaze.hpp" +#include "glaze/network/socket.hpp" +#include "glaze/core/error.hpp" + +namespace glz +{ + template + [[nodiscard]] std::error_code send(socket& sckt, T&& value, Buffer&& buffer) + { + if (auto ec = glz::write(std::forward(value), buffer)) { + return {int(ec.ec), error_category::instance()}; + } + + uint64_t header = uint64_t(buffer.size()); + + if (auto ec = blocking_send(sckt, sv{reinterpret_cast(&header), sizeof(header)})) { + return ec; + } + + return blocking_send(sckt, buffer); + } + + template + [[nodiscard]] std::error_code receive(socket& sckt, T&& value, Buffer&& buffer, size_t timeout_ms) + { + uint64_t header{}; + if (auto ec = blocking_header_receive(sckt, header, buffer, timeout_ms)) { + return ec; + } + + if (auto ec = glz::read(std::forward(value), buffer)) { + return {int(ec.ec), error_category::instance()}; + } + + return {}; + } +} diff --git a/include/glaze/thread/threadpool.hpp b/include/glaze/thread/threadpool.hpp index e4b70e25e1..2b0e93e408 100644 --- a/include/glaze/thread/threadpool.hpp +++ b/include/glaze/thread/threadpool.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include diff --git a/include/glaze/util/expected.hpp b/include/glaze/util/expected.hpp index a21ddcb6b4..2fc2e9714f 100644 --- a/include/glaze/util/expected.hpp +++ b/include/glaze/util/expected.hpp @@ -29,10 +29,8 @@ #ifndef GLZ_THROW_OR_ABORT #if __cpp_exceptions #define GLZ_THROW_OR_ABORT(EXC) (throw(EXC)) -#define GLZ_NOEXCEPT noexcept(false) #else #define GLZ_THROW_OR_ABORT(EXC) (std::abort()) -#define GLZ_NOEXCEPT noexcept(true) #endif #endif @@ -42,7 +40,7 @@ namespace glz { - inline void glaze_error([[maybe_unused]] const char* msg) GLZ_NOEXCEPT + inline void glaze_error([[maybe_unused]] const char* msg) { GLZ_THROW_OR_ABORT(std::runtime_error(msg)); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6e93dde410..f2f229873e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,9 @@ include(FetchContent) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + set(BOOST_UT_ENABLE_RUN_AFTER_BUILD OFF CACHE INTERNAL "") set(BOOST_UT_DISABLE_MODULE ON CACHE INTERNAL "") @@ -52,11 +56,12 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") target_compile_options(glz_test_exceptions INTERFACE /W4 /wd4459 /wd4805) endif() -add_subdirectory(asio_repe) +add_subdirectory(repe_server_client) add_subdirectory(api_test) add_subdirectory(binary_test) add_subdirectory(cli_menu_test) add_subdirectory(compare_test) +add_subdirectory(coroutine_test) add_subdirectory(csv_test) add_subdirectory(eigen_test) add_subdirectory(exceptions_test) @@ -69,6 +74,7 @@ add_subdirectory(mock_json_test) add_subdirectory(stencil_test) add_subdirectory(reflection_test) add_subdirectory(repe_test) +add_subdirectory(socket_test) # We don't run find_package_test or glaze-install_test with MSVC/Windows, because the Github action runner often chokes if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") diff --git a/tests/asio_repe/CMakeLists.txt b/tests/asio_repe/CMakeLists.txt deleted file mode 100644 index 70417b38f8..0000000000 --- a/tests/asio_repe/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -project(asio_repe) - -FetchContent_Declare( - asio - GIT_REPOSITORY https://github.com/chriskohlhoff/asio.git - GIT_TAG asio-1-30-1 - GIT_SHALLOW TRUE -) -FetchContent_GetProperties(asio) -if(NOT asio_POPULATED) - FetchContent_Populate(asio) -endif() - -add_subdirectory(server) -add_subdirectory(client) diff --git a/tests/coroutine_test/CMakeLists.txt b/tests/coroutine_test/CMakeLists.txt new file mode 100644 index 0000000000..7bd7b13606 --- /dev/null +++ b/tests/coroutine_test/CMakeLists.txt @@ -0,0 +1,11 @@ +project(coroutine_test) + +add_compile_definitions(CURRENT_DIRECTORY="${CMAKE_CURRENT_SOURCE_DIR}") + +add_executable(${PROJECT_NAME} ${PROJECT_NAME}.cpp) + +target_link_libraries(${PROJECT_NAME} PRIVATE glz_test_exceptions) + +add_test(NAME ${PROJECT_NAME} COMMAND ${PROJECT_NAME}) + +target_code_coverage(${PROJECT_NAME} AUTO ALL) diff --git a/tests/coroutine_test/coroutine_test.cpp b/tests/coroutine_test/coroutine_test.cpp new file mode 100644 index 0000000000..1014a83ed3 --- /dev/null +++ b/tests/coroutine_test/coroutine_test.cpp @@ -0,0 +1,711 @@ +#include "glaze/coroutine.hpp" + +#include +#include +#include +#include + +#include "exec/async_scope.hpp" +#include "exec/finally.hpp" +#include "exec/static_thread_pool.hpp" +#include "exec/task.hpp" +#include "exec/timed_thread_scheduler.hpp" +#include "exec/when_any.hpp" +#include "stdexec/execution.hpp" +#include "ut/ut.hpp" + +using namespace ut; + +#define TEST_ALL + +#ifdef TEST_ALL +suite generator = [] { + std::atomic result{}; + auto task = [&](uint64_t count_to) -> exec::task { + // Create a generator function that will yield and incrementing + // number each time its called. + auto gen = []() -> glz::generator { + uint64_t i = 0; + while (true) { + co_yield i; + ++i; + } + }; + + // Generate the next number until its greater than count to. + for (auto val : gen()) { + // std::cout << val << ", "; + result += val; + + if (val >= count_to) { + break; + } + } + co_return; + }; + + stdexec::sync_wait(task(100)); + + expect(result == 5050) << result; +}; + +suite thread_pool = [] { + // This lambda will create a glz::task that returns a unit64_t. + // It can be invoked many times with different arguments. + auto make_task_inline = [](uint64_t x) -> exec::task { co_return x + x; }; + + // This will block the calling thread until the created task completes. + // Since this task isn't scheduled on any glz::thread_pool or glz::scheduler + // it will execute directly on the calling thread. + auto result = stdexec::sync_wait(make_task_inline(5)); + expect(std::get<0>(result.value()) == 10); + // std::cout << "Inline Result = " << result << "\n"; + + exec::static_thread_pool pool(1); + auto sched = pool.get_scheduler(); + + auto make_task_offload = [&sched](uint64_t x) { + return stdexec::on(sched, stdexec::just() | stdexec::then([=] { return x + x; })); + }; + + result = stdexec::sync_wait(make_task_offload(10)); + expect(std::get<0>(result.value()) == 20); +}; + +suite when_all = [] { + // Create a thread pool to execute all the tasks in parallel. + exec::static_thread_pool tp{4}; + auto scheduler = tp.get_scheduler(); + auto twice = [&](uint64_t x) { + return x + x; // Executed on the thread pool. + }; + + exec::async_scope scope; + std::mutex mtx{}; + std::vector results{}; + for (std::size_t i = 0; i < 5; ++i) { + scope.spawn(stdexec::on(scheduler, stdexec::just() | stdexec::then([&, i] { + const auto value = twice(i + 1); + std::unique_lock lock{mtx}; + results.emplace_back(value); + }))); + } + + // Synchronously wait on this thread for the thread pool to finish executing all the tasks in parallel. + [[maybe_unused]] auto ret = stdexec::sync_wait(scope.on_empty()); + std::ranges::sort(results); + expect(results[0] == 2); + expect(results[1] == 4); + expect(results[2] == 6); + expect(results[3] == 8); + expect(results[4] == 10); + + // Use var args instead of a container as input to glz::when_all. + auto square = [](uint64_t x) { return [=] { return x * x; }; }; + + auto parallel = [](auto& scheduler, auto... fs) { + return stdexec::when_all(stdexec::on(scheduler, stdexec::just() | stdexec::then(fs))...); + }; + + /*auto chain(auto& scheduler, auto... fs) { + return stdexec::on(scheduler, stdexec::just() | (stdexec::then(fs) | ...)); + }*/ + + // Var args allows you to pass in tasks with different return types and returns + // the result as a std::tuple. + // auto tuple_results = stdexec::sync_wait(chain_workers(scheduler, square(2), square(10))).value(); + auto tuple_results = stdexec::sync_wait(parallel(scheduler, square(2), [&] { return twice(10); })).value(); + + auto first = std::get<0>(tuple_results); + auto second = std::get<1>(tuple_results); + + expect(first == 4); + expect(second == 20); +}; + +suite event = [] { + std::cout << "\nEvent test:\n"; + glz::event e; + + // These tasks will wait until the given event has been set before advancing. + auto make_wait_task = [](const glz::event& e, uint64_t i) -> exec::task { + std::cout << "task " << i << " is waiting on the event...\n"; + co_await e; + std::cout << "task " << i << " event triggered, now resuming.\n"; + co_return; + }; + + // This task will trigger the event allowing all waiting tasks to proceed. + auto make_set_task = [](glz::event& e) -> exec::task { + std::cout << "set task is triggering the event\n"; + e.set(); + co_return; + }; + + // Given more than a single task to synchronously wait on, use when_all() to execute all the + // tasks concurrently on this thread and then sync_wait() for them all to complete. + stdexec::sync_wait( + stdexec::when_all(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); +}; + +using namespace std::chrono_literals; + +/*struct timer { + std::chrono::milliseconds duration_; + + auto operator co_await() const noexcept { + struct awaiter { + std::chrono::milliseconds duration; + + bool await_ready() const noexcept { return false; } + + void await_suspend(std::coroutine_handle<> h) const { + std::thread([h, this]() { + std::this_thread::sleep_for(duration); + h.resume(); + }).detach(); + } + + void await_resume() const noexcept {} + }; + + return awaiter{duration_}; + } +};*/ + +/* +suite latch = [] { + std::cout << "\nLatch test:\n"; + // Complete worker tasks faster on a thread pool, using the scheduler version so the worker + // tasks can yield for a specific amount of time to mimic difficult work. The pool is only + // setup with a single thread to showcase yield_for(). + exec::static_thread_pool tp{1}; + auto scheduler = tp.get_scheduler(); + + // This task will wait until the given latch setters have completed. + auto make_latch_task = [](glz::latch& l) -> exec::task { + // It seems like the dependent worker tasks could be created here, but in that case it would + // be superior to simply do: `co_await coro::when_all(tasks);` + // It is also important to note that the last dependent task will resume the waiting latch + // task prior to actually completing -- thus the dependent task's frame could be destroyed + // by the latch task completing before it gets a chance to finish after calling resume() on + // the latch task! + + std::cout << "latch task is now waiting on all children tasks...\n"; + co_await l; + std::cout << "latch task dependency tasks completed, resuming.\n"; + co_return; + }; + + // This task does 'work' and counts down on the latch when completed. The final child task to + // complete will end up resuming the latch task when the latch's count reaches zero. + auto make_worker_task = [](auto& sched, glz::latch& l, int64_t i) -> exec::task { + // Schedule the worker task onto the thread pool. + co_await sched.schedule(); + std::cout << "worker task " << i << " is working...\n"; + // Do some expensive calculations, yield to mimic work...! Its also important to never use + // std::this_thread::sleep_for() within the context of coroutines, it will block the thread + // and other tasks that are ready to execute will be blocked. + //co_await timer{std::chrono::milliseconds{i * 20}}(sched); + co_await [&]() -> exec::task { + co_return std::thread([i]{ std::this_thread::sleep_for(std::chrono::milliseconds(i * 20)); }).detach(); + }(); + std::cout << "worker task " << i << " is done, counting down on the latch\n"; + l.count_down(); + co_return; + }; + + const int64_t num_tasks{5}; + glz::latch l{num_tasks}; + + // Make the latch task first so it correctly waits for all worker tasks to count down. + auto work = [&]() -> exec::task { + for (int64_t i = 1; i <= num_tasks; ++i) { + co_await make_worker_task(scheduler, l, i); + } + co_return; + }; + + // Wait for all tasks to complete. + stdexec::sync_wait(stdexec::when_all(make_latch_task(l), work())); +}; +*/ + +suite latch = [] { + namespace ex = stdexec; + + std::cout << "\nLatch test:\n"; + + // Create a thread pool with a single thread + exec::static_thread_pool tp{1}; + auto scheduler = tp.get_scheduler(); + + // Create a latch with count 3 (for 3 worker tasks) + std::latch l(3); + + std::atomic wait_start{false}; + + // Define the latch task + auto latch_task = ex::let_value(ex::schedule(scheduler), [&]() { + return ex::just() | ex::then([&]() { + std::cout << "latch task is now waiting on all children tasks...\n"; + wait_start = true; + l.wait(); + std::cout << "latch task dependency tasks completed, resuming.\n"; + }); + }); + + // Define the worker task + auto make_worker_task = [&](int64_t i) { + while (wait_start) std::this_thread::sleep_for(std::chrono::milliseconds(10)); + return ex::let_value(ex::schedule(scheduler), [i, &l]() { + return ex::just() | ex::then([i]() { + std::cout << "worker task " << i << " is working...\n"; + std::this_thread::sleep_for(std::chrono::milliseconds(i * 20)); + std::cout << "worker task " << i << " is done, counting down on the latch\n"; + }) | + ex::then([&l]() { l.count_down(); }); + }); + }; + + // Create and schedule tasks + auto scheduled_latch_task = latch_task; + + /* TODO: + std::vector worker_tasks; + for (int i = 1; i <= 3; ++i) { + worker_tasks.push_back(make_worker_task(i)); + } + */ + + auto scheduled_worker_tasks = ex::when_all(make_worker_task(1), make_worker_task(2), make_worker_task(3)); + + auto start = std::chrono::high_resolution_clock::now(); + ex::sync_wait(ex::when_all(scheduled_latch_task, scheduled_worker_tasks)); + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration diff = end - start; + std::cout << "Total execution time: " << diff.count() << " seconds\n"; + + return 0; +}; + +suite mutex_test = [] { + std::cout << "\nMutex test:\n"; + + exec::static_thread_pool tp{4}; + auto scheduler = tp.get_scheduler(); + std::vector output{}; + std::mutex mtx{}; + + auto make_critical_section_task = [&](uint64_t i) { + std::unique_lock lock{mtx}; + output.emplace_back(i); + }; + + const size_t num_tasks{100}; + exec::async_scope scope; + for (std::size_t i = 1; i < num_tasks; ++i) { + scope.spawn(stdexec::on(scheduler, stdexec::just() | stdexec::then([&, i]() { make_critical_section_task(i); }))); + } + + stdexec::sync_wait(scope.on_empty()); + + // The output will be variable per run depending on how the tasks are picked up on the + // thread pool workers. + for (const auto& value : output) { + std::cout << value << ", "; + } +}; + +/*suite shared_mutex_test = [] { + std::cout << "\nShared Mutex test:\n"; + // Shared mutexes require an excutor type to be able to wake up multiple shared waiters when + // there is an exclusive lock holder releasing the lock. This example uses a single thread + // to also show the interleaving of coroutines acquiring the shared lock in shared and + // exclusive mode as they resume and suspend in a linear manner. Ideally the thread pool + // executor would have more than 1 thread to resume all shared waiters in parallel. + auto tp = std::make_shared(glz::thread_pool::options{.thread_count = 1}); + glz::shared_mutex mutex{tp}; + + auto make_shared_task = [&](uint64_t i) -> glz::task { + co_await tp->schedule(); + { + std::cerr << "shared task " << i << " lock_shared()\n"; + auto scoped_lock = co_await mutex.lock_shared(); + std::cerr << "shared task " << i << " lock_shared() acquired\n"; + /// Immediately yield so the other shared tasks also acquire in shared state + /// while this task currently holds the mutex in shared state. + co_await tp->yield(); + std::cerr << "shared task " << i << " unlock_shared()\n"; + } + co_return; + }; + + auto make_exclusive_task = [&]() -> glz::task { + co_await tp->schedule(); + + std::cerr << "exclusive task lock()\n"; + auto scoped_lock = co_await mutex.lock(); + std::cerr << "exclusive task lock() acquired\n"; + // Do the exclusive work.. + std::cerr << "exclusive task unlock()\n"; + co_return; + }; + + // Create 3 shared tasks that will acquire the mutex in a shared state. + const size_t num_tasks{3}; + std::vector> tasks{}; + for (size_t i = 1; i <= num_tasks; ++i) { + tasks.emplace_back(make_shared_task(i)); + } + // Create an exclusive task. + tasks.emplace_back(make_exclusive_task()); + // Create 3 more shared tasks that will be blocked until the exclusive task completes. + for (size_t i = num_tasks + 1; i <= num_tasks * 2; ++i) { + tasks.emplace_back(make_shared_task(i)); + } + + glz::sync_wait(glz::when_all(std::move(tasks))); +};*/ + +/*suite semaphore_test = [] { + std::cout << "\nSemaphore test:\n"; + // Have more threads/tasks than the semaphore will allow for at any given point in time. + glz::thread_pool tp{glz::thread_pool::options{.thread_count = 8}}; + glz::semaphore semaphore{1}; + + auto make_rate_limited_task = [&](uint64_t task_num) -> glz::task { + co_await tp.schedule(); + + // This will only allow 1 task through at any given point in time, all other tasks will + // await the resource to be available before proceeding. + auto result = co_await semaphore.acquire(); + if (result == glz::semaphore::acquire_result::acquired) { + std::cout << task_num << ", "; + semaphore.release(); + } + else { + std::cout << task_num << " failed to acquire semaphore [" << glz::semaphore::to_string(result) << "],"; + } + co_return; + }; + + const size_t num_tasks{100}; + std::vector> tasks{}; + for (size_t i = 1; i <= num_tasks; ++i) { + tasks.emplace_back(make_rate_limited_task(i)); + } + + glz::sync_wait(glz::when_all(std::move(tasks))); +};*/ + +/*suite ring_buffer_test = [] { + std::cout << "\nRing Buffer test:\n"; + + const size_t iterations = 100; + const size_t consumers = 4; + glz::thread_pool tp{glz::thread_pool::options{.thread_count = 4}}; + glz::ring_buffer rb{}; + glz::mutex m{}; + + std::vector> tasks{}; + + auto make_producer_task = [&]() -> glz::task { + co_await tp.schedule(); + + for (size_t i = 1; i <= iterations; ++i) { + co_await rb.produce(i); + } + + // Wait for the ring buffer to clear all items so its a clean stop. + while (!rb.empty()) { + co_await tp.yield(); + } + + // Now that the ring buffer is empty signal to all the consumers its time to stop. Note that + // the stop signal works on producers as well, but this example only uses 1 producer. + { + auto scoped_lock = co_await m.lock(); + std::cerr << "\nproducer is sending stop signal"; + } + rb.notify_waiters(); + co_return; + }; + + auto make_consumer_task = [&](size_t id) -> glz::task { + co_await tp.schedule(); + + while (true) { + auto expected = co_await rb.consume(); + auto scoped_lock = co_await m.lock(); // just for synchronizing std::cout/cerr + if (!expected) { + std::cerr << "\nconsumer " << id << " shutting down, stop signal received"; + break; // while + } + else { + auto item = std::move(*expected); + std::cout << "(id=" << id << ", v=" << item << "), "; + } + + // Mimic doing some work on the consumed value. + co_await tp.yield(); + } + + co_return; + }; + + // Create N consumers + for (size_t i = 0; i < consumers; ++i) { + tasks.emplace_back(make_consumer_task(i)); + } + // Create 1 producer. + tasks.emplace_back(make_producer_task()); + + // Wait for all the values to be produced and consumed through the ring buffer. + glz::sync_wait(glz::when_all(std::move(tasks))); +};*/ +#endif + +// #define SERVER_CLIENT_TEST + +#ifdef SERVER_CLIENT_TEST +suite server_client_test = [] { + std::cout << "\n\nServer/Client test:\n"; + + auto scheduler = std::make_shared(glz::scheduler::options{ + // The scheduler will spawn a dedicated event processing thread. This is the default, but + // it is possible to use 'manual' and call 'process_events()' to drive the scheduler yourself. + .thread_strategy = glz::thread_strategy::spawn, + // If the scheduler is in spawn mode this functor is called upon starting the dedicated + // event processor thread. + .on_io_thread_start_functor = [] { std::cout << "scheduler::process event thread start\n"; }, + // If the scheduler is in spawn mode this functor is called upon stopping the dedicated + // event process thread. + .on_io_thread_stop_functor = [] { std::cout << "scheduler::process event thread stop\n"; }, + // The io scheduler can use a coro::thread_pool to process the events or tasks it is given. + // You can use an execution strategy of `process_tasks_inline` to have the event loop thread + // directly process the tasks, this might be desirable for small tasks vs a thread pool for large tasks. + .pool = + glz::thread_pool::options{ + .thread_count = 1, + .on_thread_start_functor = + [](size_t i) { std::cout << "scheduler::thread_pool worker " << i << " starting\n"; }, + .on_thread_stop_functor = + [](size_t i) { std::cout << "scheduler::thread_pool worker " << i << " stopping\n"; }, + }, + .execution_strategy = glz::scheduler::execution_strategy::process_tasks_on_thread_pool}); + + auto make_server_task = [&]() -> glz::task { + // Start by creating a tcp server, we'll do this before putting it into the scheduler so + // it is immediately available for the client to connect since this will create a socket, + // bind the socket and start listening on that socket. See tcp::server for more details on + // how to specify the local address and port to bind to as well as enabling SSL/TLS. + glz::server server{scheduler}; + + // Now scheduler this task onto the scheduler. + co_await scheduler->schedule(); + + // Wait for an incoming connection and accept it. + auto poll_status = co_await server.poll(); + if (poll_status != glz::poll_status::event) { + std::cerr << "Incoming client connection failed!\n" + << "Poll Status Detail: " << glz::nameof(poll_status) << '\n'; + co_return; // Handle error, see poll_status for detailed error states. + } + + auto client = server.accept(); + + if (not client.socket->valid()) { + std::cerr << "Incoming client connection failed!\n"; + co_return; // Handle error. + } + + // Now wait for the client message, this message is small enough it should always arrive + // with a single recv() call. + poll_status = co_await client.poll(glz::poll_op::read); + if (poll_status != glz::poll_status::event) { + if (glz::poll_status::closed == glz::poll_status::event) { + std::cerr << "Error on: co_await client.poll(glz::poll_op::read): client Id, " << client.socket->socket_fd + << ", the socket is closed.\n"; + } + else { + std::cerr << "Error on: co_await client.poll(glz::poll_op::read): client Id, " << client.socket->socket_fd + << ".\nDetails: " << glz::nameof(poll_status) << '\n'; + } + co_return; // Handle error. + } + + // Prepare a buffer and recv() the client's message. This function returns the recv() status + // as well as a span that overlaps the given buffer for the bytes that were read. This + // can be used to resize the buffer or work with the bytes without modifying the buffer at all. + std::string request(256, '\0'); + auto [ip_status, recv_bytes] = client.recv(request); + if (ip_status != glz::ip_status::ok) { + std::cerr << "client::recv error:\n" + << "Details: " << glz::nameof(poll_status) << '\n'; + co_return; // Handle error, see net::ip_status for detailed error states. + } + + request.resize(recv_bytes.size()); + std::cout << "server: " << request << "\n"; + + // Make sure the client socket can be written to. + poll_status = co_await client.poll(glz::poll_op::write); + if (poll_status != glz::poll_status::event) { + std::cerr << "Error on: co_await client.poll(glz::poll_op::write): client Id" << client.socket->socket_fd + << ".\nDetails: " << glz::nameof(poll_status) << '\n'; + co_return; // Handle error. + } + + // Send the server response to the client. + // This message is small enough that it will be sent in a single send() call, but to demonstrate + // how to use the 'remaining' portion of the send() result this is wrapped in a loop until + // all the bytes are sent. + std::string response = "Hello from server."; + std::span remaining = response; + do { + // Optimistically send() prior to polling. + auto [ips, r] = client.send(remaining); + if (ips != glz::ip_status::ok) { + co_return; // Handle error, see net::ip_status for detailed error states. + } + + if (r.empty()) { + break; // The entire message has been sent. + } + + // Re-assign remaining bytes for the next loop iteration and poll for the socket to be + // able to be written to again. + remaining = r; + poll_status = co_await client.poll(glz::poll_op::write); + if (poll_status != glz::poll_status::event) { + co_return; // Handle error. + } + } while (true); + + co_return; + }; + + auto make_client_task = [&]() -> glz::task { + // Immediately schedule onto the scheduler. + co_await scheduler->schedule(); + + // Create the tcp::client with the default settings, see tcp::client for how to set the + // ip address, port, and optionally enabling SSL/TLS. + glz::client client{scheduler}; + + // Ommitting error checking code for the client, each step should check the status and + // verify the number of bytes sent or received. + + // Connect to the server. + if (auto ip_status = co_await client.connect(std::chrono::milliseconds(100)); ip_status != glz::ip_status::ok) { + std::cerr << "ip_status: " << glz::nameof(ip_status) << '\n'; + } + + // Make sure the client socket can be written to. + if (auto status = co_await client.poll(glz::poll_op::write); bool(status)) { + std::cerr << "poll_status: " << glz::nameof(status) << '\n'; + } + + // Send the request data. + client.send(std::string_view{"Hello from client."}); + + // Wait for the response and receive it. + co_await client.poll(glz::poll_op::read); + std::string response(256, '\0'); + auto [ip_status, recv_bytes] = client.recv(response); + response.resize(recv_bytes.size()); + + std::cout << "client id " << client.socket->socket_fd << ", recieved: " << response << '\n'; + co_return; + }; + + // Create and wait for the server and client tasks to complete. + glz::sync_wait(glz::when_all(make_server_task(), make_client_task())); +}; +#endif + +/*Old with co-routines +template +exec::task async_answer(S1 s1, S2 s2) +{ + // Senders are implicitly awaitable (in this coroutine type): + co_await static_cast(s2); + co_return co_await static_cast(s1); +} + +template +exec::task> async_answer2(S1 s1, S2 s2) +{ + co_return co_await stdexec::stopped_as_optional(async_answer(s1, s2)); +} + +// tasks have an associated stop token +exec::task> async_stop_token() +{ + co_return co_await stdexec::stopped_as_optional(stdexec::get_stop_token()); +} + +suite stdexec_coroutine_test = [] { + try { + // Awaitables are implicitly senders: + auto [i] = stdexec::sync_wait(async_answer2(stdexec::just(42), stdexec::just())).value(); + std::cout << "The answer is " << i.value() << '\n'; + } + catch (std::exception& e) { + std::cout << e.what() << '\n'; + } +}; +*/ + +template +auto async_answer(S1&& s1, S2&& s2) +{ + return stdexec::let_value(std::forward(s2), [s1 = std::forward(s1)]() mutable { return std::move(s1); }); +} + +template +auto async_answer2(S1&& s1, S2&& s2) +{ + return stdexec::stopped_as_optional(async_answer(std::forward(s1), std::forward(s2))); +} + +inline auto async_stop_token(exec::async_scope& scope) { return stdexec::just(scope.get_stop_token()); } + +suite stdexec_sender_composition_test = [] { + exec::static_thread_pool pool(4); + auto scheduler = pool.get_scheduler(); + + auto s1 = stdexec::just(42); + auto s2 = stdexec::just(); + + auto result = stdexec::sync_wait(stdexec::on(scheduler, async_answer2(std::move(s1), std::move(s2)))); + + if (result) { + if (auto& value_opt = std::get<0>(*result)) { + std::cout << "Result: " << *value_opt << std::endl; + } + else { + std::cout << "Operation was stopped" << std::endl; + } + } + else { + std::cout << "Operation failed" << std::endl; + } + + exec::async_scope scope; + auto stop_token_result = stdexec::sync_wait(stdexec::on(scheduler, async_stop_token(scope))); + if (stop_token_result) { + std::cout << "Stop token obtained" << std::endl; + } + else { + std::cout << "Stop token operation failed" << std::endl; + } +}; + +int main() +{ + std::cout << '\n'; + return 0; +} diff --git a/tests/repe_server_client/CMakeLists.txt b/tests/repe_server_client/CMakeLists.txt new file mode 100644 index 0000000000..ff08e8207a --- /dev/null +++ b/tests/repe_server_client/CMakeLists.txt @@ -0,0 +1,4 @@ +project(repe_server_client) + +add_subdirectory(server) +add_subdirectory(client) diff --git a/tests/asio_repe/client/CMakeLists.txt b/tests/repe_server_client/client/CMakeLists.txt similarity index 67% rename from tests/asio_repe/client/CMakeLists.txt rename to tests/repe_server_client/client/CMakeLists.txt index 67ac43d3d4..19f9c0b70b 100644 --- a/tests/asio_repe/client/CMakeLists.txt +++ b/tests/repe_server_client/client/CMakeLists.txt @@ -2,7 +2,7 @@ project(repe_client) add_executable(${PROJECT_NAME} ${PROJECT_NAME}.cpp) -target_include_directories(${PROJECT_NAME} PRIVATE include ${asio_SOURCE_DIR}/asio/include) +target_include_directories(${PROJECT_NAME} PRIVATE include) target_link_libraries(${PROJECT_NAME} PRIVATE glz_test_exceptions) target_code_coverage(${PROJECT_NAME} AUTO ALL) \ No newline at end of file diff --git a/tests/asio_repe/client/repe_client.cpp b/tests/repe_server_client/client/repe_client.cpp similarity index 72% rename from tests/asio_repe/client/repe_client.cpp rename to tests/repe_server_client/client/repe_client.cpp index 6295935451..40e1907d15 100644 --- a/tests/asio_repe/client/repe_client.cpp +++ b/tests/repe_server_client/client/repe_client.cpp @@ -3,23 +3,24 @@ #include -#include "glaze/ext/glaze_asio.hpp" -#include "glaze/glaze.hpp" -#include "glaze/rpc/repe.hpp" +#include "glaze/network/repe_client.hpp" void asio_client_test() { try { constexpr auto N = 100; - std::vector> clients; + std::vector> clients; clients.reserve(N); std::vector> threads; threads.reserve(N); for (size_t i = 0; i < N; ++i) { - clients.emplace_back(glz::asio_client<>{"localhost", "8080"}); + clients.emplace_back(glz::repe_client<>{"127.0.0.1", 8080}); } + + std::mutex mtx{}; + std::vector results{}; for (size_t i = 0; i < N; ++i) { threads.emplace_back(std::async([&, i] { @@ -38,11 +39,13 @@ void asio_client_test() } int sum{}; - if (auto e_call = client.call({"/sum"}, data, sum); e_call) { + if (auto e_call = client.call({"/sum"}, data, sum)) { std::cerr << glz::write_json(e_call).value_or("error") << '\n'; } else { + std::unique_lock lock{mtx}; std::cout << "i: " << i << ", " << sum << '\n'; + results.emplace_back(sum); } })); } @@ -50,6 +53,12 @@ void asio_client_test() for (auto& t : threads) { t.get(); } + + for (auto v : results) { + if (v != 4950) { + std::abort(); + } + } } catch (const std::exception& e) { std::cerr << e.what() << '\n'; diff --git a/tests/asio_repe/server/CMakeLists.txt b/tests/repe_server_client/server/CMakeLists.txt similarity index 67% rename from tests/asio_repe/server/CMakeLists.txt rename to tests/repe_server_client/server/CMakeLists.txt index 0ee53d4889..50dfb91273 100644 --- a/tests/asio_repe/server/CMakeLists.txt +++ b/tests/repe_server_client/server/CMakeLists.txt @@ -2,7 +2,7 @@ project(repe_server) add_executable(${PROJECT_NAME} ${PROJECT_NAME}.cpp) -target_include_directories(${PROJECT_NAME} PRIVATE include ${asio_SOURCE_DIR}/asio/include) +target_include_directories(${PROJECT_NAME} PRIVATE include) target_link_libraries(${PROJECT_NAME} PRIVATE glz_test_exceptions) target_code_coverage(${PROJECT_NAME} AUTO ALL) \ No newline at end of file diff --git a/tests/asio_repe/server/repe_server.cpp b/tests/repe_server_client/server/repe_server.cpp similarity index 63% rename from tests/asio_repe/server/repe_server.cpp rename to tests/repe_server_client/server/repe_server.cpp index 89056171c0..12a731d55c 100644 --- a/tests/asio_repe/server/repe_server.cpp +++ b/tests/repe_server_client/server/repe_server.cpp @@ -1,26 +1,24 @@ // Glaze Library // For the license information refer to glaze.hpp -#include "glaze/ext/glaze_asio.hpp" -#include "glaze/glaze.hpp" -#include "glaze/rpc/repe.hpp" +#include "glaze/network/repe_server.hpp" struct api { std::function& vec)> sum = [](std::vector& vec) { return std::reduce(vec.begin(), vec.end()); }; - std::function& vec)> max = [](std::vector& vec) { return std::ranges::max(vec); }; + std::function& vec)> maximum = [](std::vector& vec) { return (std::ranges::max)(vec); }; }; #include -void run_server() +int main() { std::cout << "Server active...\n"; try { - glz::asio_server<> server{.port = 8080}; + glz::repe_server<> server{.port = 8080, .print_errors = true}; api methods{}; server.on(methods); server.run(); @@ -30,11 +28,6 @@ void run_server() } std::cout << "Server closed...\n"; -} - -int main() -{ - run_server(); return 0; } diff --git a/tests/socket_test/CMakeLists.txt b/tests/socket_test/CMakeLists.txt new file mode 100644 index 0000000000..8685725812 --- /dev/null +++ b/tests/socket_test/CMakeLists.txt @@ -0,0 +1,9 @@ +project(socket_test) + +add_executable(${PROJECT_NAME} ${PROJECT_NAME}.cpp) + +target_link_libraries(${PROJECT_NAME} PRIVATE glz_test_common) + +add_test(NAME ${PROJECT_NAME} COMMAND ${PROJECT_NAME}) + +target_code_coverage(${PROJECT_NAME} AUTO ALL) diff --git a/tests/socket_test/socket_test.cpp b/tests/socket_test/socket_test.cpp new file mode 100644 index 0000000000..c5fbdbf58e --- /dev/null +++ b/tests/socket_test/socket_test.cpp @@ -0,0 +1,125 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#define UT_RUN_TIME_ONLY + +#include "glaze/network/server.hpp" +#include "glaze/network/socket_io.hpp" + +#include +#include +#include +#include + +#include "ut/ut.hpp" + +using namespace ut; + +constexpr bool user_input = false; + +constexpr auto n_clients = 10; +constexpr auto service_0_port{8080}; +constexpr auto service_0_ip{"127.0.0.1"}; + +// std::latch is broken on MSVC: +// std::latch working_clients{n_clients}; +static std::atomic_int working_clients{n_clients}; + +glz::windows_socket_startup_t<> wsa; // wsa_startup (ignored on macOS and Linux) + +glz::server server{.port = service_0_port}; + +suite make_server = [] { + std::cout << std::format("Server started on port: {}\n", server.port); + + const auto future = server.async_accept([](glz::socket&& client, auto& active) { + std::cout << "New client connected!\n"; + + if (auto ec = glz::send(client, "Welcome!", std::string{})) { + std::cerr << ec.message() << '\n'; + return; + } + + while (active) { + std::string received{}; + if (auto ec = glz::receive(client, received, std::string{}, 5000)) { + std::cerr << ec.message() << '\n'; + return; + } + std::cout << std::format("Server: {}\n", received); + std::ignore = glz::send(client, std::format("Hello to {} from server.\n", received), std::string{}); + } + }); + + if (future.wait_for(std::chrono::milliseconds(10)) == std::future_status::ready) { + std::cerr << future.get().message() << '\n'; + } +}; + +suite socket_test = [] { + std::vector sockets(n_clients); + std::vector> threads(n_clients); + for (size_t id{}; id < n_clients; ++id) { + threads.emplace_back(std::async([id, &sockets] { + glz::socket& socket = sockets[id]; + + if (connect(socket, service_0_ip, service_0_port)) { + std::cerr << std::format("Failed to connect to server.\nDetails: {}\n", glz::get_socket_error().message()); + } + else { + std::string received{}; + if (auto ec = glz::receive(socket, received, std::string{}, 100)) { + std::cerr << ec.message() << '\n'; + return; + } + std::cout << std::format("Received from server: {}\n", received); + + size_t tick{}; + std::string result; + while (tick < 3) { + if (auto ec = glz::send(socket, std::format("Client {}, {}", id, tick), std::string{})) { + + std::cerr << ec.message() << '\n'; + return; + } + if (auto ec = glz::receive(socket, result, std::string{}, 100)) { + continue; + } + else { + expect(result.size() > 0); + std::cout << result; + } + + + std::this_thread::sleep_for(std::chrono::seconds(2)); + ++tick; + } + // working_clients.count_down(); + --working_clients; + } + })); + } + + // working_clients.arrive_and_wait(); + while (working_clients) std::this_thread::sleep_for(std::chrono::milliseconds(1)); + + if constexpr (user_input) { + std::cout << "\nFinished! Press any key to exit."; + std::cin.get(); + } + + server.active = false; +}; + +int main() +{ + std::signal(SIGINT, [](int) { + server.active = false; + std::exit(0); + }); + + // GCC needs this sleep + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + return 0; +}