diff --git a/.gitmodules b/.gitmodules index e3b3ed5e..9a1f96ad 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,15 +10,9 @@ [submodule "external/zstd"] path = external/zstd url = https://github.com/facebook/zstd.git -[submodule "external/nlohmann-json"] - path = external/nlohmann-json - url = https://github.com/nlohmann/json.git -[submodule "external/oxen-libquic"] - path = external/oxen-libquic - url = https://github.com/oxen-io/oxen-libquic.git [submodule "external/protobuf"] path = external/protobuf url = https://github.com/protocolbuffers/protobuf.git -[submodule "external/oxen-logging"] - path = external/oxen-logging - url = https://github.com/oxen-io/oxen-logging.git +[submodule "external/lokinet"] + path = external/lokinet + url = https://github.com/oxen-io/lokinet.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 2271b4bc..e27165f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,7 +82,7 @@ option(STATIC_LIBSTD "Statically link libstdc++/libgcc" ${default_static_libstd} option(USE_LTO "Use Link-Time Optimization" ${use_lto_default}) # Provide this as an option for now because GMP and Desktop are sometimes unhappy with each other. -option(ENABLE_ONIONREQ "Build with onion request functionality" ON) +option(ENABLE_NETWORKING "Build with networking functionality" ON) if(USE_LTO) include(CheckIPOSupported) @@ -119,7 +119,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory(external) -if(ENABLE_ONIONREQ) +if(ENABLE_NETWORKING) if(NOT TARGET nettle::nettle) if(BUILD_STATIC_DEPS) message(FATAL_ERROR "Internal error: nettle::nettle target (expected via libquic BUILD_STATIC_DEPS) not found") diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 5506fd41..e939190c 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -1,4 +1,5 @@ option(SUBMODULE_CHECK "Enables checking that vendored library submodules are up to date" ON) + if(SUBMODULE_CHECK) find_package(Git) if(GIT_FOUND) @@ -26,8 +27,7 @@ if(SUBMODULE_CHECK) message(STATUS "Checking submodules") check_submodule(ios-cmake) check_submodule(libsodium-internal) - check_submodule(oxen-libquic external/oxen-logging external/oxen-encoding) - check_submodule(nlohmann-json) + check_submodule(lokinet) check_submodule(zstd) check_submodule(protobuf) endif() @@ -37,28 +37,30 @@ if(NOT BUILD_STATIC_DEPS AND NOT FORCE_ALL_SUBMODULES) find_package(PkgConfig REQUIRED) endif() -macro(libsession_system_or_submodule BIGNAME smallname pkgconf subdir) - option(FORCE_${BIGNAME}_SUBMODULE "force using ${smallname} submodule" OFF) - if(NOT BUILD_STATIC_DEPS AND NOT FORCE_${BIGNAME}_SUBMODULE AND NOT FORCE_ALL_SUBMODULES) - pkg_check_modules(${BIGNAME} ${pkgconf} IMPORTED_TARGET GLOBAL) - endif() - if(${BIGNAME}_FOUND) - add_library(${smallname} INTERFACE IMPORTED GLOBAL) - if(NOT TARGET PkgConfig::${BIGNAME} AND CMAKE_VERSION VERSION_LESS "3.21") - # Work around cmake bug 22180 (PkgConfig::THING not set if no flags needed) +macro(libsession_system_or_submodule BIGNAME smallname target pkgconf subdir) + if(NOT TARGET ${target}) + option(FORCE_${BIGNAME}_SUBMODULE "force using ${smallname} submodule" OFF) + if(NOT BUILD_STATIC_DEPS AND NOT FORCE_${BIGNAME}_SUBMODULE AND NOT FORCE_ALL_SUBMODULES) + pkg_check_modules(${BIGNAME} ${pkgconf} IMPORTED_TARGET GLOBAL) + endif() + if(${BIGNAME}_FOUND) + add_library(${smallname} INTERFACE IMPORTED GLOBAL) + if(NOT TARGET PkgConfig::${BIGNAME} AND CMAKE_VERSION VERSION_LESS "3.21") + # Work around cmake bug 22180 (PkgConfig::THING not set if no flags needed) + else() + target_link_libraries(${smallname} INTERFACE PkgConfig::${BIGNAME}) + endif() + message(STATUS "Found system ${smallname} ${${BIGNAME}_VERSION}") else() - target_link_libraries(${smallname} INTERFACE PkgConfig::${BIGNAME}) + message(STATUS "using ${smallname} submodule ${subdir}") + add_subdirectory(${subdir}) endif() - message(STATUS "Found system ${smallname} ${${BIGNAME}_VERSION}") - else() - message(STATUS "using ${smallname} submodule") - add_subdirectory(${subdir}) - endif() - if(TARGET ${smallname} AND NOT TARGET ${smallname}::${smallname}) - add_library(${smallname}::${smallname} ALIAS ${smallname}) - endif() - if(BUILD_STATIC_DEPS AND STATIC_BUNDLE) + if(NOT TARGET ${target}) + add_library(${target} ALIAS ${smallname}) + endif() + if(BUILD_STATIC_DEPS AND STATIC_BUNDLE) libsession_static_bundle(${smallname}::${smallname}) + endif() endif() endmacro() @@ -100,31 +102,25 @@ if(CMAKE_CROSSCOMPILING) endif() endif() -set(LIBQUIC_BUILD_TESTS OFF CACHE BOOL "") -if(ENABLE_ONIONREQ) - libsession_system_or_submodule(OXENQUIC quic liboxenquic>=1.3.0 oxen-libquic) +if(ENABLE_NETWORKING) + set(LIBQUIC_BUILD_TESTS OFF CACHE BOOL "") + libsession_system_or_submodule(OXENQUIC quic oxen::quic liboxenquic>=1.6 lokinet/external/oxen-libquic) endif() -if(NOT TARGET oxenc::oxenc) - # The oxenc target will already exist if we load libquic above via submodule - set(OXENC_BUILD_TESTS OFF CACHE BOOL "") - set(OXENC_BUILD_DOCS OFF CACHE BOOL "") - libsession_system_or_submodule(OXENC oxenc liboxenc>=1.3.0 oxen-libquic/external/oxen-encoding) -endif() +libsession_system_or_submodule(OXENC oxenc oxenc::oxenc liboxenc>=1.5.0 lokinet/external/oxen-libquic/external/oxen-encoding) -if(NOT TARGET oxen::logging) - add_subdirectory(oxen-libquic/external/oxen-logging) +libsession_system_or_submodule(OXENLOGGING oxenlogging oxen::logging liboxen-logging>=1.2.0 lokinet/external/oxen-libquic/external/oxen-logging) +if(OXENLOGGING_FOUND) + # If we load oxen-logging via system lib then we won't necessarily have fmt/spdlog targets, + # but this script will give us them: + include(lokinet/external/oxen-libquic/external/oxen-logging/cmake/load_fmt_spdlog.cmake) + + add_library(oxen-logging-fmt-spdlog INTERFACE) + target_link_libraries(oxen-logging-fmt-spdlog INTERFACE oxenlogging::oxenlogging ${OXEN_LOGGING_FMT_TARGET} ${OXEN_LOGGING_SPDLOG_TARGET}) + add_library(oxen::logging ALIAS oxen-logging-fmt-spdlog) endif() -oxen_logging_add_source_dir("${PROJECT_SOURCE_DIR}") -# Apple xcode 15 has a completely broken std::source_location; we can't fix it, but at least we can -# hack up the source locations to hide the path that it uses (which is the useless path to -# oxen/log.hpp where the info/critical/etc. bodies are). -if(APPLE AND CMAKE_CXX_COMPILER_ID STREQUAL AppleClang AND NOT CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 16) - message(WARNING "${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION} is broken: filenames in logging statements will not display properly") - oxen_logging_add_source_dir("${CMAKE_CURRENT_SOURCE_DIR}/oxen-libquic/external/oxen-logging/include/oxen") -endif() if(CMAKE_C_COMPILER_LAUNCHER) set(deps_cc "${CMAKE_C_COMPILER_LAUNCHER} ${deps_cc}") @@ -144,12 +140,12 @@ if(APPLE) endforeach() endif() - -function(libsodium_internal_subdir) +function(add_static_subdirectory dir) set(BUILD_SHARED_LIBS OFF) - add_subdirectory(libsodium-internal) + add_subdirectory(${dir} ${ARGN}) endfunction() -libsodium_internal_subdir() + +add_static_subdirectory(libsodium-internal) libsession_static_bundle(libsodium::sodium-internal) @@ -163,7 +159,7 @@ set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) set(protobuf_ABSL_PROVIDER "module" CACHE STRING "" FORCE) set(protobuf_BUILD_PROTOC_BINARIES OFF CACHE BOOL "") set(protobuf_BUILD_PROTOBUF_BINARIES ON CACHE BOOL "" FORCE) -libsession_system_or_submodule(PROTOBUF_LITE protobuf_lite protobuf-lite>=3.21 protobuf) +libsession_system_or_submodule(PROTOBUF_LITE protobuf_lite protobuf::libprotobuf-lite protobuf-lite>=3.21 protobuf) if(TARGET PkgConfig::PROTOBUF_LITE AND NOT TARGET protobuf::libprotobuf-lite) add_library(protobuf::libprotobuf-lite ALIAS PkgConfig::PROTOBUF_LITE) endif() @@ -192,4 +188,16 @@ libsession_static_bundle(libzstd_static) set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install ON CACHE INTERNAL "") # Required to export targets that we use -libsession_system_or_submodule(NLOHMANN nlohmann_json nlohmann_json>=3.7.0 nlohmann-json) +libsession_system_or_submodule(NLOHMANN nlohmann_json nlohmann_json::nlohmann_json nlohmann_json>=3.7.0 lokinet/external/nlohmann) + +if(ENABLE_NETWORKING) + set(LOKINET_FULL OFF CACHE BOOL "") + set(LOKINET_DAEMON OFF CACHE BOOL "") + set(LOKINET_NATIVE_BUILD OFF CACHE BOOL "") + set(LOKINET_JEMALLOC OFF CACHE BOOL "") + + add_library(sodium INTERFACE) + target_link_libraries(sodium INTERFACE libsodium::sodium-internal) + add_static_subdirectory(lokinet EXCLUDE_FROM_ALL) + libsession_static_bundle(lokinet::liblokinet) +endif() diff --git a/external/lokinet b/external/lokinet new file mode 160000 index 00000000..15b0608a --- /dev/null +++ b/external/lokinet @@ -0,0 +1 @@ +Subproject commit 15b0608a87f32ddecf75c92e73dd23b8c6991f51 diff --git a/external/nlohmann-json b/external/nlohmann-json deleted file mode 160000 index 9cca280a..00000000 --- a/external/nlohmann-json +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 diff --git a/external/oxen-libquic b/external/oxen-libquic deleted file mode 160000 index 793bf5be..00000000 --- a/external/oxen-libquic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 793bf5be12dc26ae08a4585b7bff4e0b0d23e278 diff --git a/external/oxen-logging b/external/oxen-logging deleted file mode 160000 index 6ae91a24..00000000 --- a/external/oxen-logging +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6ae91a2417c4a9e55a3b312ba4b43019a13f003b diff --git a/include/session/file.hpp b/include/session/file.hpp index a2bf747d..a9e91170 100644 --- a/include/session/file.hpp +++ b/include/session/file.hpp @@ -3,6 +3,7 @@ #include #include #include +#include // Utility functions for working with files @@ -19,8 +20,8 @@ std::ofstream open_for_writing(const fs::path& filename); /// enabled for any failures. This also throws if the file cannot be opened. std::ifstream open_for_reading(const fs::path& filename); -/// Reads a (binary) file from disk into the string `contents`. -std::string read_whole_file(const fs::path& filename); +/// Reads a (binary) file from disk. +std::vector read_whole_file(const fs::path& filename); /// Dumps (binary) string contents to disk. The file is overwritten if it already exists. void write_whole_file(const fs::path& filename, std::string_view contents = ""); diff --git a/include/session/network/backends/session_file_server.h b/include/session/network/backends/session_file_server.h new file mode 100644 index 00000000..3256248d --- /dev/null +++ b/include/session/network/backends/session_file_server.h @@ -0,0 +1,78 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#include "session/export.h" +#include "session/network/session_network_types.h" +#include "session/onionreq/builder.h" +#include "session/platform.h" + +/// API: file_server/session_file_server_upload +/// +/// Constructs a request to upload a file to the session file server. +/// +/// Inputs: +/// - `data` -- [in] data to upload to the file server. +/// - `data_len` -- [in] size of the `data`. +/// - `file_name` -- [in, optional] name of the file being uploaded. MUST be null terminated. +/// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take any +/// pre-flight operations into account so the request will never timeout if pre-flight operations +/// never complete. +/// - `overall_timeout` -- [in] timeout in milliseconds to use for the request and any pre-flight +/// operations that may need to occur (eg. path building). This value takes presedence over +/// `request_timeout` if provided, the request itself will be given a timeout of this value +/// subtracting however long the pre-flight operations took. +LIBSESSION_EXPORT session_request_params* session_file_server_upload( + const unsigned char* data, + size_t data_len, + const char* file_name, + int64_t request_timeout_ms, + int64_t overall_timeout_ms); + +/// API: network/session_file_server_download +/// +/// Constructs a request to download a file from the session file server. +/// +/// Inputs: +/// - `file_id` -- [in] the id of the file to download, NULL terminated. +/// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take any +/// pre-flight operations into account so the request will never timeout if pre-flight operations +/// never complete. +/// - `overall_timeout` -- [in] timeout in milliseconds to use for the request and any pre-flight +/// operations that may need to occur (eg. path building). This value takes presedence over +/// `request_timeout` if provided, the request itself will be given a timeout of this value +/// subtracting however long the pre-flight operations took. +LIBSESSION_EXPORT session_request_params* session_file_server_download( + const char* file_id, int64_t request_timeout_ms, int64_t overall_timeout_ms); + +/// API: network/session_file_server_get_client_version +/// +/// Constructs a request to retrieve the version information for the given platform. +/// +/// Inputs: +/// - `platform` -- [in] the platform to retrieve the client version for. +/// - `ed25519_secret` -- [in] the users ed25519 secret key (used for blinded auth - 64 bytes). +/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take +/// the path build into account so if the path build takes forever then this request will never +/// timeout. +/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and +/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, +/// the request itself will be given a timeout of this value subtracting however long it took to +/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. +/// - `callback` -- [in] callback to be called with the result of the request. +/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set +/// to NULL if unused. +LIBSESSION_EXPORT session_request_params* session_file_server_get_client_version( + CLIENT_PLATFORM platform, + const unsigned char* ed25519_secret, /* 64 bytes */ + int64_t request_timeout_ms, + int64_t overall_timeout_ms); + +#ifdef __cplusplus +} +#endif diff --git a/include/session/network/backends/session_file_server.hpp b/include/session/network/backends/session_file_server.hpp new file mode 100644 index 00000000..4c1600d6 --- /dev/null +++ b/include/session/network/backends/session_file_server.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include "session/network/key_types.hpp" +#include "session/network/session_network_types.hpp" +#include "session/platform.hpp" + +namespace session::network::file_server { + +/// API: file_server/upload +/// +/// Constructs a request to upload a file to the session file server. +/// +/// Inputs: +/// - 'data' - [in] the data to be uploaded to a server. +/// - `file_name` -- [in, optional] optional name to use for the file. +/// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take any +/// pre-flight operations into account so the request will never timeout if pre-flight operations +/// never complete. +/// - `overall_timeout` -- [in] timeout in milliseconds to use for the request and any pre-flight +/// operations that may need to occur (eg. path building). This value takes presedence over +/// `request_timeout` if provided, the request itself will be given a timeout of this value +/// subtracting however long the pre-flight operations took. +Request upload( + std::vector data, + std::optional file_name, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout = std::nullopt); + +/// API: file_server/download +/// +/// Constructs a request to download a file from the session file server. +/// +/// Inputs: +/// - `file_id` -- [in] the id of the file to download. +/// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take any +/// pre-flight operations into account so the request will never timeout if pre-flight operations +/// never complete. +/// - `overall_timeout` -- [in] timeout in milliseconds to use for the request and any pre-flight +/// operations that may need to occur (eg. path building). This value takes presedence over +/// `request_timeout` if provided, the request itself will be given a timeout of this value +/// subtracting however long the pre-flight operations took. +Request download( + std::string file_id, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout = std::nullopt); + +/// API: file_server/get_client_version +/// +/// Constructs a request to retrieve the version information for the given platform. +/// +/// Inputs: +/// - `platform` -- [in] the platform to retrieve the client version for. +/// - `seckey` -- [in] the users ed25519 secret key (to generated blinded auth). +/// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take any +/// pre-flight operations into account so the request will never timeout if pre-flight operations +/// never complete. +/// - `overall_timeout` -- [in] timeout in milliseconds to use for the request and any pre-flight +/// operations that may need to occur (eg. path building). This value takes presedence over +/// `request_timeout` if provided, the request itself will be given a timeout of this value +/// subtracting however long the pre-flight operations took. +Request get_client_version( + Platform platform, + network::ed25519_seckey seckey, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout = std::nullopt); + +} // namespace session::network::file_server diff --git a/include/session/onionreq/key_types.hpp b/include/session/network/key_types.hpp similarity index 87% rename from include/session/onionreq/key_types.hpp rename to include/session/network/key_types.hpp index 5f22c71c..d3d2b1ef 100644 --- a/include/session/onionreq/key_types.hpp +++ b/include/session/network/key_types.hpp @@ -14,7 +14,7 @@ #include "../types.hpp" #include "../util.hpp" -namespace session::onionreq { +namespace session::network { using namespace std::literals; @@ -100,12 +100,12 @@ ed25519_pubkey parse_ed25519_pubkey(std::string_view pubkey_in); x25519_pubkey parse_x25519_pubkey(std::string_view pubkey_in); x25519_pubkey compute_x25519_pubkey(std::span ed25519_pk); -} // namespace session::onionreq +} // namespace session::network namespace std { template -struct hash> { - size_t operator()(const session::onionreq::pubkey_base& pk) const { +struct hash> { + size_t operator()(const session::network::pubkey_base& pk) const { // pubkeys are already random enough to use the first bytes directly as a good (and fast) // hash value static_assert(alignof(decltype(pk)) >= alignof(size_t)); @@ -114,13 +114,11 @@ struct hash> { }; template <> -struct hash : hash { -}; +struct hash : hash {}; template <> -struct hash : hash { -}; +struct hash : hash {}; template <> -struct hash - : hash {}; +struct hash : hash { +}; } // namespace std diff --git a/include/session/network/network_config.hpp b/include/session/network/network_config.hpp new file mode 100644 index 00000000..21cb2825 --- /dev/null +++ b/include/session/network/network_config.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include +#include + +#include "session/network/network_opt.hpp" +#include "session/types.hpp" + +namespace session::network::config { + +using namespace std::chrono_literals; +namespace fs = std::filesystem; + +struct Config { + public: + opt::netid::Target netid = opt::netid::Target::mainnet; + opt::router::Type router = opt::router::Type::onion_requests; + opt::transport::Type transport = opt::transport::Type::quic; + uint8_t path_length = 3; + bool enforce_subnet_diversity = true; + uint8_t redirect_retry_count = 1; + opt::retry_delay retry_delay = opt::retry_delay(200ms, 5s); + std::chrono::milliseconds request_timeout_check_frequency = 250ms; + + // Netid Options + std::vector seed_nodes; + + // Snode Pool Options + std::optional cache_directory; + std::chrono::minutes cache_expiration = 2h; + std::chrono::milliseconds cache_min_lifetime = 2s; + size_t cache_min_size = 12; + uint8_t cache_num_nodes_to_use_for_refresh = 3; + uint8_t cache_node_failure_threshold = 3; + bool cache_refresh_using_legacy_endpoint = false; + + // Onion Request Router Options + uint8_t onionreq_path_failure_threshold = 3; + uint8_t onionreq_path_build_retry_limit = 10; + std::unordered_map onionreq_min_path_counts = { + {RequestCategory::standard, 2}, + {RequestCategory::download, 2}, + {RequestCategory::upload, 2}}; + bool onionreq_single_path_mode = false; + bool onionreq_disable_pre_build_paths = false; + + // Quic Transport Options + std::chrono::milliseconds quic_handshake_timeout{3s}; + std::chrono::seconds quic_keep_alive{10s}; + bool quic_disable_mtu_discovery = false; + + // Callback Transport Options + std::optional callbacks_callback; + + template + requires( + sizeof...(Opt) > 0 && + std::conjunction_v>...>) + Config(Opt&&... opts) { + // parse all options + ((void)handle_config_opt(std::forward(opts)), ...); + _init(); + } + explicit Config(const std::vector& opts); + + Config() = default; + Config(const Config&) = default; + Config(Config&&) = default; + Config& operator=(const Config&) = default; + Config& operator=(Config&&) = default; + ~Config() = default; + + private: + void _init(); + + void handle_config_opt(opt::netid netid); + void handle_config_opt(opt::router router); + void handle_config_opt(opt::transport transport); + void handle_config_opt(opt::path_length pl); + void handle_config_opt(opt::disable_subnet_diversity dsd); + void handle_config_opt(opt::redirect_retry_count rrc); + void handle_config_opt(opt::retry_delay rd); + void handle_config_opt(opt::request_timeout_check_frequency rtcf); + + // Snode pool options + void handle_config_opt(opt::cache_directory dir); + void handle_config_opt(opt::cache_expiration ce); + void handle_config_opt(opt::cache_min_lifetime mcl); + void handle_config_opt(opt::cache_min_size mcs); + void handle_config_opt(opt::cache_num_nodes_to_use_for_refresh nnr); + void handle_config_opt(opt::cache_node_failure_threshold nft); + void handle_config_opt(opt::cache_refresh_using_legacy_endpoint rule); + + // Quic transport options + void handle_config_opt(opt::quic_handshake_timeout qht); + void handle_config_opt(opt::quic_keep_alive qka); + void handle_config_opt(opt::quic_disable_mtu_discovery qdmd); + + // Onion request router options + void handle_config_opt(opt::onionreq_path_failure_threshold pft); + void handle_config_opt(opt::onionreq_path_build_retry_limit pbrl); + void handle_config_opt(opt::onionreq_min_path_count mpc); + void handle_config_opt(opt::onionreq_single_path_mode spm); + void handle_config_opt(opt::onionreq_disable_pre_build_paths dpbp); + + template + void handle_config_opt(std::optional option) { + if (option) + handle_config_opt(std::move(*option)); + } +}; + +} // namespace session::network::config diff --git a/include/session/network/network_opt.hpp b/include/session/network/network_opt.hpp new file mode 100644 index 00000000..39ec6fb7 --- /dev/null +++ b/include/session/network/network_opt.hpp @@ -0,0 +1,353 @@ +#pragma once + +#include + +#include "session/network/service_node.hpp" +#include "session/network/session_network_types.hpp" +#include "session/types.hpp" + +namespace session::network { +class Endpoint; +class Stream; + +namespace opt { + namespace fs = std::filesystem; + using namespace std::chrono_literals; + + namespace { + inline std::vector from_hex(std::string_view s) { + std::vector out; + out.reserve(s.size() / 2); + oxenc::from_hex(s.begin(), s.end(), std::back_inserter(out)); + + return out; + } + } // namespace + + struct base {}; + + /// Can be used to override the default (mainnet) netid that the network will populate it's + /// internal caches from, 'devnet' allows for specifying a custom server. + struct netid : base { + enum class Target { + mainnet, + testnet, + devnet, + }; + + Target target; + std::vector seed_nodes; + + private: + explicit netid(Target t, std::vector seed_nodes = {}) : + target{t}, seed_nodes{std::move(seed_nodes)} {} + + public: + netid() = delete; + + static netid mainnet() { + auto seed_nodes = { + service_node{ + from_hex("1f000f09a7b07828dcb72af7cd16857050c10c02bd58afb0e38111fb6cda1" + "fef"), + oxen::quic::ipv4{"95.216.33.113"}, + uint16_t{22100}, + uint16_t{20200}, + {2, 11, 0}, + swarm::INVALID_SWARM_ID}, + service_node{ + from_hex("1f101f0acee4db6f31aaa8b4df134e85ca8a4878efaef7f971e88ab144c1a" + "7ce"), + oxen::quic::ipv4{"37.27.236.229"}, + uint16_t{22101}, + uint16_t{20201}, + {2, 11, 0}, + swarm::INVALID_SWARM_ID}, + service_node{ + from_hex("1f202f00f4d2d4acc01e20773999a291cf3e3136c325474d159814e061999" + "19f"), + oxen::quic::ipv4{"172.96.140.124"}, + uint16_t{22102}, + uint16_t{20202}, + {2, 11, 0}, + swarm::INVALID_SWARM_ID}, + service_node{ + from_hex("1f303f1d7523c46fa5398826740d13282d26b5de90fbae5749442f66afb6d" + "78b"), + oxen::quic::ipv4{"208.73.207.54"}, + uint16_t{22103}, + uint16_t{20203}, + {2, 11, 0}, + swarm::INVALID_SWARM_ID}, + service_node{ + from_hex("1f604f1c858a121a681d8f9b470ef72e6946ee1b9c5ad15a35e16b50c28db" + "7b0"), + oxen::quic::ipv4{"104.194.8.115"}, + uint16_t{22104}, + uint16_t{20204}, + {2, 11, 0}, + swarm::INVALID_SWARM_ID}, + }; + + return netid(Target::mainnet, seed_nodes); + } + + static netid testnet() { + auto seed_nodes = { + // service_node{ + // from_hex("decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"), + // oxen::quic::ipv4{"144.76.164.202"}, + // uint16_t{35500}, + // uint16_t{35400}, + // {2, 10, 0}, + // swarm::INVALID_SWARM_ID}, // This is the original one + + service_node{ + from_hex("decaf20025ca6389d8225bda6a32d7fc4ee5176d21e3b2e9e08c3505a48a8" + "11a"), + oxen::quic::ipv4{"23.88.6.250"}, + uint16_t{35520}, + uint16_t{35420}, + {2, 10, 0}, + swarm::INVALID_SWARM_ID}, // lokinet one + }; + + return netid(Target::testnet, seed_nodes); + } + + static netid devnet(std::vector seed_nodes) { + if (seed_nodes.empty()) + throw std::invalid_argument( + "devnet must be configured with at least one seed node."); + + return netid(Target::devnet, std::move(seed_nodes)); + } + + static std::string to_string(Target target) { + switch (target) { + case Target::mainnet: return "mainnet"; + case Target::testnet: return "testnet"; + case Target::devnet: return "devnet"; + } + + return "mainnet"; // Shouldn't happen + } + }; + + /// Can be used to override the default (onion_requests) routing method for requests. + struct router : base { + enum class Type { + onion_requests, + lokinet, + direct, + }; + + Type type; + + private: + explicit router(Type t) : type{t} {} + + public: + router() = delete; + + static router onion_requests() { return router(Type::onion_requests); } + static router lokinet() { return router(Type::lokinet); } + static router direct() { return router(Type::direct); } + }; + + /// Can be used to override the default (quic_onionreq) transport layer used to send requests. + struct transport : base { + enum class Type { + quic, + callbacks, + }; + // TODO: Add in "HTTP" as an option + + using network_callback_t = std::function; + + Type type; + std::optional callback; + + private: + explicit transport(Type t, std::optional callback = std::nullopt) : + type{t}, callback{std::move(callback)} {} + + public: + transport() = delete; + + static transport quic() { return transport(Type::quic); } + static transport callbacks(network_callback_t callback) { + return transport(Type::callbacks, std::move(callback)); + } + }; + + /// Can be used to override the default (3) path length used when building onion request or + /// lokinet paths. + struct path_length : base { + uint8_t length; + + explicit path_length(uint8_t length) : length{length} {} + }; + + /// Can be used to prevent the code from excluding nodes within the same `/24` subnet from being + /// included in the same path when building onion request or lokinet paths. + struct disable_subnet_diversity : base {}; + + /// Can be used to override the default (1) number of request retries that will occur when + /// receiving a 421 error. + struct redirect_retry_count : base { + uint8_t count; + + explicit redirect_retry_count(uint8_t count) : count{count} {} + }; + + struct retry_delay : base { + std::chrono::milliseconds base_delay; + std::chrono::milliseconds max_delay; + + explicit retry_delay( + std::chrono::milliseconds base_delay, std::chrono::milliseconds max_delay) : + base_delay{base_delay}, max_delay{max_delay} {} + + /// API: retry_delay/exponential + /// + /// A function which generates an exponential delay to wait before retrying a request/action + /// based on the provided failure count. + /// + /// Inputs: + /// - 'failure_count' - [in] the number of times the request has already failed. + inline std::chrono::milliseconds exponential(int failure_count) { + if (failure_count <= 0) + return base_delay; + + double delay_ms = base_delay.count() * std::pow(2.0, failure_count - 1); + auto final_delay = std::chrono::milliseconds(static_cast(delay_ms)); + + return std::min(final_delay, max_delay); + } + }; + + /// Can be used to override the default (250ms) fequency that is used to check if queued + /// requests have timed out due to transport/router setup. + struct request_timeout_check_frequency : base { + std::chrono::milliseconds frequency; + explicit request_timeout_check_frequency(std::chrono::milliseconds f) : frequency{f} {} + }; + + // MARK: Snode Pool Options + + /// Can be used to override the default ('.') path the network uses to cache files (eg. snode + /// pool and lokinet bootstrap). + struct cache_directory : base { + fs::path path; + explicit cache_directory(fs::path p) : path{p} {} + }; + + /// Can be used to override the default (2h) duration that the snode cache can be used for + /// before it needs to be refreshed. + struct cache_expiration : base { + std::chrono::minutes duration; + explicit cache_expiration(std::chrono::minutes duration) : duration{duration} {} + }; + + /// Can be used to override the default (2s) minimum duration that the snode cache should live + /// for, if a refresh is triggered within this period it will be delayed until the minimum + /// duration has passed to prevent excessive looping. + struct cache_min_lifetime : base { + std::chrono::milliseconds duration; + explicit cache_min_lifetime(std::chrono::milliseconds duration) : duration{duration} {} + }; + + /// Can be used to override the default (12) minimum number of unused nodes before we trigger a + /// snode cache refresh. + /// + /// Note: If the cache size is somehow smaller than this value (eg. Testnet is having issues) + /// then the minimum size will be the full cache size (minus enough to build a path) or at least + /// the size of a single path. + struct cache_min_size : base { + size_t size; + explicit cache_min_size(size_t size) : size{size} {} + }; + + /// Can be used to override the default (3) number of cached nodes used to refresh the cache for + /// any subsequent refreshes after populating from a seed node. + /// + /// Note: Providing a value of `0` will result in the cache _always_ being refreshed using a + /// seed node. + struct cache_num_nodes_to_use_for_refresh : base { + uint8_t count; + explicit cache_num_nodes_to_use_for_refresh(uint8_t count) : count{count} {} + }; + + /// Can be used to override the default (3) number of times a specific node in a path can + /// receive an error before it is removed from the path and replaced by a new node (or the path + /// is rebuilt if it happens to be the guard node). + struct cache_node_failure_threshold : base { + uint16_t count; + explicit cache_node_failure_threshold(uint16_t count) : count{count} {} + }; + + /// Can be used to make the snode cache use the legacy endpoint when refreshing. + struct cache_refresh_using_legacy_endpoint : base { + explicit cache_refresh_using_legacy_endpoint() {} + }; + + // MARK: Quic Transport Options + + /// Can be used to override the default (10s) handshake timeout duration for Quic connections. + struct quic_handshake_timeout : base { + std::chrono::milliseconds duration; + explicit quic_handshake_timeout(std::chrono::milliseconds duration) : duration{duration} {} + }; + + /// Can be used to override the default (0ms) keep alive duration for Quic connections. + struct quic_keep_alive : base { + std::chrono::seconds duration; + explicit quic_keep_alive(std::chrono::seconds duration) : duration{duration} {} + }; + + /// Can be used to disable Quic MTU discovery. + struct quic_disable_mtu_discovery : base {}; + + // MARK: Onion Request Router Options + + /// Can be used to override the default (3) number of times a path can receive an error before + /// it is dropped and replaced by a new path. + struct onionreq_path_failure_threshold : base { + uint16_t count; + + explicit onionreq_path_failure_threshold(uint16_t count) : count{count} {} + }; + + /// Can be used to override the default (3) number of times a path can receive an error before + /// it is dropped and replaced by a new path. + struct onionreq_path_build_retry_limit : base { + uint16_t count; + + explicit onionreq_path_build_retry_limit(uint16_t count) : count{count} {} + }; + + /// Can be used to override the default (2) minimum number of paths that are maintained for each + /// request category when using onion requests. If `onionreq_single_path_mode` is provided this + /// will be ignored. + struct onionreq_min_path_count : base { + RequestCategory category; + uint8_t min_count; + + explicit onionreq_min_path_count(RequestCategory category, uint8_t min_count) : + category{category}, min_count{min_count} {} + }; + + /// Can be used to force the onion request router to only use a single path regardless of what + /// category the requests sent have. When this option is provided `onionreq_min_path_count` will + /// be ignored. + struct onionreq_single_path_mode : base {}; + + /// Can be used to prevent the network instance from building onion request paths when + /// initialised, when this option is provided paths will be built when the first request it + /// made. + struct onionreq_disable_pre_build_paths : base {}; + +} // namespace opt +} // namespace session::network diff --git a/include/session/network/request_queue.hpp b/include/session/network/request_queue.hpp new file mode 100644 index 00000000..5df577b3 --- /dev/null +++ b/include/session/network/request_queue.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include +#include + +#include "session/network/session_network_types.hpp" +#include "session/network/transport/network_transport.hpp" + +namespace session::network::detail { + +class RequestQueue : public std::enable_shared_from_this { + private: + friend class TestRequestQueue; + + std::shared_ptr _loop; + std::chrono::milliseconds _check_frequency; + + std::deque> _queue; + bool _checker_active = false; + + public: + RequestQueue( + std::shared_ptr loop, std::chrono::milliseconds check_frequency) : + _loop{loop}, _check_frequency{check_frequency} {}; + ~RequestQueue(); + + bool is_empty() const { return _queue.empty(); }; + + virtual void add(Request request, network_response_callback_t callback); + virtual void add_front(std::pair req_pair); + + virtual std::deque> pop_all(); + + private: + virtual void check_timeouts(); +}; + +} // namespace session::network::detail diff --git a/include/session/network/routing/direct_router.hpp b/include/session/network/routing/direct_router.hpp new file mode 100644 index 00000000..0d034732 --- /dev/null +++ b/include/session/network/routing/direct_router.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "session/network/request_queue.hpp" +#include "session/network/routing/network_router.hpp" +#include "session/network/snode_pool.hpp" + +namespace session::network { + +class DirectRouter : public IRouter, public std::enable_shared_from_this { + private: + bool _suspended = false; + std::shared_ptr _loop; + std::weak_ptr _transport; + + public: + DirectRouter(std::shared_ptr loop, std::weak_ptr transport); + ~DirectRouter() override; + + void suspend() override; + void resume(bool automatically_reconnect = true) override; + void close_connections() override {}; + void clear_cache() override {}; + + ConnectionStatus get_status() const override { return _status.load(); }; + void send_request(Request request, network_response_callback_t callback) override; + + private: + std::atomic _status{ConnectionStatus::unknown}; + void _update_status(ConnectionStatus new_status); + void _send_request_internal(Request request, network_response_callback_t callback); + void _handle_transport_response( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response_body, + network_response_callback_t callback); +}; + +} // namespace session::network diff --git a/include/session/network/routing/lokinet_router.hpp b/include/session/network/routing/lokinet_router.hpp new file mode 100644 index 00000000..913f648e --- /dev/null +++ b/include/session/network/routing/lokinet_router.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "session/network/request_queue.hpp" +#include "session/network/routing/network_router.hpp" +#include "session/network/snode_pool.hpp" + +namespace lokinet { +class Lokinet; +struct tunnel_info; +}; // namespace lokinet + +namespace session::network { + +namespace config { + struct LokinetRouterConfig { + opt::netid::Target netid; + fs::path cache_directory; + std::chrono::milliseconds request_timeout_check_frequency; + + uint8_t path_length; + }; +} // namespace config + +class LokinetRouter : public IRouter, public std::enable_shared_from_this { + private: + bool _ready = false; + bool _suspended = false; + config::LokinetRouterConfig _config; + std::shared_ptr _loop; + std::shared_ptr lokinet; + std::weak_ptr _snode_pool; + std::weak_ptr _transport; + + std::unordered_map _active_tunnels; + std::unordered_map>> + _pending_requests; + + public: + LokinetRouter( + config::LokinetRouterConfig config, + std::shared_ptr loop, + std::weak_ptr snode_pool, + std::weak_ptr transport); + ~LokinetRouter() override; + + void suspend() override; + void resume(bool automatically_reconnect = true) override; + void close_connections() override; + void clear_cache() override; + + ConnectionStatus get_status() const override { return _status.load(); }; + std::vector get_active_paths() override; + void send_request(Request request, network_response_callback_t callback) override; + + private: + std::atomic _status{ConnectionStatus::unknown}; + + // All of the below functions should only be called from within `_loop` + void _finish_setup(); + void _close_connections(); + void _update_status(ConnectionStatus new_status); + void _send_request_internal(Request request, network_response_callback_t callback); + void _send_direct_request(Request request, network_response_callback_t callback); + void _send_proxy_request(Request request, network_response_callback_t callback); + void _establish_tunnel( + const oxen::quic::RemoteAddress& address, const std::string& initiating_req_id); + void _send_via_tunnel( + lokinet::tunnel_info tunnel, Request request, network_response_callback_t callback); +}; + +} // namespace session::network diff --git a/include/session/network/routing/network_router.hpp b/include/session/network/routing/network_router.hpp new file mode 100644 index 00000000..e2120dbf --- /dev/null +++ b/include/session/network/routing/network_router.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "session/network/transport/network_transport.hpp" + +namespace session::network { + +class IRouter { + public: + std::function on_status_changed; + + virtual ~IRouter() = default; + + virtual void suspend() = 0; + virtual void resume(bool automatically_reconnect = true) = 0; + virtual void close_connections() = 0; + virtual void clear_cache() = 0; + + virtual ConnectionStatus get_status() const = 0; + virtual std::vector get_active_paths() { return {}; }; + virtual std::vector get_all_used_nodes() { return {}; }; + virtual void send_request(Request request, network_response_callback_t callback) = 0; +}; + +} // namespace session::network \ No newline at end of file diff --git a/include/session/network/routing/onion_request_router.hpp b/include/session/network/routing/onion_request_router.hpp new file mode 100644 index 00000000..5096576d --- /dev/null +++ b/include/session/network/routing/onion_request_router.hpp @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "session/network/request_queue.hpp" +#include "session/network/routing/network_router.hpp" +#include "session/network/snode_pool.hpp" + +namespace session::network { + +namespace config { + struct OnionRequestRouterConfig { + network::opt::retry_delay retry_delay; + std::chrono::milliseconds request_timeout_check_frequency; + + uint8_t path_length; + uint8_t path_failure_threshold; + uint8_t path_build_retry_limit; + bool disable_pre_build_paths; + bool single_path_mode; + std::unordered_map min_path_counts; + }; +} // namespace config + +struct OnionPath { + std::string id; + std::vector nodes; + + size_t pending_requests = 0; + uint16_t failure_count = 0; + + std::string to_string() const; +}; + +class OnionRequestRouter : public IRouter, public std::enable_shared_from_this { + private: + friend class TestOnionRequestRouter; + + bool _ready = false; + bool _suspended = false; + config::OnionRequestRouterConfig _config; + std::shared_ptr _loop; + std::weak_ptr _snode_pool; + std::weak_ptr _transport; + + std::unordered_map> _paths; + std::unordered_map> _paths_pending_drop; + std::unordered_map> _request_queues; + + std::unordered_map _in_progress_path_builds; + std::unordered_map _path_build_retries; + std::unordered_map> _pending_paths; + + public: + OnionRequestRouter( + config::OnionRequestRouterConfig config, + std::shared_ptr loop, + std::weak_ptr snode_pool, + std::weak_ptr transport); + ~OnionRequestRouter() override; + + void suspend() override; + void resume(bool automatically_reconnect = true) override; + void close_connections() override; + void clear_cache() override {} + + ConnectionStatus get_status() const override { return _status.load(); }; + std::vector get_active_paths() override; + std::vector get_all_used_nodes() override; + void send_request(Request request, network_response_callback_t callback) override; + + private: + std::atomic _status{ConnectionStatus::unknown}; + + // All of the below functions should only be called from within `_loop` + void _finish_setup(); + void _pre_build_paths_if_needed(); + void _close_connections(); + void _update_status(); + void _send_request_internal(Request request, network_response_callback_t callback); + + void _build_path( + RequestCategory category, + std::optional initiating_req_id, + const std::vector& nodes_to_exclude, + std::optional original_path_id = std::nullopt); + void _on_guard_connectivity_response( + const std::string& path_id, + RequestCategory category, + std::optional initiating_req_id, + bool success); + + OnionPath* _find_valid_path(const Request& request); + + void _send_on_path(OnionPath& path, Request request, network_response_callback_t callback); + void _handle_transport_response( + std::string path_id, + Request original_request, + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional decrypted_body, + network_response_callback_t callback); + + void _decrement_and_cleanup_path(const std::string& path_id, RequestCategory category); + void _handle_path_failure( + const std::string& path_id, + const RequestCategory& request_category, + const std::optional& error_body); +}; + +} // namespace session::network diff --git a/include/session/network/service_node.h b/include/session/network/service_node.h new file mode 100644 index 00000000..fa541c09 --- /dev/null +++ b/include/session/network/service_node.h @@ -0,0 +1,20 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +typedef struct network_service_node { + char ed25519_pubkey_hex[65]; // The 64-byte ed25519 pubkey in hex + null terminator. + uint8_t ip[4]; + uint16_t https_port; + uint16_t omq_port; + uint16_t version[3]; + uint64_t swarm_id; +} network_service_node; + +#ifdef __cplusplus +} +#endif diff --git a/include/session/network/service_node.hpp b/include/session/network/service_node.hpp new file mode 100644 index 00000000..b07ea052 --- /dev/null +++ b/include/session/network/service_node.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include +#include + +#include +#include + +#include "session/network/service_node.h" +#include "session/network/swarm.hpp" + +namespace session::network { + +using namespace session::network::swarm; + +namespace service_node_disk_format { + constexpr size_t PUBKEY_HEX = 64; // 32 bytes * 2 hex chars + constexpr size_t IP_MAX = 15; // 255.255.255.255 + constexpr size_t PORT_MAX = 5; // 65535 + constexpr size_t VERSION_MAX = 17; // 65535.65535.65535 + constexpr size_t SWARM_ID_MAX = 20; // uint64_t max value + constexpr size_t FIELD_COUNT = 6; + constexpr size_t SEPARATORS = FIELD_COUNT - 1; // 5 pipes + constexpr size_t LINE_ENDING = 2; // \n\r (just in case) + + constexpr size_t MAX_LINE_SIZE = PUBKEY_HEX + IP_MAX + (PORT_MAX * 2) + VERSION_MAX + + SWARM_ID_MAX + SEPARATORS + LINE_ENDING; +} // namespace service_node_disk_format + +struct fork_versions { + uint16_t hardfork; + uint16_t softfork; + + bool operator==(const fork_versions& other) const { + return hardfork == other.hardfork && softfork == other.softfork; + } +}; + +struct service_node { + std::vector _remote_pubkey; + oxen::quic::ipv4 ip; + uint16_t https_port; + uint16_t omq_port; + std::array storage_server_version; + swarm_id_t swarm_id; + + oxen::quic::RemoteAddress to_https_address() const { + return oxen::quic::RemoteAddress{_remote_pubkey, ip, https_port}; + } + + oxen::quic::RemoteAddress to_omq_address() const { + return oxen::quic::RemoteAddress{_remote_pubkey, ip, omq_port}; + } + + std::span view_remote_key() const { return _remote_pubkey; } + std::string host() const { return ip.to_string(); } + session::network::x25519_pubkey swarm_pubkey() const; + + std::string to_string() const; + std::string to_https_string() const; + std::string to_omq_string() const; + + static service_node from(const network_service_node& node); + void into(network_service_node& n) const; + + template + void to_disk(OutputIt out) const { + fmt::format_to( + out, + "{}|{}|{}|{}|{}.{}.{}|{}\n", + oxenc::to_hex(view_remote_key()), + host(), + https_port, + omq_port, + storage_server_version[0], + storage_server_version[1], + storage_server_version[2], + swarm_id); + } + + static service_node from_disk(std::string_view str); + static std::pair, int> process_snode_cache_bin( + std::vector cache_bin); + + static service_node legacy_from_json(nlohmann::json json); + static service_node legacy_from_disk(std::string_view str); + std::string legacy_to_disk() const; + + bool operator==(const service_node& other) const = default; + auto operator<=>(const service_node& other) const = default; +}; + +inline std::ostream& operator<<(std::ostream& os, const service_node& sn) { + return os << sn.to_string(); +} + +} // namespace session::network diff --git a/include/session/network/session_network.h b/include/session/network/session_network.h new file mode 100644 index 00000000..520745e2 --- /dev/null +++ b/include/session/network/session_network.h @@ -0,0 +1,198 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +#include "session/export.h" +#include "session/log_level.h" +#include "session/network/session_network_types.h" +#include "session/onionreq/builder.h" +#include "session/platform.h" + +typedef struct network_object { + // Internal opaque object pointer; calling code should leave this alone. + void* internals; +} network_object; +typedef struct session_response_handle_cpp_t session_response_handle_t; + +typedef enum { + SESSION_NETWORK_MAINNET = 0, + SESSION_NETWORK_TESTNET = 1, + SESSION_NETWORK_DEVNET = 2 +} SESSION_NETWORK_NETID; + +typedef enum { + SESSION_NETWORK_ROUTER_ONION_REQUESTS = 0, + SESSION_NETWORK_ROUTER_LOKINET = 1, + SESSION_NETWORK_ROUTER_DIRECT = 2, +} SESSION_NETWORK_ROUTER; + +typedef enum { + SESSION_NETWORK_TRANSPORT_QUIC = 0, + SESSION_NETWORK_TRANSPORT_CALLBACKS = 1, +} SESSION_NETWORK_TRANSPORT; + +typedef void (*session_network_request_t)( + const char* url, + const char* body_data, + size_t body_size, + session_response_handle_t* response_handle, + void* ctx); + +typedef struct { + // Basic options + SESSION_NETWORK_NETID netid; + SESSION_NETWORK_ROUTER router; + SESSION_NETWORK_TRANSPORT transport; + uint8_t path_length; + bool enforce_subnet_diversity; + uint8_t redirect_retry_count; + uint64_t min_retry_delay_ms; + uint64_t max_retry_delay_ms; + uint64_t request_timeout_check_frequency_ms; + + // Devnet options (only used when netid_target == SESSION_NETWORK_DEVNET) + const network_service_node* devnet_seed_nodes; + size_t devnet_seed_nodes_size; + + // Snode pool options + const char* cache_dir; + uint32_t cache_expiration_minutes; + uint64_t cache_min_lifetime_ms; + size_t cache_min_size; + uint8_t cache_num_nodes_to_use_for_refresh; + uint8_t cache_node_failure_threshold; + bool cache_refresh_using_legacy_endpoint; + + // Onion request router options (only used when router == + // SESSION_NETWORK_ROUTER_ONION_REQUESTS) + uint8_t onionreq_path_failure_threshold; + uint8_t onionreq_path_build_retry_limit; + uint8_t onionreq_min_path_count_standard; + uint8_t onionreq_min_path_count_upload; + uint8_t onionreq_min_path_count_download; + bool onionreq_single_path_mode; + bool onionreq_disable_pre_build_paths; + + // Quic transport options (for transport == SESSION_NETWORK_TRANSPORT_QUIC) + uint32_t quic_handshake_timeout_seconds; + uint32_t quic_keep_alive_seconds; + bool quic_disable_mtu_discovery; + + // Callback options (for transport == SESSION_NETWORK_TRANSPORT_CALLBACKS) + session_network_request_t transport_callback; + + // A user-defined context pointer passed back to every invocation of `transport_callback` + void* transport_callback_ctx; + +} session_network_config; + +typedef void (*session_network_response_t)( + bool success, + bool timeout, + int16_t status_code, + const char* const* headers_kv_pairs, + size_t headers_kv_pairs_len, + const unsigned char* response, + size_t response_size, + void* ctx); + +/// API: network/session_network_default_config +/// +/// Populates an instance with the default configuration options. +/// +/// Inputs: +/// - `config` -- [in] Pointer to session_network_config object +LIBSESSION_EXPORT session_network_config session_network_config_default(); + +LIBSESSION_EXPORT bool session_network_init( + network_object** network, + const session_network_config* config, + char* error) LIBSESSION_WARN_UNUSED; + +/// API: network/session_network_free +/// +/// Frees a network object. +/// +/// Inputs: +/// - `network` -- [in] Pointer to network_object object +LIBSESSION_EXPORT void session_network_free(network_object* network); + +/// API: network/session_request_params_free +/// +/// Frees a request params object. +/// +/// Inputs: +/// - `params` -- [in] Pointer to session_request_params object +LIBSESSION_EXPORT void session_request_params_free(session_request_params* params); + +LIBSESSION_EXPORT void session_network_suspend(network_object* network); +LIBSESSION_EXPORT void session_network_resume( + network_object* network, bool automatically_reconnect); +LIBSESSION_EXPORT void session_network_close_connections(network_object* network); +LIBSESSION_EXPORT void session_network_clear_cache(network_object* network); + +LIBSESSION_EXPORT int64_t session_network_time_offset(network_object* network); +LIBSESSION_EXPORT uint16_t session_network_hardfork(network_object* network); +LIBSESSION_EXPORT uint16_t session_network_softfork(network_object* network); + +/// API: network/network_set_status_changed_callback +/// +/// Registers a callback to be called whenever the network connection status changes. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object +/// - `callback` -- [in] callback to be called when the network connection status changes. +/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. +LIBSESSION_EXPORT void session_network_set_status_changed_callback( + network_object* network, void (*callback)(CONNECTION_STATUS status, void* ctx), void* ctx); + +LIBSESSION_EXPORT void session_network_set_network_info_changed_callback( + network_object* netowrk, + void (*callback)( + int64_t network_time_offset, uint16_t hardfork, uint16_t softfork, void* ctx), + void* ctx); + +LIBSESSION_EXPORT void session_network_callbacks_respond( + network_object* network, + session_response_handle_t* response_handle, + bool success, + bool timeout, + int16_t status_code, + const char* const* headers, + const char* const* header_values, + size_t headers_size, + const char* body, + size_t body_len); + +LIBSESSION_EXPORT CONNECTION_STATUS session_network_get_status(network_object* network); + +LIBSESSION_EXPORT void session_network_get_active_paths( + network_object* network, session_path_info** out_paths, size_t* out_paths_len); + +LIBSESSION_EXPORT void session_network_paths_free(session_path_info* paths); + +LIBSESSION_EXPORT void session_network_get_swarm( + network_object* network, + const char* swarm_pubkey_hex, + void (*callback)(network_service_node* nodes, size_t nodes_len, void*), + void* ctx); + +LIBSESSION_EXPORT void session_network_get_random_nodes( + network_object* network, + uint16_t count, + void (*callback)(network_service_node*, size_t, void*), + void* ctx); + +LIBSESSION_EXPORT void session_network_send_request( + network_object* network, + const session_request_params* params, + session_network_response_t callback, + void* ctx); + +#ifdef __cplusplus +} +#endif diff --git a/include/session/network/session_network.hpp b/include/session/network/session_network.hpp new file mode 100644 index 00000000..58a20fe1 --- /dev/null +++ b/include/session/network/session_network.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include + +#include "session/network/network_config.hpp" +#include "session/network/routing/network_router.hpp" +#include "session/network/snode_pool.hpp" +#include "session/network/transport/network_transport.hpp" +#include "session/platform.hpp" +#include "session/types.hpp" + +namespace session::network { + +namespace fs = std::filesystem; + +class Network { + private: + const config::Config config; + std::shared_ptr _loop; + std::shared_ptr _snode_pool; + std::shared_ptr _transport; + std::shared_ptr _router; + bool _suspended = false; + + public: + // Hook to be notified whenever the network connection status changes. + std::function on_status_changed; + std::function + on_network_info_changed; + + template + requires(!std::is_same_v< + std::decay_t>>, + config::Config>) + Network(Opt&&... opts) : Network(Config(std::forward(opts)...)){}; + explicit Network(config::Config config); + virtual ~Network(); + + std::chrono::milliseconds network_time_offset() const { return _network_time_offset; }; + fork_versions fork() const { return _fork_versions.load(); }; + uint16_t hardfork() const { return _fork_versions.load().hardfork; }; + uint16_t softfork() const { return _fork_versions.load().softfork; }; + + void suspend(); + void resume(bool automatically_reconnect = true); + void close_connections(); + void clear_cache(); + + ConnectionStatus get_status(); + std::vector get_active_paths(); + + /// API: network/get_swarm + /// + /// Retrieves the swarm for the given pubkey. If there is already an entry in the cache for the + /// swarm then that will be returned, otherwise a network request will be made to retrieve the + /// swarm and save it to the cache. + /// + /// Inputs: + /// - 'swarm_pubkey' - [in] public key for the swarm. + /// - 'callback' - [in] callback to be called with the retrieved swarm (in the case of an error + /// the callback will be called with an empty list). + void get_swarm( + session::network::x25519_pubkey swarm_pubkey, + std::function swarm)> callback); + + /// API: network/get_random_nodes + /// + /// Retrieves a number of random nodes from the snode pool. If the are no nodes in the pool a + /// new pool will be populated and the nodes will be retrieved from that. + /// + /// Inputs: + /// - 'count' - [in] the number of nodes to retrieve. + /// - 'callback' - [in] callback to be called with the retrieved nodes (in the case of an error + /// the callback will be called with an empty list). + void get_random_nodes( + uint16_t count, std::function nodes)> callback); + + void send_request(Request request, network_response_callback_t callback); + + private: + std::atomic _status{ConnectionStatus::unknown}; + std::atomic _network_time_offset{0ms}; + std::atomic _fork_versions{{0, 0}}; + + void configure(); + + void _close_connections(); + void _recalculate_status(); + void _update_status(ConnectionStatus new_status); + void _update_network_state(const std::string& body); + void _handle_421_retry(Request original_request, network_response_callback_t final_callback); + Request _preprocess_request(Request request); +}; + +} // namespace session::network diff --git a/include/session/network/session_network_types.h b/include/session/network/session_network_types.h new file mode 100644 index 00000000..6858d727 --- /dev/null +++ b/include/session/network/session_network_types.h @@ -0,0 +1,83 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#include "../export.h" +#include "session/network/service_node.h" + +typedef enum CONNECTION_STATUS { + CONNECTION_STATUS_UNKNOWN, + CONNECTION_STATUS_CONNECTING, + CONNECTION_STATUS_CONNECTED, + CONNECTION_STATUS_DISCONNECTED, +} CONNECTION_STATUS; + +typedef enum { + SESSION_NETWORK_REQUEST_CATEGORY_STANDARD, + SESSION_NETWORK_REQUEST_CATEGORY_UPLOAD, + SESSION_NETWORK_REQUEST_CATEGORY_DOWNLOAD, +} SESSION_NETWORK_REQUEST_CATEGORY; + +typedef struct network_server_destination { + const char* method; + const char* protocol; + const char* host; + uint16_t port; + const char* x25519_pubkey_hex; + const char* const* headers_kv_pairs; + size_t headers_kv_pairs_len; +} network_server_destination; + +typedef struct { + char ed25519_pubkey_hex[65]; // The 64-byte ed25519 pubkey in hex + null terminator. + uint8_t ip[4]; + uint16_t port; +} session_remote_address; + +typedef struct { + // Only ONE of these pointers should be set, the other should be left null + const network_service_node* snode_dest; + const network_server_destination* server_dest; + const session_remote_address* remote_addr_dest; + + const char* endpoint; + const unsigned char* body; + size_t body_size; + + SESSION_NETWORK_REQUEST_CATEGORY category; + uint64_t request_timeout_ms; + uint64_t overall_timeout_ms; // Use 0 for no overall timeout + + const char* upload_file_name; // Optional name for file uploads, null terminated + + const char* request_id; // Optional id for the request to trace through logs, null terminated + +} session_request_params; + +typedef struct { + SESSION_NETWORK_REQUEST_CATEGORY category; +} session_onion_path_metadata; + +typedef struct { + char destination_pubkey[65]; // The 64-byte ed25519 pubkey in hex + null terminator. + char destination_snode_address[65]; // The 64-byte .snode address + null terminator. +} session_lokinet_tunnel_metadata; + +typedef struct { + const network_service_node* nodes; + size_t nodes_count; + + // Only ONE of these pointers should be set, the other should be left null + const session_onion_path_metadata* onion_metadata; + const session_lokinet_tunnel_metadata* lokinet_metadata; + +} session_path_info; + +#ifdef __cplusplus +} +#endif diff --git a/include/session/network/session_network_types.hpp b/include/session/network/session_network_types.hpp new file mode 100644 index 00000000..b3433fd9 --- /dev/null +++ b/include/session/network/session_network_types.hpp @@ -0,0 +1,177 @@ +#pragma once + +#include +#include +#include +#include + +#include "session/network/key_types.hpp" +#include "session/network/service_node.hpp" +#include "session/network/session_network_types.h" + +namespace session::network { + +constexpr int16_t ERROR_NETWORK_SUSPENDED = -10001; +constexpr int16_t ERROR_BUILD_TIMEOUT = -10003; + +const std::pair content_type_plain_text = { + "Content-Type", "text/plain; charset=UTF-8"}; +const std::pair content_type_json = {"Content-Type", "application/json"}; + +class status_code_exception : public std::runtime_error { + public: + int16_t status_code; + std::vector> headers; + + status_code_exception( + int16_t status_code, + std::vector> headers, + std::string message) : + std::runtime_error(message), status_code{status_code}, headers{headers} {} +}; + +enum class ConnectionStatus { + unknown = CONNECTION_STATUS_UNKNOWN, + connecting = CONNECTION_STATUS_CONNECTING, + connected = CONNECTION_STATUS_CONNECTED, + disconnected = CONNECTION_STATUS_DISCONNECTED, +}; + +enum class RequestCategory { + standard = SESSION_NETWORK_REQUEST_CATEGORY_STANDARD, + upload = SESSION_NETWORK_REQUEST_CATEGORY_UPLOAD, + download = SESSION_NETWORK_REQUEST_CATEGORY_DOWNLOAD, +}; + +inline std::string to_string(RequestCategory category) { + switch (category) { + case RequestCategory::standard: return "standard"; + case RequestCategory::upload: return "upload"; + case RequestCategory::download: return "download"; + } + return "unknown"; // Should not be reached +} + +struct ServerDestination { + std::string protocol; + std::string host; + session::network::x25519_pubkey x25519_pubkey; + std::optional port; + std::optional>> headers; + std::string method; + + ServerDestination( + std::string protocol, + std::string host, + session::network::x25519_pubkey x25519_pubkey, + std::optional port = std::nullopt, + std::optional>> headers = std::nullopt, + std::string method = "GET") : + protocol{std::move(protocol)}, + host{std::move(host)}, + x25519_pubkey{std::move(x25519_pubkey)}, + port{std::move(port)}, + headers{std::move(headers)}, + method{std::move(method)} {} +}; + +using network_destination = + std::variant; + +struct UploadInfo { + std::optional file_name; +}; + +using RequestDetails = std::variant; + +struct Request { + std::string request_id; + network_destination destination; + std::string endpoint; + std::optional> body; + RequestCategory category; + + /// Timeout for an in-flight request after it has been sent via the transport mechanism. + std::chrono::milliseconds request_timeout; + + /// An optional, overall timeout for the entire operation, starting from the moment the request + /// is created. This includes time spent in queues waiting for a path to be built or a + /// connection to be established. If this timeout is exceeded while the request is still in a + /// queue, it will be timed out. + std::optional overall_timeout; + + /// Any extra request details which may modify the structure of the request. + RequestDetails details; + + /// The time the request was created, this is used primarily for determining whether the + /// `overall_timeout` has been exceeded. + std::chrono::system_clock::time_point creation_time = std::chrono::system_clock::now(); + + // If true, the transport should not cache/pool the connection used for this request, this is + // for one-shot requests like bootstrapping. + bool ephemeral_connection; + + int retry_count = 0; + + Request(std::string request_id, + network_destination destination, + std::string endpoint, + std::optional> body, + RequestCategory category, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout = std::nullopt, + RequestDetails details = std::monostate{}, + bool ephemeral_connection = false); + + Request(network_destination destination, + std::string endpoint, + std::optional> body, + RequestCategory category, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout = std::nullopt, + std::optional request_id = std::nullopt, + RequestDetails details = std::monostate{}, + bool ephemeral_connection = false); + + std::chrono::milliseconds time_remaining() const { + if (!overall_timeout) + return request_timeout; + + auto elapsed = std::chrono::duration_cast( + std::chrono::system_clock::now() - creation_time); + auto remaining = *overall_timeout - elapsed; + + return (remaining > std::chrono::milliseconds::zero() ? remaining + : std::chrono::milliseconds::zero()); + } +}; + +using node_failure_reporter_t = std::function; +using network_response_callback_t = std::function> headers, + std::optional response)>; + +struct Response { + static std::optional> parse_text_error(const std::string& body); + static std::optional find_uniform_batch_error(const std::string& body); +}; + +struct OnionPathMetadata { + RequestCategory category; +}; +struct LokinetTunnelMetadata { + std::string destination_pubkey; + std::string destination_snode_address; +}; + +using PathMetadata = std::variant; + +struct PathInfo { + std::vector nodes; + PathMetadata metadata; +}; + +} // namespace session::network diff --git a/include/session/network/snode_pool.hpp b/include/session/network/snode_pool.hpp new file mode 100644 index 00000000..bf7e12a3 --- /dev/null +++ b/include/session/network/snode_pool.hpp @@ -0,0 +1,136 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "session/network/key_types.hpp" +#include "session/network/network_config.hpp" +#include "session/network/service_node.hpp" +#include "swarm.hpp" + +namespace session::network { + +namespace config { + struct SnodePoolConfig { + std::optional cache_directory; + std::chrono::minutes cache_expiration; + std::chrono::milliseconds cache_min_lifetime; + bool enforce_subnet_diversity; + network::opt::retry_delay retry_delay; + + opt::netid::Target netid; + std::vector seed_nodes; + + size_t cache_min_size; + uint8_t cache_num_nodes_to_use_for_refresh; + uint16_t cache_node_failure_threshold; + bool cache_refresh_using_legacy_endpoint; + }; +} // namespace config + +class SnodePool : public std::enable_shared_from_this { + public: + using network_fetcher_t = std::function; + using fetcher_connectivity_check_t = std::function; + + SnodePool( + config::SnodePoolConfig config, + std::shared_ptr loop, + network_fetcher_t direct_fetcher); + ~SnodePool(); + + void suspend(); + void resume(); + + // Sets the network fetcher which should be used once the snode cache exists + void set_routed_fetcher( + network_fetcher_t routed_fetcher, fetcher_connectivity_check_t connectivity_check); + + // Returns the number of nodes currently in the pool + size_t size(); + + // Forcibly clears the cache from memory and disk + void clear_cache(); + + // Records that a specific node has failed a request + virtual void record_node_failure(const service_node& node, bool permanent = false); + virtual void record_node_failure(const ed25519_pubkey& key, bool permanent = false); + uint16_t node_failure_count(const service_node& node); + uint16_t node_failure_count(const ed25519_pubkey& key); + void clear_node_failure_counts(); + + // Checks if the pool is empty or stale and triggers a refresh if needed + virtual void refresh_if_needed( + const std::vector& in_use_nodes, + std::function on_refresh_complete = nullptr); + + virtual void get_swarm( + session::network::x25519_pubkey swarm_pubkey, + std::function)> callback); + + virtual std::vector get_unused_nodes( + size_t count, const std::vector& exclude = {}); + + private: + friend class TestSnodePool; + + bool _suspended = false; + config::SnodePoolConfig _config; + std::shared_ptr _loop; + network_fetcher_t _direct_fetcher; + std::optional _routed_fetcher; + std::optional _routed_fetcher_connectivity_check; + + // Data (protected by '_cache_mutex') + std::vector _snode_cache; + std::vector>> _all_swarms; + std::unordered_map>> + _swarm_cache; + std::unordered_map _snode_failure_counts; + + // Disk I/O + std::filesystem::path _snode_cache_file_path; + std::thread _disk_write_thread; + std::condition_variable _disk_write_cv; + std::mutex _cache_mutex; + bool _need_write = false; + bool _need_clear_cache = false; + bool _shut_down_disk_thread = false; + + // Refresh logic (protected by '_cache_mutex') + std::chrono::system_clock::time_point _last_snode_cache_update; + std::optional _current_snode_cache_refresh_id; + int _snode_cache_refresh_failure_count = 0; + std::vector _refresh_candidate_nodes; + std::vector> _snode_refresh_results; + std::vector> _after_snode_cache_refresh; + + // Disk I/O functions + void _load_from_disk(); + void _disk_write_loop(); + + // Refresh functions + void _refresh_snode_cache(std::optional request_id = std::nullopt); + void _launch_next_refresh_request( + const std::string& request_id, + const bool use_direct_fetcher, + const uint8_t total_requests); + void _retry_refresh_request( + const std::string& request_id, + const bool use_direct_fetcher, + const uint8_t total_requests); + void _on_refresh_complete( + std::string refresh_id, + std::vector> raw_results, + const bool use_direct_fetcher, + const uint8_t total_requests, + const bool from_legacy_endpoint); +}; + +} // namespace session::network diff --git a/include/session/network/swarm.hpp b/include/session/network/swarm.hpp new file mode 100644 index 00000000..1c0ba8b0 --- /dev/null +++ b/include/session/network/swarm.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "session/network/key_types.hpp" + +namespace session::network { +struct service_node; +} // namespace session::network + +namespace session::network::swarm { + +using swarm_id_t = uint64_t; +constexpr swarm_id_t INVALID_SWARM_ID = std::numeric_limits::max(); + +swarm_id_t pubkey_to_swarm_space(const session::network::x25519_pubkey& pk); +std::vector>> generate_swarms( + const std::vector nodes); +std::pair> get_swarm( + const session::network::x25519_pubkey swarm_pubkey, + const std::vector>> all_swarms); + +} // namespace session::network::swarm diff --git a/include/session/network/transport/network_transport.hpp b/include/session/network/transport/network_transport.hpp new file mode 100644 index 00000000..ef780a6a --- /dev/null +++ b/include/session/network/transport/network_transport.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "session/network/session_network_types.hpp" + +namespace session::network { + +class ITransport { + public: + std::function on_status_changed; + + virtual ~ITransport() = default; + + virtual void suspend() = 0; + virtual void resume(bool automatically_reconnect = true) = 0; + virtual void close_connections() = 0; + + virtual ConnectionStatus get_status() const = 0; + virtual void set_node_failure_reporter(node_failure_reporter_t reporter) {} + virtual void verify_connectivity( + service_node node, + std::chrono::milliseconds timeout, + const std::string& request_id, + std::function callback) = 0; + virtual void add_failure_listener( + const ed25519_pubkey& pubkey, std::function listener) = 0; + virtual void remove_failure_listeners(const ed25519_pubkey& pubkey) = 0; + + virtual void send_request(Request request, network_response_callback_t callback) = 0; +}; + +} // namespace session::network \ No newline at end of file diff --git a/include/session/network/transport/quic_transport.hpp b/include/session/network/transport/quic_transport.hpp new file mode 100644 index 00000000..81825754 --- /dev/null +++ b/include/session/network/transport/quic_transport.hpp @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "session/network/key_types.hpp" +#include "session/network/network_config.hpp" +#include "session/network/transport/network_transport.hpp" + +namespace oxen::quic { +class Loop; +class Endpoint; +struct ConnectionID; +} // namespace oxen::quic + +namespace session::network { + +namespace config { + struct QuicTransportConfig { + std::chrono::milliseconds handshake_timeout; + std::chrono::seconds keep_alive; + + bool disable_mtu_discovery; + }; +} // namespace config + +class QuicTransport : public ITransport, public std::enable_shared_from_this { + private: + bool _suspended = false; + config::QuicTransportConfig _config; + std::shared_ptr _loop; + std::shared_ptr _endpoint; + + std::unordered_set _ephemeral_connection_ids; + std::unordered_map _active_connection_ids; + std::unordered_map _active_stream_ids; + std::unordered_map>> + _pending_verification_callbacks; + std::unordered_map>> + _pending_requests; + std::unordered_map>> _failure_listeners; + + public: + explicit QuicTransport( + config::QuicTransportConfig config, std::shared_ptr loop); + ~QuicTransport() override; + + void suspend() override; + void resume(bool automatically_reconnect = true) override; + void close_connections() override; + + ConnectionStatus get_status() const override { return _status.load(); }; + void set_node_failure_reporter(node_failure_reporter_t reporter) override; + void verify_connectivity( + service_node node, + std::chrono::milliseconds timeout, + const std::string& request_id, + std::function callback) override; + void add_failure_listener( + const ed25519_pubkey& pubkey, std::function listener) override; + void remove_failure_listeners(const ed25519_pubkey& pubkey) override; + void send_request(Request request, network_response_callback_t callback) override; + + private: + // The current connection status of this transport layer + std::atomic _status{ConnectionStatus::unknown}; + + // Callback which will be called when failing to connect to a node + std::optional _report_node_failure; + + // True if we have already transitioned to "connecting" since the last time we were fully + // disconnected + bool _has_attempted_reconnect = false; + + void _recreate_endpoint(); + void _close_connections(); + void _update_status(ConnectionStatus new_status); + void _send_request_internal(Request request, network_response_callback_t callback); + void _establish_connection( + const oxen::quic::RemoteAddress& address, const std::string& initiating_req_id); + void _send_on_connection( + oxen::quic::ConnectionID conn_id, + Request request, + network_response_callback_t callback); + void _fail_connection( + const std::string& address_pubkey_hex, + const std::string& initiating_req_id, + std::optional conn_id, + std::optional error_code, + std::optional custom_error); +}; + +} // namespace session::network \ No newline at end of file diff --git a/include/session/onionreq/builder.h b/include/session/onionreq/builder.h index 4320276e..7dbed35c 100644 --- a/include/session/onionreq/builder.h +++ b/include/session/onionreq/builder.h @@ -76,7 +76,6 @@ LIBSESSION_EXPORT void onion_request_builder_set_snode_destination( /// - `builder` -- [in] Pointer to the builder object /// - `protocol` -- [in] The protocol to use /// - `host` -- [in] The server host -/// - `endpoint` -- [in] The endpoint to call /// - `method` -- [in] The HTTP method to use /// - `port` -- [in] The port to use /// - `x25519_pubkey` -- [in] The x25519 public key for server @@ -84,22 +83,10 @@ LIBSESSION_EXPORT void onion_request_builder_set_server_destination( onion_request_builder_object* builder, const char* protocol, const char* host, - const char* endpoint, const char* method, uint16_t port, const char* x25519_pubkey); -/// API: onion_request_builder_set_destination_pubkey -/// -/// Wrapper around session::onionreq::Builder::set_destination_pubkey. -/// -/// Inputs: -/// - `builder` -- [in] Pointer to the builder object -/// - `x25519_pubkey` -- [in] The x25519 public key for server (Hex string of exactly 64 -/// characters). -LIBSESSION_EXPORT void onion_request_builder_set_destination_pubkey( - onion_request_builder_object* builder, const char* x25519_pubkey); - /// API: onion_request_builder_add_hop /// /// Wrapper around session::onionreq::Builder::add_hop. ed25519_pubkey and diff --git a/include/session/onionreq/builder.hpp b/include/session/onionreq/builder.hpp index fe87ec51..99890cbe 100644 --- a/include/session/onionreq/builder.hpp +++ b/include/session/onionreq/builder.hpp @@ -5,46 +5,20 @@ #include #include -#include "key_types.hpp" +#include "session/network/session_network_types.hpp" namespace session::network { struct service_node; struct request_info; +struct Request; } // namespace session::network namespace session::onionreq { -struct ServerDestination { - std::string protocol; - std::string host; - std::string endpoint; - session::onionreq::x25519_pubkey x25519_pubkey; - std::optional port; - std::optional>> headers; - std::string method; - - ServerDestination( - std::string protocol, - std::string host, - std::string endpoint, - session::onionreq::x25519_pubkey x25519_pubkey, - std::optional port = std::nullopt, - std::optional>> headers = std::nullopt, - std::string method = "GET") : - protocol{std::move(protocol)}, - host{std::move(host)}, - endpoint{std::move(endpoint)}, - x25519_pubkey{std::move(x25519_pubkey)}, - port{std::move(port)}, - headers{std::move(headers)}, - method{std::move(method)} {} -}; - -using network_destination = std::variant; - namespace detail { - session::onionreq::x25519_pubkey pubkey_for_destination(network_destination destination); + session::network::x25519_pubkey pubkey_for_destination( + network::network_destination destination); } enum class EncryptType { @@ -66,43 +40,51 @@ inline constexpr std::string_view to_string(EncryptType type) { // Builder class for preparing onion request payloads. class Builder { - Builder(const network_destination& destination, + public: + Builder(const network::network_destination& destination, + const std::string& endpoint, const std::vector& nodes, - const EncryptType enc_type_); + const EncryptType enc_type_ = EncryptType::xchacha20); - public: static Builder make( - const network_destination& destination, + const network::network_destination& destination, + const std::string& endpoint, const std::vector& nodes, const EncryptType enc_type_ = EncryptType::xchacha20); EncryptType enc_type; - std::optional destination_x25519_public_key = std::nullopt; - std::optional final_hop_x25519_keypair = std::nullopt; + bool is_v4_request; + std::optional final_hop_x25519_keypair = std::nullopt; Builder(EncryptType enc_type_ = EncryptType::xchacha20) : enc_type{enc_type_} {} void set_enc_type(EncryptType enc_type_) { enc_type = enc_type_; } + std::optional get_destination_x25519_public_key() const { + return destination_x25519_public_key_; + }; - void set_destination(network_destination destination); - void set_destination_pubkey(session::onionreq::x25519_pubkey x25519_pubkey); + void set_destination(network::network_destination destination); void add_hop(std::span remote_key); - void add_hop(std::pair keys) { hops_.push_back(keys); } + void add_hop(std::pair keys) { + hops_.push_back(keys); + } - void generate(network::request_info& info); std::vector build(std::vector payload); + std::vector generate_onion_blob( + const std::optional>& plaintext_body); private: - std::vector> hops_ = {}; + std::vector> hops_ = {}; + std::string endpoint_; + std::optional destination_x25519_public_key_ = std::nullopt; // Snode request values - std::optional ed25519_public_key_ = std::nullopt; + std::optional ed25519_public_key_ = std::nullopt; // Proxied request values std::optional host_ = std::nullopt; - std::optional endpoint_ = std::nullopt; std::optional protocol_ = std::nullopt; std::optional method_ = std::nullopt; std::optional port_ = std::nullopt; diff --git a/include/session/onionreq/hop_encryption.hpp b/include/session/onionreq/hop_encryption.hpp index a5e937de..47bb1f28 100644 --- a/include/session/onionreq/hop_encryption.hpp +++ b/include/session/onionreq/hop_encryption.hpp @@ -4,14 +4,17 @@ #include #include "builder.hpp" -#include "key_types.hpp" +#include "session/network/key_types.hpp" namespace session::onionreq { // Encryption/decription class for encryption/decrypting outgoing/incoming messages. class HopEncryption { public: - HopEncryption(x25519_seckey private_key, x25519_pubkey public_key, bool server = true) : + HopEncryption( + network::x25519_seckey private_key, + network::x25519_pubkey public_key, + bool server = true) : private_key_{std::move(private_key)}, public_key_{std::move(public_key)}, server_{server} {} @@ -25,17 +28,17 @@ class HopEncryption { std::vector encrypt( EncryptType type, std::vector plaintext, - const x25519_pubkey& pubkey) const; + const network::x25519_pubkey& pubkey) const; std::vector decrypt( EncryptType type, std::vector ciphertext, - const x25519_pubkey& pubkey) const; + const network::x25519_pubkey& pubkey) const; // AES-GCM encryption. std::vector encrypt_aesgcm( - std::vector plainText, const x25519_pubkey& pubKey) const; + std::vector plainText, const network::x25519_pubkey& pubKey) const; std::vector decrypt_aesgcm( - std::vector cipherText, const x25519_pubkey& pubKey) const; + std::vector cipherText, const network::x25519_pubkey& pubKey) const; // xchacha20-poly1305 encryption; for a message sent from client Alice to server Bob we use a // shared key of a Blake2B 32-byte (i.e. crypto_aead_xchacha20poly1305_ietf_KEYBYTES) hash of @@ -46,13 +49,13 @@ class HopEncryption { // H(bA || A || B) (note that this is *different* that what would result if Bob was a client // sending to Alice the client). std::vector encrypt_xchacha20( - std::vector plaintext, const x25519_pubkey& pubKey) const; + std::vector plaintext, const network::x25519_pubkey& pubKey) const; std::vector decrypt_xchacha20( - std::vector ciphertext, const x25519_pubkey& pubKey) const; + std::vector ciphertext, const network::x25519_pubkey& pubKey) const; private: - const x25519_seckey private_key_; - const x25519_pubkey public_key_; + const network::x25519_seckey private_key_; + const network::x25519_pubkey public_key_; bool server_; // True if we are the server (i.e. the snode). }; diff --git a/include/session/onionreq/parser.hpp b/include/session/onionreq/parser.hpp index 233a97e0..91857904 100644 --- a/include/session/onionreq/parser.hpp +++ b/include/session/onionreq/parser.hpp @@ -11,10 +11,10 @@ constexpr size_t DEFAULT_MAX_SIZE = 10'485'760; // 10 MiB class OnionReqParser { private: - x25519_keypair keys; + network::x25519_keypair keys; HopEncryption enc; EncryptType enc_type = EncryptType::aes_gcm; - x25519_pubkey remote_pk; + network::x25519_pubkey remote_pk; std::vector payload_; public: diff --git a/include/session/onionreq/response_parser.hpp b/include/session/onionreq/response_parser.hpp index 46f76a43..6a1d7c0a 100644 --- a/include/session/onionreq/response_parser.hpp +++ b/include/session/onionreq/response_parser.hpp @@ -3,34 +3,48 @@ #include #include "hop_encryption.hpp" -#include "key_types.hpp" +#include "session/network/key_types.hpp" +#include "session/network/session_network_types.hpp" namespace session::onionreq { constexpr auto decryption_failed_error = "Decryption failed (both XChaCha20-Poly1305 and AES256-GCM)"sv; +struct DecryptedResponse { + int16_t status_code; + std::vector> headers; + std::optional body; +}; + class ResponseParser { public: /// Constructs a parser, parsing the given request sent to us. Throws if parsing or decryption /// fails. ResponseParser(session::onionreq::Builder builder); ResponseParser( - x25519_pubkey destination_x25519_public_key, - x25519_keypair x25519_keypair, - EncryptType enc_type = EncryptType::xchacha20) : + network::x25519_pubkey destination_x25519_public_key, + network::x25519_keypair x25519_keypair, + EncryptType enc_type = EncryptType::xchacha20, + bool v4_request = false) : destination_x25519_public_key_{std::move(destination_x25519_public_key)}, x25519_keypair_{std::move(x25519_keypair)}, - enc_type_{enc_type} {} + enc_type_{enc_type}, + v4_request_{v4_request} {} static bool response_long_enough(EncryptType enc_type, size_t response_size); std::vector decrypt(std::vector ciphertext) const; + DecryptedResponse decrypted_response(const std::string& encrypted_response); private: - x25519_pubkey destination_x25519_public_key_; - x25519_keypair x25519_keypair_; + network::x25519_pubkey destination_x25519_public_key_; + network::x25519_keypair x25519_keypair_; EncryptType enc_type_; + bool v4_request_; + + DecryptedResponse _decrypt_v3_response(const std::string& response); + DecryptedResponse _decrypt_v4_response(const std::string& response); }; } // namespace session::onionreq diff --git a/include/session/random.hpp b/include/session/random.hpp index 54b33fba..3bf63a44 100644 --- a/include/session/random.hpp +++ b/include/session/random.hpp @@ -50,4 +50,23 @@ std::vector random(size_t size); /// - random base32 string of the specified length. std::string random_base32(size_t size); +/// API: random/get_uniform_distribution +/// +/// Generates a cryptographically secure random integer within a given range (inclusive). +/// +/// Inputs: +/// - `min` -- the minimum value for the range. +/// - `max` -- the maximum value for the range. +/// +/// Outputs: +/// - A random integer in the specified range +template +T get_uniform_distribution(T min, T max) { + if (min > max) + return min; + + const uint64_t range = static_cast(max) - static_cast(min) + 1; + return static_cast(static_cast(min) + (csrng() % range)); +} + } // namespace session::random diff --git a/include/session/session_network.h b/include/session/session_network.h deleted file mode 100644 index 534d93b2..00000000 --- a/include/session/session_network.h +++ /dev/null @@ -1,351 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include - -#include "export.h" -#include "log_level.h" -#include "onionreq/builder.h" -#include "platform.h" - -typedef enum CONNECTION_STATUS { - CONNECTION_STATUS_UNKNOWN = 0, - CONNECTION_STATUS_CONNECTING = 1, - CONNECTION_STATUS_CONNECTED = 2, - CONNECTION_STATUS_DISCONNECTED = 3, -} CONNECTION_STATUS; - -typedef struct network_object { - // Internal opaque object pointer; calling code should leave this alone. - void* internals; -} network_object; - -typedef struct network_service_node { - uint8_t ip[4]; - uint16_t quic_port; - char ed25519_pubkey_hex[65]; // The 64-byte ed25519 pubkey in hex + null terminator. -} network_service_node; - -typedef struct network_server_destination { - const char* method; - const char* protocol; - const char* host; - const char* endpoint; - uint16_t port; - const char* x25519_pubkey; - const char* const* headers; - const char* const* header_values; - size_t headers_size; -} network_server_destination; - -typedef struct onion_request_path { - const network_service_node* nodes; - const size_t nodes_count; -} onion_request_path; - -/// API: network/network_init -/// -/// Constructs a new network object. -/// -/// When done with the object the `network_object` must be destroyed by passing the pointer to -/// network_free(). -/// -/// Inputs: -/// - `network` -- [out] Pointer to the network object -/// - `cache_path` -- [in] Path where the snode cache files should be stored. Should be -/// NULL-terminated. -/// - `use_testnet` -- [in] Flag indicating whether the network should connect to testnet or -/// mainnet. -/// - `single_path_mode` -- [in] Flag indicating whether the network should be in "single path mode" -/// (ie. use a single path for everything - this is useful for iOS App Extensions which perform a -/// single action and then close so we don't waste time building other paths). -/// - `pre_build_paths` -- [in] Flag indicating whether the network should pre-build it's paths. -/// - `error` -- [out] the pointer to a buffer in which we will write an error string if an error -/// occurs; error messages are discarded if this is given as NULL. If non-NULL this must be a -/// buffer of at least 256 bytes. -/// -/// Outputs: -/// - `bool` -- Returns true on success; returns false and write the exception message as a C-string -/// into `error` (if not NULL) on failure. -LIBSESSION_EXPORT bool network_init( - network_object** network, - const char* cache_path, - bool use_testnet, - bool single_path_mode, - bool pre_build_paths, - char* error) LIBSESSION_WARN_UNUSED; - -/// API: network/network_free -/// -/// Frees a network object. -/// -/// Inputs: -/// - `network` -- [in] Pointer to network_object object -LIBSESSION_EXPORT void network_free(network_object* network); - -/// API: network/network_suspend -/// -/// Suspends the network preventing any further requests from creating new connections and paths. -/// This function also calls the `close_connections` function. -LIBSESSION_EXPORT void network_suspend(network_object* network); - -/// API: network/network_resume -/// -/// Resumes the network allowing new requests to creating new connections and paths. -LIBSESSION_EXPORT void network_resume(network_object* network); - -/// API: network/network_close_connections -/// -/// Closes any currently active connections. -LIBSESSION_EXPORT void network_close_connections(network_object* network); - -/// API: network/network_clear_cache -/// -/// Clears the cached from memory and from disk (if a cache path was provided during -/// initialization). -LIBSESSION_EXPORT void network_clear_cache(network_object* network); - -/// API: network/network_get_cache_size -/// -/// Retrieves the current size of the snode cache from memory (if a cache doesn't exist or -/// hasn't been loaded then this will return 0). -LIBSESSION_EXPORT size_t network_get_snode_cache_size(network_object* network); - -/// API: network/network_set_status_changed_callback -/// -/// Registers a callback to be called whenever the network connection status changes. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object -/// - `callback` -- [in] callback to be called when the network connection status changes. -/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. -LIBSESSION_EXPORT void network_set_status_changed_callback( - network_object* network, void (*callback)(CONNECTION_STATUS status, void* ctx), void* ctx); - -/// API: network/network_set_paths_changed_callback -/// -/// Registers a callback to be called whenever the onion request paths are updated. -/// -/// The pointer provided to the callback belongs to the caller and must be freed via `free()` when -/// done with it. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object -/// - `callback` -- [in] callback to be called when the onion request paths are updated. -/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. -LIBSESSION_EXPORT void network_set_paths_changed_callback( - network_object* network, - void (*callback)(onion_request_path* paths, size_t paths_len, void* ctx), - void* ctx); - -/// API: network/network_get_swarm -/// -/// Retrieves the swarm for the given pubkey. If there is already an entry in the cache for the -/// swarm then that will be returned, otherwise a network request will be made to retrieve the -/// swarm and save it to the cache. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object -/// - 'swarm_pubkey_hex' - [in] x25519 pubkey for the swarm in hex (64 characters). -/// - 'callback' - [in] callback to be called with the retrieved swarm (in the case of an error -/// the callback will be called with an empty list). -/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. -LIBSESSION_EXPORT void network_get_swarm( - network_object* network, - const char* swarm_pubkey_hex, - void (*callback)(network_service_node* nodes, size_t nodes_len, void*), - void* ctx); - -/// API: network/network_get_random_nodes -/// -/// Retrieves a number of random nodes from the snode pool. If the are no nodes in the pool a -/// new pool will be populated and the nodes will be retrieved from that. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object -/// - 'count' - [in] the number of nodes to retrieve. -/// - 'callback' - [in] callback to be called with the retrieved nodes (in the case of an error -/// the callback will be called with an empty list). -/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. -LIBSESSION_EXPORT void network_get_random_nodes( - network_object* network, - uint16_t count, - void (*callback)(network_service_node*, size_t, void*), - void* ctx); - -/// API: network/network_onion_response_callback_t -/// -/// Function pointer typedef for the callback function pointer given to -/// network_send_onion_request_to_snode_destination and -/// network_send_onion_request_to_server_destination. -/// -/// Fields: -/// - `success` -- true if the request was successful, false if it failed. -/// - `timeout` -- true if the request failed because of a timeout -/// - `status_code` -- the HTTP numeric status code of the request, e.g. 200 for OK -/// - `headers` -- the response headers, array of null-terminated C strings -/// - `header_values` -- the response header values, array of null-terminated C strings -/// - `headers_size` -- the number of `headers`/`header_values` -/// - `response` -- pointer to the beginning of the response body -/// - `response_size` -- length of the response body -/// - `ctx` -- the context pointer passed to the function that initiated the request. -typedef void (*network_onion_response_callback_t)( - bool success, - bool timeout, - int16_t status_code, - const char* const* headers, - const char* const* header_values, - size_t headers_size, - const char* response, - size_t response_size, - void* ctx); - -/// API: network/network_send_onion_request_to_snode_destination -/// -/// Sends a request via onion routing to the provided service node. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object. -/// - `node` -- [in] address information about the service node the request should be sent to. -/// - `body` -- [in] data to send to the specified node. -/// - `body_size` -- [in] size of the `body`. -/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take -/// the path build into account so if the path build takes forever then this request will never -/// timeout. -/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and -/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, -/// the request itself will be given a timeout of this value subtracting however long it took to -/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. -/// - `callback` -- [in] callback to be called with the result of the request. -/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set to -/// NULL if unused. -LIBSESSION_EXPORT void network_send_onion_request_to_snode_destination( - network_object* network, - const network_service_node node, - const unsigned char* body, - size_t body_size, - const char* swarm_pubkey_hex, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx); - -/// API: network/network_send_onion_request_to_server_destination -/// -/// Sends a request via onion routing to the provided server. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object. -/// - `server` -- [in] struct containing information about the server the request should be sent to. -/// - `body` -- [in] data to send to the specified endpoint. -/// - `body_size` -- [in] size of the `body`. -/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take -/// the path build into account so if the path build takes forever then this request will never -/// timeout. -/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and -/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, -/// the request itself will be given a timeout of this value subtracting however long it took to -/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. -/// - `callback` -- [in] callback to be called with the result of the request. -/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set -/// to NULL if unused. -LIBSESSION_EXPORT void network_send_onion_request_to_server_destination( - network_object* network, - const network_server_destination server, - const unsigned char* body, - size_t body_size, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx); - -/// API: network/network_upload_to_server -/// -/// Uploads a file to a server. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object. -/// - `server` -- [in] struct containing information about the server the request should be sent to. -/// - `data` -- [in] data to upload to the file server. -/// - `data_len` -- [in] size of the `data`. -/// - `file_name` -- [in, optional] name of the file being uploaded. MUST be null terminated. -/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take -/// the path build into account so if the path build takes forever then this request will never -/// timeout. -/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and -/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, -/// the request itself will be given a timeout of this value subtracting however long it took to -/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. -/// - `callback` -- [in] callback to be called with the result of the request. -/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set -/// to NULL if unused. -LIBSESSION_EXPORT void network_upload_to_server( - network_object* network, - const network_server_destination server, - const unsigned char* data, - size_t data_len, - const char* file_name, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx); - -/// API: network/network_download_from_server -/// -/// Downloads a file from a server. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object. -/// - `server` -- [in] struct containing information about file to be downloaded. -/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take -/// the path build into account so if the path build takes forever then this request will never -/// timeout. -/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and -/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, -/// the request itself will be given a timeout of this value subtracting however long it took to -/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. -/// - `callback` -- [in] callback to be called with the result of the request. -/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set -/// to NULL if unused. -LIBSESSION_EXPORT void network_download_from_server( - network_object* network, - const network_server_destination server, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx); - -/// API: network/network_get_client_version -/// -/// Retrieves the version information for the given platform. -/// -/// Inputs: -/// - `network` -- [in] Pointer to the network object. -/// - `platform` -- [in] the platform to retrieve the client version for. -/// - `ed25519_secret` -- [in] the users ed25519 secret key (used for blinded auth - 64 bytes). -/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take -/// the path build into account so if the path build takes forever then this request will never -/// timeout. -/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and -/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, -/// the request itself will be given a timeout of this value subtracting however long it took to -/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. -/// - `callback` -- [in] callback to be called with the result of the request. -/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set -/// to NULL if unused. -LIBSESSION_EXPORT void network_get_client_version( - network_object* network, - CLIENT_PLATFORM platform, - const unsigned char* ed25519_secret, /* 64 bytes */ - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx); - -#ifdef __cplusplus -} -#endif diff --git a/include/session/session_network.hpp b/include/session/session_network.hpp deleted file mode 100644 index f526ee21..00000000 --- a/include/session/session_network.hpp +++ /dev/null @@ -1,755 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "onionreq/builder.hpp" -#include "onionreq/key_types.hpp" -#include "platform.hpp" -#include "session/random.hpp" -#include "types.hpp" - -namespace session::network { - -namespace fs = std::filesystem; - -using network_response_callback_t = std::function> headers, - std::optional response)>; - -enum class ConnectionStatus { - unknown, - connecting, - connected, - disconnected, -}; - -enum class PathType { - standard, - upload, - download, -}; - -using swarm_id_t = uint64_t; -constexpr swarm_id_t INVALID_SWARM_ID = std::numeric_limits::max(); - -struct service_node : public oxen::quic::RemoteAddress { - public: - std::vector storage_server_version; - swarm_id_t swarm_id; - - service_node() = delete; - - template - service_node( - std::string_view remote_pk, - std::vector storage_server_version, - swarm_id_t swarm_id, - Opt&&... opts) : - oxen::quic::RemoteAddress{remote_pk, std::forward(opts)...}, - storage_server_version{storage_server_version}, - swarm_id{swarm_id} {} - - template - service_node( - std::span remote_pk, - std::vector storage_server_version, - swarm_id_t swarm_id, - Opt&&... opts) : - oxen::quic::RemoteAddress{remote_pk, std::forward(opts)...}, - storage_server_version{storage_server_version}, - swarm_id{swarm_id} {} - - service_node(const service_node& obj) : - oxen::quic::RemoteAddress{obj}, - storage_server_version{obj.storage_server_version}, - swarm_id{obj.swarm_id} {} - service_node& operator=(const service_node& obj) { - storage_server_version = obj.storage_server_version; - swarm_id = obj.swarm_id; - oxen::quic::RemoteAddress::operator=(obj); - _copy_internals(obj); - return *this; - } - - auto operator<=>(const service_node& other) const = delete; - bool operator==(const service_node& other) const { - return RemoteAddress::operator==(other) && - storage_server_version == other.storage_server_version && swarm_id == other.swarm_id; - } -}; - -struct connection_info { - service_node node; - std::shared_ptr pending_requests; - std::shared_ptr conn; - std::shared_ptr stream; - - bool is_valid() const { return conn && stream && !stream->is_closing(); }; - bool has_pending_requests() const { return (pending_requests && (*pending_requests) > 0); }; - - void add_pending_request() { - if (!pending_requests) - pending_requests = std::make_shared(0); - (*pending_requests)++; - }; - - // This is weird but since we are modifying the shared_ptr we aren't mutating - // the object so it can be a const function - void remove_pending_request() const { - if (!pending_requests) - return; - (*pending_requests)--; - }; -}; - -struct onion_path { - std::string id; - connection_info conn_info; - std::vector nodes; - uint8_t failure_count; - - bool is_valid() const { return !nodes.empty() && conn_info.is_valid(); }; - bool has_pending_requests() const { return conn_info.has_pending_requests(); } - size_t num_pending_requests() const { - if (!conn_info.pending_requests) - return 0; - return (*conn_info.pending_requests); - } - - std::string to_string() const; - - bool contains_node(const service_node& sn) const; - - bool operator==(const onion_path& other) const { - // The `conn_info` and failure/timeout counts can be reset for a path in a number - // of situations so just use the nodes to determine if the paths match - return nodes == other.nodes; - } -}; - -namespace detail { - swarm_id_t pubkey_to_swarm_space(const session::onionreq::x25519_pubkey& pk); - std::vector>> generate_swarms( - std::vector nodes); - - std::optional node_for_destination(onionreq::network_destination destination); - - session::onionreq::x25519_pubkey pubkey_for_destination( - onionreq::network_destination destination); - -} // namespace detail - -struct request_info { - static request_info make( - onionreq::network_destination _dest, - std::optional> _original_body, - std::optional _swarm_pk, - std::chrono::milliseconds _request_timeout, - std::optional _request_and_path_build_timeout = std::nullopt, - PathType _type = PathType::standard, - std::optional _req_id = std::nullopt, - std::optional endpoint = "onion_req", - std::optional> _body = std::nullopt); - - enum class RetryReason { - none, - decryption_failure, - redirect, - redirect_swarm_refresh, - }; - - std::string request_id; - session::onionreq::network_destination destination; - std::string endpoint; - std::optional> body; - std::optional> original_body; - std::optional swarm_pubkey; - PathType path_type; - std::chrono::milliseconds request_timeout; - std::optional request_and_path_build_timeout; - std::chrono::system_clock::time_point creation_time = std::chrono::system_clock::now(); - - /// The reason we are retrying the request (if it's a retry). Generally only used for internal - /// purposes (like receiving a `421`) in order to prevent subsequent retries. - std::optional retry_reason{}; - - bool node_destination{detail::node_for_destination(destination).has_value()}; -}; - -class Network { - private: - const bool use_testnet; - const bool should_cache_to_disk; - const bool single_path_mode; - const fs::path cache_path; - - // Disk thread state - std::mutex snode_cache_mutex; // This guards all the below: - std::condition_variable snode_cache_cv; - bool has_pending_disk_write = false; - bool shut_down_disk_thread = false; - bool need_write = false; - bool need_clear_cache = false; - - // Values persisted to disk - std::optional seed_node_cache_size; - std::vector snode_cache; - std::chrono::system_clock::time_point last_snode_cache_update{}; - - std::thread disk_write_thread; - - // General values - bool suspended = false; - ConnectionStatus status; - - std::shared_ptr loop; - std::shared_ptr endpoint; - std::unordered_map> paths; - std::vector> paths_pending_drop; - std::vector unused_nodes; - std::unordered_map snode_failure_counts; - std::vector>> all_swarms; - std::unordered_map>> swarm_cache; - - // Snode refresh state - int snode_cache_refresh_failure_count = 0; - int in_progress_snode_cache_refresh_count = 0; - std::optional current_snode_cache_refresh_request_id; - std::vector> after_snode_cache_refresh; - std::optional> unused_snode_refresh_nodes; - std::shared_ptr>> snode_refresh_results; - - // First hop state - int connection_failures = 0; - std::deque unused_connections; - std::unordered_map in_progress_connections; - - // Path build state - int path_build_failures = 0; - std::deque path_build_queue; - std::unordered_map in_progress_path_builds; - - // Request state - bool has_scheduled_resume_queues = false; - std::optional request_timeout_id; - std::chrono::system_clock::time_point last_resume_queues_timestamp{}; - std::unordered_map>> - request_queue; - - public: - friend class TestNetwork; - friend class TestNetworkWrapper; - - // Hook to be notified whenever the network connection status changes. - std::function status_changed; - - // Hook to be notified whenever the onion request paths are updated. - std::function> paths)> paths_changed; - - // Constructs a new network with the given cache path and a flag indicating whether it should - // use testnet or mainnet, all requests should be made via a single Network instance. - Network(std::optional cache_path, - bool use_testnet, - bool single_path_mode, - bool pre_build_paths); - virtual ~Network(); - - /// API: network/suspend - /// - /// Suspends the network preventing any further requests from creating new connections and - /// paths. This function also calls the `close_connections` function. - void suspend(); - - /// API: network/resume - /// - /// Resumes the network allowing new requests to creating new connections and paths. - void resume(); - - /// API: network/close_connections - /// - /// Closes any currently active connections. - void close_connections(); - - /// API: network/clear_cache - /// - /// Clears the cached from memory and from disk (if a cache path was provided during - /// initialization). - void clear_cache(); - - /// API: network/snode_cache_size - /// - /// Retrieves the current size of the snode cache from memory (if a cache doesn't exist or - /// hasn't been loaded then this will return 0). - size_t snode_cache_size(); - - /// API: network/get_swarm - /// - /// Retrieves the swarm for the given pubkey. If there is already an entry in the cache for the - /// swarm then that will be returned, otherwise a network request will be made to retrieve the - /// swarm and save it to the cache. - /// - /// Inputs: - /// - 'swarm_pubkey' - [in] public key for the swarm. - /// - 'callback' - [in] callback to be called with the retrieved swarm (in the case of an error - /// the callback will be called with an empty list). - void get_swarm( - session::onionreq::x25519_pubkey swarm_pubkey, - std::function swarm)> callback); - - /// API: network/get_random_nodes - /// - /// Retrieves a number of random nodes from the snode pool. If the are no nodes in the pool a - /// new pool will be populated and the nodes will be retrieved from that. - /// - /// Inputs: - /// - 'count' - [in] the number of nodes to retrieve. - /// - 'callback' - [in] callback to be called with the retrieved nodes (in the case of an error - /// the callback will be called with an empty list). - void get_random_nodes( - uint16_t count, std::function nodes)> callback); - - /// API: network/send_onion_request - /// - /// Sends a request via onion routing to the provided service node or server destination. - /// - /// Inputs: - /// - `destination` -- [in] service node or server destination information. - /// - `body` -- [in] data to send to the specified destination. - /// - `swarm_pubkey` -- [in, optional] pubkey for the swarm the request is associated with. - /// Should be NULL if the request is not associated with a swarm. - /// - `handle_response` -- [in] callback to be called with the result of the request. - /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take - /// the path build into account so if the path build takes forever then this request will never - /// timeout. - /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request - /// and path build (if required). This value takes presedence over `request_timeout` if - /// provided, the request itself will be given a timeout of this value subtracting however long - /// it took to build the path. - /// - 'type' - [in] the type of paths to send the request across. - void send_onion_request( - onionreq::network_destination destination, - std::optional> body, - std::optional swarm_pubkey, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout = std::nullopt, - PathType type = PathType::standard); - - /// API: network/upload_file_to_server - /// - /// Uploads a file to a given server destination. - /// - /// Inputs: - /// - 'data' - [in] the data to be uploaded to a server. - /// - `server` -- [in] the server destination to upload the file to. - /// - `file_name` -- [in, optional] optional name to use for the file. - /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take - /// the path build into account so if the path build takes forever then this request will never - /// timeout. - /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request - /// and path build (if required). This value takes presedence over `request_timeout` if - /// provided, the request itself will be given a timeout of this value subtracting however long - /// it took to build the path. - /// - `handle_response` -- [in] callback to be called with the result of the request. - void upload_file_to_server( - std::vector data, - onionreq::ServerDestination server, - std::optional file_name, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout = std::nullopt); - - /// API: network/download_file - /// - /// Download a file from a given server destination. - /// - /// Inputs: - /// - `server` -- [in] the server destination to download the file from. - /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take - /// the path build into account so if the path build takes forever then this request will never - /// timeout. - /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request - /// and path build (if required). This value takes presedence over `request_timeout` if - /// provided, the request itself will be given a timeout of this value subtracting however long - /// it took to build the path. - /// - `handle_response` -- [in] callback to be called with the result of the request. - void download_file( - onionreq::ServerDestination server, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout = std::nullopt); - - /// API: network/download_file - /// - /// Convenience function to download a file from a given url and x25519 pubkey combination. - /// Calls through to the above `download_file` function after constructing a server destination - /// from the provided values. - /// - /// Inputs: - /// - `download_url` -- [in] the url to download the file from. - /// - `x25519_pubkey` -- [in] the server destination to download the file from. - /// - `timeout` -- [in] timeout in milliseconds to use for the request. - /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take - /// the path build into account so if the path build takes forever then this request will never - /// timeout. - /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request - /// and path build (if required). This value takes presedence over `request_timeout` if - /// provided, the request itself will be given a timeout of this value subtracting however long - /// it took to build the path. - /// - `handle_response` -- [in] callback to be called with the result of the request. - void download_file( - std::string_view download_url, - onionreq::x25519_pubkey x25519_pubkey, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout = std::nullopt); - - /// API: network/get_client_version - /// - /// Retrieves the version information for the given platform. - /// - /// Inputs: - /// - `platform` -- [in] the platform to retrieve the client version for. - /// - `seckey` -- [in] the users ed25519 secret key (to generated blinded auth). - /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take - /// the path build into account so if the path build takes forever then this request will never - /// timeout. - /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request - /// and path build (if required). This value takes presedence over `request_timeout` if - /// provided, the request itself will be given a timeout of this value subtracting however long - /// it took to build the path. - /// - `handle_response` -- [in] callback to be called with the result of the request. - void get_client_version( - Platform platform, - onionreq::ed25519_seckey seckey, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout = std::nullopt); - - private: - /// API: network/all_path_ips - /// - /// Internal function to retrieve all of the node ips current used in paths - std::vector all_path_ips() const { - std::vector result; - - for (const auto& [path_type, paths_for_type] : paths) - for (const auto& path : paths_for_type) - for (const auto& node : path.nodes) - result.emplace_back(node.to_ipv4()); - - return result; - }; - - /// API: network/update_disk_cache_throttled - /// - /// Function which can be used to notify the disk write thread that a write can be performed. - /// This function has a very basic throttling mechanism where it triggers the write a small - /// delay after it is called, any subsequent calls to the function within the same period will - /// be ignored. This is done to avoid excessive disk writes which probably aren't needed for - /// the cached network data. - virtual void update_disk_cache_throttled(bool force_immediate_write = false); - - /// API: network/disk_write_thread_loop - /// - /// Body of the disk writer which runs until signalled to stop. This is intended to run in its - /// own thread. The thread monitors a number of private variables and persists the snode pool - /// and swarm caches to disk if a `cache_path` was provided during initialization. - void disk_write_thread_loop(); - - /// API: network/load_cache_from_disk - /// - /// Loads the snode pool and swarm caches from disk if a `cache_path` was provided and cached - /// data exists. - void load_cache_from_disk(); - - /// API: network/_close_connections - /// - /// Triggered via the close_connections function but actually contains the logic to clear out - /// paths, requests and connections. This function is not thread safe so should should be - /// called with that in mind. - void _close_connections(); - - /// API: network/update_status - /// - /// Internal function to update the connection status and trigger the `status_changed` hook if - /// provided, this method ignores invalid or unchanged status changes. - /// - /// Inputs: - /// - 'updated_status' - [in] the updated connection status. - void update_status(ConnectionStatus updated_status); - - /// API: network/retry_delay - /// - /// A function which generates an exponential delay to wait before retrying a request/action - /// based on the provided failure count. - /// - /// Inputs: - /// - 'num_failures' - [in] the number of times the request has already failed. - /// - 'max_delay' - [in] the maximum amount of time to delay for. - virtual std::chrono::milliseconds retry_delay( - int num_failures, - std::chrono::milliseconds max_delay = std::chrono::milliseconds{5000}); - - /// API: network/get_endpoint - /// - /// Retrieves or creates a new endpoint pointer. - std::shared_ptr get_endpoint(); - - /// API: network/min_snode_cache_size - /// - /// When talking to testnet it's occassionally possible for the cache size to be smaller than - /// the `min_snode_cache_count` value (which would result in an endless loop re-fetching the - /// node cache) so instead this function will return the smaller of the two if we've done a - /// fetch from a seed node. - size_t min_snode_cache_size() const; - - /// API: network/get_unused_nodes - /// - /// Retrieves a list of all nodes in the cache which are currently unused (ie. not present in an - /// exising or pending path, connection or request). - /// - /// Outputs: - /// - The list of unused nodes. - std::vector get_unused_nodes(); - - /// API: network/establish_connection - /// - /// Establishes a connection to the target node and triggers the callback once the connection is - /// established (or closed in case it fails). - /// - /// Inputs: - /// - 'id' - [in] id for the request or path build which triggered the call. - /// - `target` -- [in] the target service node to connect to. - /// - `timeout` -- [in, optional] optional timeout for the request, if NULL the - /// `quic::DEFAULT_HANDSHAKE_TIMEOUT` will be used. - /// - `callback` -- [in] callback to be called with connection info once the connection is - /// established or fails. - void establish_connection( - std::string id, - service_node target, - std::optional timeout, - std::function error)> callback); - - /// API: network/establish_and_store_connection - /// - /// Establishes a connection to a random unused node and stores it in the `unused_connections` - /// list. - /// - /// Inputs: - /// - 'path_id' - [in] id for the path build which triggered the call. - virtual void establish_and_store_connection(std::string path_id); - - /// API: network/refresh_snode_cache_complete - /// - /// This function will be called from either `refresh_snode_cache` or - /// `refresh_snode_cache_from_seed_nodes` and will actually update the state and persist the - /// updated cache to disk. - /// - /// Inputs: - /// - 'nodes' - [in] the nodes to use as the updated cache. - void refresh_snode_cache_complete(std::vector nodes); - - /// API: network/refresh_snode_cache_from_seed_nodes - /// - /// This function refreshes the snode cache for a random seed node. Unlike the - /// `refresh_snode_cache` function this will update the cache with the response from a single - /// seed node since it's a trusted source. - /// - /// Inputs: - /// - 'request_id' - [in] id for an existing refresh_snode_cache request. - /// - 'reset_unused_nodes' - [in] flag to indicate whether this should reset the unused nodes - /// before kicking off the request. - virtual void refresh_snode_cache_from_seed_nodes( - std::string request_id, bool reset_unused_nodes); - - /// API: network/refresh_snode_cache - /// - /// This function refreshes the snode cache. If the current cache is to small (or not present) - /// this will trigger the above `refresh_snode_cache_from_seed_nodes` function, otherwise it - /// will randomly pick a number of nodes from the existing cache and refresh the cache from the - /// intersection of the results. - /// - /// Inputs: - /// - 'existing_request_id' - [in, optional] id for an existing refresh_snode_cache request. - virtual void refresh_snode_cache(std::optional existing_request_id = std::nullopt); - - /// API: network/build_path - /// - /// Build a new onion request path for the specified type. If there are no existing connections - /// this will open a new connection to a random service nodes in the snode cache. - /// - /// Inputs: - /// - 'path_id' - [in] id for the new path. - /// - `path_type` -- [in] the type of path to build. - virtual void build_path(std::string path_id, PathType path_type); - - /// API: network/find_valid_path - /// - /// Find a random path from the provided paths which is valid for the provided request. Note: - /// if the Network is setup in `single_path_mode` then the path returned may include the - /// destination for the request. - /// - /// Inputs: - /// - `info` -- [in] request to select a path for. - /// - `paths` -- [in] paths to select from. - /// - /// Outputs: - /// - The possible path, if found. - virtual std::optional find_valid_path( - const request_info info, const std::vector paths); - - /// API: network/build_path_if_needed - /// - /// Triggers a path build for the specified type if the total current or pending paths is below - /// the minimum threshold for the given type. Note: This may result in more paths than the - /// minimum threshold being built in order to avoid a situation where a request may never get - /// sent due to it's destination being present in the existing path(s) for the type. - /// - /// Inputs: - /// - `path_type` -- [in] the type of path to be built. - /// - `found_path` -- [in] flag indicating whether a valid path was found by calling - /// `find_valid_path` above. - virtual void build_path_if_needed(PathType path_type, bool found_valid_path); - - /// API: network/get_service_nodes - /// - /// Retrieves all or a random subset of service nodes from the given node. - /// - /// Inputs: - /// - 'request_id' - [in] id for the request which triggered the call. - /// - `conn_info` -- [in] the connection info to retrieve service nodes from. - /// - `limit` -- [in, optional] the number of service nodes to retrieve. - /// - `callback` -- [in] callback to be triggered once we receive nodes. NOTE: If an error - /// occurs an empty list and an error will be provided. - void get_service_nodes( - std::string request_id, - connection_info conn_info, - std::optional limit, - std::function nodes, std::optional error)> - callback); - - /// API: network/check_request_queue_timeouts - /// - /// Checks if any of the requests in the request queue have timed out (and fails them if so). - /// - /// Inputs: - /// - 'request_timeout_id' - [in] id for the timeout loop to prevent multiple loops from being - /// scheduled. - virtual void check_request_queue_timeouts( - std::optional request_timeout_id = std::nullopt); - - /// API: network/send_request - /// - /// Send a request via the network. - /// - /// Inputs: - /// - `info` -- [in] wrapper around all of the information required to send a request. - /// - `conn` -- [in] connection information used to send the request. - /// - `handle_response` -- [in] callback to be called with the result of the request. - void send_request( - request_info info, connection_info conn, network_response_callback_t handle_response); - - /// API: network/_send_onion_request - /// - /// Internal function invoked by ::send_onion_request after request_info construction - virtual void _send_onion_request( - request_info info, network_response_callback_t handle_response); - - /// API: network/process_v3_onion_response - /// - /// Processes a v3 onion request response. - /// - /// Inputs: - /// - `builder` -- [in] the builder that was used to build the onion request. - /// - `response` -- [in] the response data returned from the destination. - /// - /// Outputs: - /// - A tuple containing the status code, headers and body of the decrypted onion request - /// response. - virtual std::tuple< - int16_t, - std::vector>, - std::optional> - process_v3_onion_response(session::onionreq::Builder builder, std::string response); - - /// API: network/process_v4_onion_response - /// - /// Processes a v4 onion request response. - /// - /// Inputs: - /// - `builder` -- [in] the builder that was used to build the onion request. - /// - `response` -- [in] the response data returned from the destination. - /// - /// Outputs: - /// - A tuple containing the status code, headers and body of the decrypted onion request - /// response. - virtual std::tuple< - int16_t, - std::vector>, - std::optional> - process_v4_onion_response(session::onionreq::Builder builder, std::string response); - - /// API: network/validate_response - /// - /// Processes a quic response to extract the status code and body or throw if it errored or - /// received a non-successful status code. - /// - /// Inputs: - /// - `resp` -- [in] the quic response. - /// - `is_bencoded` -- [in] flag indicating whether the response will be bencoded or JSON. - /// - /// Returns: - /// - `std::pair` -- the status code and response body (for a bencoded - /// response this is just the direct response body from quic as it simplifies consuming the - /// response elsewhere). - std::pair validate_response(oxen::quic::message resp, bool is_bencoded); - - /// API: network/drop_path_when_empty - /// - /// Flags a path to be dropped once all pending requests have finished. - /// - /// Inputs: - /// - `id` -- [in] id the request or path which triggered the path drop (if the id is a path_id - /// then the drop was triggered by the connection being dropped). - /// - `path_type` -- [in] the type of path to build. - /// - `path` -- [in] the path to be dropped. - void drop_path_when_empty(std::string id, PathType path_type, onion_path path); - - /// API: network/clear_empty_pending_path_drops - /// - /// Iterates through all paths flagged to be dropped and actually drops any which are no longer - /// valid or have no more pending requests. - void clear_empty_pending_path_drops(); - - /// API: network/handle_errors - /// - /// Processes a non-success response to automatically perform any standard operations based on - /// the errors returned from the service node network (ie. updating the service node cache, - /// dropping nodes and/or onion request paths). - /// - /// Inputs: - /// - `info` -- [in] the information for the request that was made. - /// - `conn_info` -- [in] the connection info for the request that failed. - /// - `timeout` -- [in, optional] flag indicating whether the request timed out. - /// - `status_code` -- [in] the status code returned from the network. - /// - `headers` -- [in] the response headers returned from the network. - /// - `response` -- [in, optional] response data returned from the network. - /// - `handle_response` -- [in, optional] callback to be called with updated response - /// information after processing the error. - virtual void handle_errors( - request_info info, - connection_info conn_info, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response, - std::optional handle_response); -}; - -} // namespace session::network diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1e4ad566..3afdf431 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -98,20 +98,33 @@ target_link_libraries(config libzstd::static ) -if(ENABLE_ONIONREQ) - add_libsession_util_library(onionreq +if(ENABLE_NETWORKING) + add_libsession_util_library(network onionreq/builder.cpp onionreq/hop_encryption.cpp - onionreq/key_types.cpp onionreq/parser.cpp onionreq/response_parser.cpp - session_network.cpp + network/key_types.cpp + network/network_config.cpp + network/request_queue.cpp + network/service_node.cpp + network/session_network_internal.cpp + network/session_network_types.cpp + network/session_network.cpp + network/snode_pool.cpp + network/swarm.cpp + network/backends/session_file_server.cpp + network/transport/quic_transport.cpp + network/routing/direct_router.cpp + network/routing/lokinet_router.cpp + network/routing/onion_request_router.cpp ) - target_link_libraries(onionreq + target_link_libraries(network PUBLIC crypto quic + lokinet::liblokinet PRIVATE nlohmann_json::nlohmann_json libsodium::sodium-internal @@ -119,7 +132,7 @@ if(ENABLE_ONIONREQ) ) if (BUILD_STATIC_DEPS) - target_include_directories(onionreq PUBLIC ${CMAKE_BINARY_DIR}/static-deps/include) + target_include_directories(network PUBLIC ${CMAKE_BINARY_DIR}/static-deps/include) endif() endif() diff --git a/src/file.cpp b/src/file.cpp index 8f0ee656..3077800f 100644 --- a/src/file.cpp +++ b/src/file.cpp @@ -19,14 +19,19 @@ std::ifstream open_for_reading(const fs::path& filename) { return in; } -std::string read_whole_file(const fs::path& filename) { +std::vector read_whole_file(const fs::path& filename) { auto in = open_for_reading(filename); - std::string contents; in.seekg(0, std::ios::end); auto size = in.tellg(); in.seekg(0, std::ios::beg); - contents.resize(size); - in.read(contents.data(), size); + + if (size <= 0) + return {}; + + std::vector contents(static_cast(size)); + if (!in.read(reinterpret_cast(contents.data()), size)) + return {}; + return contents; } diff --git a/src/network/backends/session_file_server.cpp b/src/network/backends/session_file_server.cpp new file mode 100644 index 00000000..328a82ef --- /dev/null +++ b/src/network/backends/session_file_server.cpp @@ -0,0 +1,184 @@ +#include "session/network/backends/session_file_server.hpp" + +#include +#include + +#include +#include + +#include "../session_network_internal.hpp" +#include "session/blinding.hpp" +#include "session/network/backends/session_file_server.h" +#include "session/random.hpp" + +using namespace oxen; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network::file_server { + +namespace { + + constexpr auto FILE_SERVER_HOST = "filev2.getsession.org"sv; + constexpr auto FILE_SERVER_PUBKEY_HEX = + "da21e1d886c6fbaea313f75298bd64aab03a97ce985b46bb2dad9f2089c8ee59"sv; + + constexpr auto ENDPOINT_FILE = "file"; +} // namespace + +Request upload( + std::vector data, + std::optional file_name, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout) { + return {"UL-{}"_format(random::random_base32(4)), + ServerDestination{ + "http", // protocol + std::string{FILE_SERVER_HOST}, // host + x25519_pubkey::from_hex(FILE_SERVER_PUBKEY_HEX), + 80, // port + std::nullopt, // headers (Network will add them) + "POST" // method + }, + ENDPOINT_FILE, + std::move(data), + RequestCategory::upload, + request_timeout, + overall_timeout, + UploadInfo{std::move(file_name)}}; +} + +Request download( + std::string file_id, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout) { + return {"DL-{}"_format(random::random_base32(4)), + ServerDestination{ + "http", // protocol + std::string{FILE_SERVER_HOST}, // host + x25519_pubkey::from_hex(FILE_SERVER_PUBKEY_HEX), + 80, // port + std::nullopt, // headers (Network will add them) + "GET" // method + }, + "{}/{}"_format(ENDPOINT_FILE, file_id), + std::nullopt, + RequestCategory::download, + request_timeout, + overall_timeout}; +} + +Request get_client_version( + Platform platform, + network::ed25519_seckey seckey, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout) { + std::string endpoint; + + switch (platform) { + case Platform::android: endpoint = "/session_version?platform=android"; break; + case Platform::desktop: endpoint = "/session_version?platform=desktop"; break; + case Platform::ios: endpoint = "/session_version?platform=ios"; break; + } + + // Generate the auth signature + auto blinded_keys = blind_version_key_pair(to_span(seckey.view())); + auto timestamp = std::chrono::duration_cast( + (std::chrono::system_clock::now()).time_since_epoch()) + .count(); + auto signature = blind_version_sign(to_span(seckey.view()), platform, timestamp); + auto pubkey = x25519_pubkey::from_hex(FILE_SERVER_PUBKEY_HEX); + std::string blinded_pk_hex; + blinded_pk_hex.reserve(66); + blinded_pk_hex += "07"; + oxenc::to_hex( + blinded_keys.first.begin(), + blinded_keys.first.end(), + std::back_inserter(blinded_pk_hex)); + + auto headers = std::vector>{}; + headers.emplace_back("X-FS-Pubkey", blinded_pk_hex); + headers.emplace_back("X-FS-Timestamp", "{}"_format(timestamp)); + headers.emplace_back("X-FS-Signature", oxenc::to_base64(signature.begin(), signature.end())); + + return {"GCV-{}"_format(random::random_base32(4)), + ServerDestination{ + "http", // protocol + std::string{FILE_SERVER_HOST}, // host + x25519_pubkey::from_hex(FILE_SERVER_PUBKEY_HEX), + 80, // port + headers, + "GET" // method + }, + std::move(endpoint), + std::nullopt, + RequestCategory::standard, + request_timeout, + overall_timeout}; +} + +} // namespace session::network::file_server + +extern "C" { + +using namespace session; +using namespace session::network; + +LIBSESSION_C_API session_request_params* session_file_server_upload( + const unsigned char* data, + size_t data_len, + const char* file_name, + int64_t request_timeout_ms, + int64_t overall_timeout_ms) { + try { + auto req = file_server::upload( + {data, data + data_len}, + (file_name ? std::optional{std::string{file_name}} : std::nullopt), + std::chrono::milliseconds{request_timeout_ms}, + (overall_timeout_ms > 0 + ? std::optional{std::chrono::milliseconds{overall_timeout_ms}} + : std::nullopt)); + + return session::network::detail::convert_cpp_request_to_c(req); + } catch (...) { + return nullptr; + } +} + +LIBSESSION_C_API session_request_params* session_file_server_download( + const char* file_id, int64_t request_timeout_ms, int64_t overall_timeout_ms) { + try { + auto req = file_server::download( + file_id, + std::chrono::milliseconds{request_timeout_ms}, + (overall_timeout_ms > 0 + ? std::optional{std::chrono::milliseconds{overall_timeout_ms}} + : std::nullopt)); + + return session::network::detail::convert_cpp_request_to_c(req); + } catch (...) { + return nullptr; + } +} + +LIBSESSION_C_API session_request_params* session_file_server_get_client_version( + CLIENT_PLATFORM platform, + const unsigned char* ed25519_secret, /* 64 bytes */ + int64_t request_timeout_ms, + int64_t overall_timeout_ms) { + try { + auto req = file_server::get_client_version( + static_cast(platform), + network::ed25519_seckey::from_bytes({ed25519_secret, 64}), + std::chrono::milliseconds{request_timeout_ms}, + (overall_timeout_ms > 0 + ? std::optional{std::chrono::milliseconds{overall_timeout_ms}} + : std::nullopt)); + + return session::network::detail::convert_cpp_request_to_c(req); + } catch (...) { + return nullptr; + } +} + +} // extern "C" \ No newline at end of file diff --git a/src/onionreq/key_types.cpp b/src/network/key_types.cpp similarity index 96% rename from src/onionreq/key_types.cpp rename to src/network/key_types.cpp index f99174a2..1e65a048 100644 --- a/src/onionreq/key_types.cpp +++ b/src/network/key_types.cpp @@ -1,4 +1,4 @@ -#include "session/onionreq/key_types.hpp" +#include "session/network/key_types.hpp" #include #include @@ -8,7 +8,7 @@ #include #include -namespace session::onionreq { +namespace session::network { namespace detail { @@ -90,4 +90,4 @@ x25519_pubkey compute_x25519_pubkey(std::span ed25519_pk) { return x25519_pubkey::from_bytes({xpk.data(), 32}); } -} // namespace session::onionreq +} // namespace session::network diff --git a/src/network/network_config.cpp b/src/network/network_config.cpp new file mode 100644 index 00000000..6b3160dd --- /dev/null +++ b/src/network/network_config.cpp @@ -0,0 +1,250 @@ +#include "session/network/network_config.hpp" + +#include +#include +#include + +using namespace oxen; +using namespace oxen::log::literals; + +namespace session::network::config { + +inline auto cat = oxen::log::Cat("network"); + +Config::Config(const std::vector& opts) { + for (const auto& opt_any : opts) { +#define HANDLE_TYPE(T) \ + if (const auto* p = std::any_cast(&opt_any)) { \ + handle_config_opt(*p); \ + continue; \ + } + + HANDLE_TYPE(opt::netid); + HANDLE_TYPE(opt::router); + HANDLE_TYPE(opt::transport); + HANDLE_TYPE(opt::path_length); + HANDLE_TYPE(opt::disable_subnet_diversity); + HANDLE_TYPE(opt::redirect_retry_count); + HANDLE_TYPE(opt::retry_delay); + HANDLE_TYPE(opt::request_timeout_check_frequency); + + // Snode pool options + HANDLE_TYPE(opt::cache_directory); + HANDLE_TYPE(opt::cache_expiration); + HANDLE_TYPE(opt::cache_min_lifetime); + HANDLE_TYPE(opt::cache_min_size); + HANDLE_TYPE(opt::cache_num_nodes_to_use_for_refresh); + HANDLE_TYPE(opt::cache_node_failure_threshold); + HANDLE_TYPE(opt::cache_refresh_using_legacy_endpoint); + + // Quic transport options + HANDLE_TYPE(opt::quic_handshake_timeout); + HANDLE_TYPE(opt::quic_keep_alive); + HANDLE_TYPE(opt::quic_disable_mtu_discovery); + + // Onion request router options + HANDLE_TYPE(opt::onionreq_path_failure_threshold); + HANDLE_TYPE(opt::onionreq_path_build_retry_limit); + HANDLE_TYPE(opt::onionreq_min_path_count); + HANDLE_TYPE(opt::onionreq_disable_pre_build_paths); + + log::warning(cat, "Ignoring unknown option type in Config constructor"); +#undef HANDLE_TYPE + } + + _init(); +} + +void Config::_init() { + log::debug(cat, "Network config created successfully"); +} + +void Config::handle_config_opt(opt::netid netid_) { + netid = netid_.target; + seed_nodes = std::move(netid_.seed_nodes); + + switch (netid_.target) { + case opt::netid::Target::mainnet: + log::debug( + cat, "Network config set to mainnet with {} seed node(s)", seed_nodes.size()); + break; + case opt::netid::Target::testnet: + log::debug( + cat, "Network config set to testnet with {} seed node(s)", seed_nodes.size()); + break; + + case opt::netid::Target::devnet: + log::debug(cat, "Network config set to devnet with {} seed node(s)", seed_nodes.size()); + break; + } +} + +void Config::handle_config_opt(opt::router router_) { + router = router_.type; + + switch (router_.type) { + case opt::router::Type::onion_requests: + log::debug(cat, "Network config set to route requests using Onion Requests"); + break; + + case opt::router::Type::lokinet: + log::debug(cat, "Network config set to route requests using Lokinet"); + break; + + case opt::router::Type::direct: + log::debug(cat, "Network config set to route requests directly"); + break; + } +} + +void Config::handle_config_opt(opt::transport transport_) { + transport = transport_.type; + + switch (transport_.type) { + case opt::transport::Type::quic: + log::debug(cat, "Network config set to transport requests via QUIC"); + break; + + case opt::transport::Type::callbacks: { + if (!transport_.callback) + throw std::invalid_argument{ + "Must provide callback when using the Callbacks to send requests"}; + + callbacks_callback = std::move(transport_.callback); + log::debug(cat, "Network config set to transport requests via Callbacks"); + } + } +} + +void Config::handle_config_opt(opt::path_length pl) { + path_length = pl.length; + log::debug(cat, "Network config path length set to {}", pl.length); +} + +void Config::handle_config_opt(opt::disable_subnet_diversity dsd) { + enforce_subnet_diversity = false; + log::debug(cat, "Network config disabled subnet diversity"); +} + +void Config::handle_config_opt(opt::redirect_retry_count rrc) { + redirect_retry_count = rrc.count; + log::debug(cat, "Network config redirect retry count set to {}", rrc.count); +} + +void Config::handle_config_opt(opt::retry_delay rd) { + retry_delay = std::move(rd); + log::debug( + cat, + "Network config retry delay set to min: {}ms, max: {}ms", + retry_delay.base_delay.count(), + retry_delay.max_delay.count()); +} + +void Config::handle_config_opt(opt::request_timeout_check_frequency rtcf) { + request_timeout_check_frequency = rtcf.frequency; + log::debug( + cat, + "Network config request timeout check frequency set to: {}ms", + rtcf.frequency.count()); +} + +// MARK: Snode Pool Options + +void Config::handle_config_opt(opt::cache_directory dir) { + cache_directory = std::move(dir.path); + + if (cache_directory) + log::debug(cat, "Network config using cache dir {}", cache_directory->string()); +} + +void Config::handle_config_opt(opt::cache_expiration ce) { + cache_expiration = ce.duration; + log::debug( + cat, + "Network config snode pool cache expiration set to {} minutes", + ce.duration.count()); +} + +void Config::handle_config_opt(opt::cache_min_lifetime mcl) { + cache_min_lifetime = mcl.duration; + log::debug( + cat, + "Network config snode pool minimum cache lifetime set to {}ms", + mcl.duration.count()); +} + +void Config::handle_config_opt(opt::cache_min_size mcs) { + cache_min_size = mcs.size; + log::debug(cat, "Network config min snode pool cache size set to {}", mcs.size); +} + +void Config::handle_config_opt(opt::cache_num_nodes_to_use_for_refresh nnr) { + cache_num_nodes_to_use_for_refresh = nnr.count; + log::debug( + cat, + "Network config number of cached nodes to be used for refreshing the snode pool cache " + "set to {}{}", + nnr.count, + (nnr.count > 0 ? "" : ", refreshes will always use a random seed node")); +} + +void Config::handle_config_opt(opt::cache_node_failure_threshold nft) { + cache_node_failure_threshold = nft.count; + log::debug(cat, "Network config snode pool node failure threshold set to {}", nft.count); +} + +void Config::handle_config_opt(opt::cache_refresh_using_legacy_endpoint rule) { + cache_refresh_using_legacy_endpoint = true; + log::debug(cat, "Network config will refresh snode cache using legacy endpoint"); +} + +// MARK: Quic Transport Options + +void Config::handle_config_opt(opt::quic_handshake_timeout qht) { + quic_handshake_timeout = qht.duration; + log::debug(cat, "Network config quic handshake timeout set to {}ms", qht.duration.count()); +} + +void Config::handle_config_opt(opt::quic_keep_alive qka) { + quic_keep_alive = qka.duration; + log::debug(cat, "Network config quic keep alive set to {}s", qka.duration.count()); +} + +void Config::handle_config_opt(opt::quic_disable_mtu_discovery qdmd) { + quic_disable_mtu_discovery = true; + log::debug(cat, "Network config disabled MTU discovery for Quic"); +} + +// MARK: Onion Request Router Options + +void Config::handle_config_opt(opt::onionreq_path_failure_threshold pft) { + onionreq_path_failure_threshold = pft.count; + log::debug(cat, "Network config onion request path failure threshold set to {}", pft.count); +} + +void Config::handle_config_opt(opt::onionreq_path_build_retry_limit pbrl) { + onionreq_path_build_retry_limit = pbrl.count; + log::debug(cat, "Network config onion request path build retry limit set to {}", pbrl.count); +} + +void Config::handle_config_opt(opt::onionreq_min_path_count mpc) { + onionreq_min_path_counts.emplace(mpc.category, mpc.min_count); + + log::debug( + cat, + "Network config min {} onion request path count set to {}", + to_string(mpc.category), + mpc.min_count); +} + +void Config::handle_config_opt(opt::onionreq_single_path_mode spm) { + onionreq_single_path_mode = true; + log::debug(cat, "Network config onion requests set to single path mode"); +} + +void Config::handle_config_opt(opt::onionreq_disable_pre_build_paths dpbp) { + onionreq_disable_pre_build_paths = true; + log::debug(cat, "Network config disabled pre-building onion request paths"); +} + +} // namespace session::network::config diff --git a/src/network/request_queue.cpp b/src/network/request_queue.cpp new file mode 100644 index 00000000..d0302dbd --- /dev/null +++ b/src/network/request_queue.cpp @@ -0,0 +1,112 @@ +#include "session/network/request_queue.hpp" + +#include + +#include +#include +#include + +using namespace oxen; +using namespace oxen::log::literals; + +namespace session::network::detail { + +RequestQueue::~RequestQueue() { + _loop->call_get([this] { + for (auto& [category, callback] : _queue) { + try { + callback( + false, + false, + -1, + {content_type_plain_text}, + "Request cancelled: networking system is shutting down"); + } catch (...) { /* Ignore exceptions during shutdown */ + } + } + }); +} + +void RequestQueue::add(Request request, network_response_callback_t callback) { + _loop->call([self = shared_from_this(), req = std::move(request), cb = std::move(callback)]() { + auto has_timeout = req.overall_timeout.has_value(); + self->_queue.emplace_back(std::move(req), std::move(cb)); + + if (has_timeout && !self->_checker_active) { + self->_checker_active = true; + + auto weak_self = std::weak_ptr(self); + self->_loop->call_later(self->_check_frequency, [weak_self] { + if (auto self = weak_self.lock()) + self->check_timeouts(); + }); + } + }); +} + +void RequestQueue::add_front(std::pair req_pair) { + _loop->call([self = shared_from_this(), pair = std::move(req_pair)] { + auto has_timeout = pair.first.overall_timeout.has_value(); + self->_queue.emplace_front(std::move(pair)); + + if (has_timeout && !self->_checker_active) { + self->_checker_active = true; + + auto weak_self = std::weak_ptr(self); + self->_loop->call_later(self->_check_frequency, [weak_self] { + if (auto self = weak_self.lock()) + self->check_timeouts(); + }); + } + }); +} + +std::deque> RequestQueue::pop_all() { + return _loop->call_get([self = shared_from_this()] { + std::deque> popped_items; + std::swap(self->_queue, popped_items); + + return popped_items; + }); +} + +void RequestQueue::check_timeouts() { + auto time_now = std::chrono::system_clock::now(); + bool has_remaining_timeout_requests = false; + + std::erase_if(_queue, [&has_remaining_timeout_requests, &time_now](const auto& request) { + // If the request doesn't have an overall timeout then ignore it + if (!request.first.overall_timeout) + return false; + + auto duration = std::chrono::duration_cast( + time_now - request.first.creation_time); + + if (duration > *request.first.overall_timeout) { + request.second( + false, + true, + ERROR_BUILD_TIMEOUT, + {content_type_plain_text}, + "Timed out while in build queue."); + return true; + } + + has_remaining_timeout_requests = true; + return false; + }); + + // If there are no more timeout requests then stop looping here + if (!has_remaining_timeout_requests) { + _checker_active = false; + return; + } + + // Otherwise schedule the next check + _loop->call_later(_check_frequency, [weak_self = weak_from_this()] { + if (auto self = weak_self.lock()) + self->check_timeouts(); + }); +} + +} // namespace session::network::detail diff --git a/src/network/routing/direct_router.cpp b/src/network/routing/direct_router.cpp new file mode 100644 index 00000000..c42ed834 --- /dev/null +++ b/src/network/routing/direct_router.cpp @@ -0,0 +1,136 @@ +#include "session/network/routing/direct_router.hpp" + +#include +#include +#include + +#include +#include + +#include "session/network/network_opt.hpp" + +using namespace oxen; +using namespace session; +using namespace session::network; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network { + +namespace { + auto cat = oxen::log::Cat("network"); +} // namespace + +DirectRouter::DirectRouter( + std::shared_ptr loop, std::weak_ptr transport) : + _loop{loop}, _transport{transport} { + log::trace(cat, "[DirectRouter] Initializing."); + _update_status(ConnectionStatus::connected); +} + +DirectRouter::~DirectRouter() { + // Use 'call_get' to force this to be synchronous + if (_loop) + _loop->call_get([this] { _update_status(ConnectionStatus::disconnected); }); + log::debug(cat, "[DirectRouter] Destroyed."); +} + +// MARK: IRouter + +void DirectRouter::suspend() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + _suspended = true; + log::info(cat, "[DirectRouter] Suspended."); + }); +} + +void DirectRouter::resume(bool automatically_reconnect) { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + if (!_suspended) + return; + + _suspended = false; + log::info(cat, "[DirectRouter] Resumed."); + }); +} + +void DirectRouter::send_request(Request request, network_response_callback_t callback) { + _loop->call([weak_self = weak_from_this(), req = std::move(request), cb = std::move(callback)] { + if (auto self = weak_self.lock()) + self->_send_request_internal(std::move(req), std::move(cb)); + }); +} + +// MARK: Internal Logic + +void DirectRouter::_update_status(ConnectionStatus new_status) { + ConnectionStatus old_status = _status.load(); + if (old_status == new_status) + return; + + _status.store(new_status); + + if (on_status_changed) + on_status_changed(); +} + +void DirectRouter::_send_request_internal(Request request, network_response_callback_t callback) { + // If we are suspended then fail immediately + if (_suspended) + return callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "DirectRouter is suspended."); + + auto transport = _transport.lock(); + if (!transport) { + log::critical(cat, "[DirectRouter] Transport was destroyed, cannot send request."); + return; + } + + transport->send_request( + std::move(request), + [weak_self = weak_from_this(), cb = std::move(callback)]( + bool success, bool timeout, int16_t status_code, auto headers, auto response) { + if (auto self = weak_self.lock()) + self->_handle_transport_response( + success, + timeout, + status_code, + std::move(headers), + std::move(response), + std::move(cb)); + }); +} + +void DirectRouter::_handle_transport_response( + bool success, + bool timeout, + int16_t status_code_, + std::vector> headers, + std::optional response_body, + network_response_callback_t callback) { + // If we weren't given a body then just return the data directly + if (!response_body) + return callback(success, timeout, status_code_, headers, response_body); + + // Otherwise the response will be a json array of [{status_code}, {body}] + try { + nlohmann::json response_json = nlohmann::json::parse(*response_body); + + if (!response_json.is_array() || response_json.size() != 2) + throw std::runtime_error{"Unexpected JSON response structure."}; + + uint16_t status_code = response_json[0].get(); + std::string data = response_json[1].dump(); + return callback(success, timeout, status_code, headers, data); + } catch (const std::exception& e) { + return callback(false, timeout, status_code_, {content_type_plain_text}, e.what()); + } +} + +} // namespace session::network diff --git a/src/network/routing/lokinet_router.cpp b/src/network/routing/lokinet_router.cpp new file mode 100644 index 00000000..48289776 --- /dev/null +++ b/src/network/routing/lokinet_router.cpp @@ -0,0 +1,648 @@ +#include "session/network/routing/lokinet_router.hpp" + +#include +#include +#include + +#include +#include +#include +#include + +#include "session/network/network_opt.hpp" +#include "session/onionreq/builder.hpp" +#include "session/onionreq/response_parser.hpp" + +using namespace oxen; +using namespace session; +using namespace session::network; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network { + +namespace { + auto cat = oxen::log::Cat("network"); + + static constexpr std::string_view PROXIED_REQUESTS_KEY{"proxied_requests"}; + + std::string pending_request_key(const network_destination& dest) { + std::optional key; + + std::visit( + [&key](auto&& arg) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + key = oxenc::to_hex(arg.view_remote_key()); + } else if constexpr (std::is_same_v) { + key = oxenc::to_hex(arg.view_remote_key()); + } else if constexpr (std::is_same_v) { + key = PROXIED_REQUESTS_KEY; + } + }, + dest); + + if (!key) + throw std::runtime_error{"Invalid destination"}; + + return *key; + } + + oxen::quic::RemoteAddress address_for_destination( + const network_destination& dest, const std::string& request_id) { + std::optional address; + + std::visit( + [&address, &request_id](auto&& arg) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + log::trace( + cat, + "[LokinetRouter Request {}]: Using pre-resolved RemoteAddress.", + request_id); + address = arg; + } else if constexpr (std::is_same_v) { + log::trace( + cat, + "[LokinetRouter Request {}]: Resolving service_node to " + "RemoteAddress.", + request_id); + address.emplace(arg.view_remote_key(), arg.host(), arg.omq_port); + } + }, + dest); + + if (!address) + throw std::runtime_error{"Invalid destination"}; + + if (address->view_remote_key().size() != 32) + throw std::runtime_error{"Invalid remote key"}; + + return *address; + } + +} // namespace + +LokinetRouter::LokinetRouter( + config::LokinetRouterConfig config, + std::shared_ptr loop, + std::weak_ptr snode_pool, + std::weak_ptr transport) : + _config{std::move(config)}, _loop{loop}, _snode_pool{snode_pool}, _transport{transport} { + log::trace(cat, "[LokinetRouter] Initializing."); + + auto test_ini = R"( + [router] + netid={} + data-dir={} + [bind] + listen=:0 + [logging] + type=none + level=*=debug,quic=info + )"_format(opt::netid::to_string(_config.netid), _config.cache_directory); + + try { + _update_status(ConnectionStatus::connecting); + + // TODO: Don't pass the loop for now. + lokinet = std::make_shared(test_ini /*, loop*/); + + // TODO: Remove this hack to wait for lokinet to be ready before any requests get sent + _loop->call_later(5000ms, [this] { + auto snode_pool = _snode_pool.lock(); + if (!snode_pool) { + log::critical(cat, "[LokinetRouter] SnodePool was destroyed, cannot setup router."); + return; + } + + if (snode_pool->size() == 0) + snode_pool->refresh_if_needed({}, [weak_self = weak_from_this()] { + if (auto self = weak_self.lock()) + self->_loop->call([weak_self] { + if (auto self = weak_self.lock()) + self->_finish_setup(); + }); + }); + else + _finish_setup(); + }); + } catch (const std::exception& e) { + log::error(cat, "[LokinetRouter] Failed to start lokinet ({}).", e.what()); + _update_status(ConnectionStatus::disconnected); + throw; + } +} + +LokinetRouter::~LokinetRouter() { + // Use 'call_get' to force this to be synchronous + if (_loop) + _loop->call_get([this] { _update_status(ConnectionStatus::disconnected); }); + log::debug(cat, "[LokinetRouter] Destroyed."); +} + +// MARK: IRouter + +void LokinetRouter::suspend() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + _suspended = true; + _close_connections(); + log::info(cat, "[LokinetRouter] Suspended."); + }); +} + +void LokinetRouter::resume(bool automatically_reconnect) { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + if (!_suspended) + return; + + _suspended = false; + log::info(cat, "[LokinetRouter] Resumed."); + }); +} + +void LokinetRouter::close_connections() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { _close_connections(); }); +} + +void LokinetRouter::clear_cache() { + // TODO: Implement this +} + +std::vector LokinetRouter::get_active_paths() { + // TODO: Implement this + return {}; +} + +void LokinetRouter::send_request(Request request, network_response_callback_t callback) { + _loop->call([weak_self = weak_from_this(), req = std::move(request), cb = std::move(callback)] { + if (auto self = weak_self.lock()) + self->_send_request_internal(std::move(req), std::move(cb)); + }); +} + +// MARK: Internal Logic + +void LokinetRouter::_finish_setup() { + // Start processing requests + _ready = true; + log::debug(cat, "[LokinetRouter] Finishing setup, router is now ready."); + + auto requests_to_process = std::move(_pending_requests); + if (requests_to_process.empty()) + return; + + // Process any requests that were queued before we were ready + log::debug( + cat, + "[LokinetRouter] Processing {} requests queued during initialization.", + requests_to_process.size()); + + for (auto& [address, requests] : requests_to_process) { + if (!requests.empty()) { + log::debug( + cat, + "[LokinetRouter] Processing {} queued requests for address {}.", + requests.size(), + address); + + for (auto&& [req, cb] : std::move(requests)) + _send_request_internal(std::move(req), std::move(cb)); + } + } +} + +void LokinetRouter::_close_connections() { + // TODO: Need to close any active connections on the lokinet instance + + // Cancel any pending requests (they can't succeed once the connection is closed) + for (const auto& [pubkey, pupkey_requests] : _pending_requests) + for (const auto& [info, callback] : pupkey_requests) + callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "Network is suspended."); + + // Clear all storage of requests, paths and connections so that we are in a fresh state on + // relaunch + _active_tunnels.clear(); + _pending_requests.clear(); + _update_status(ConnectionStatus::disconnected); + log::info(cat, "[LokinetRouter] Closed all connections."); +} + +void LokinetRouter::_update_status(ConnectionStatus new_status) { + ConnectionStatus old_status = _status.load(); + if (old_status == new_status) + return; + + _status.store(new_status); + + if (on_status_changed) + on_status_changed(); +} + +void LokinetRouter::_send_request_internal(Request request, network_response_callback_t callback) { + // If we are suspended then fail immediately + if (_suspended) + return callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "LokinetRouter is suspended."); + + // Queue the request if we aren't ready + auto key = pending_request_key(request.destination); + + if (!_ready) { + log::debug( + cat, + "[LokinetRouter Request {}]: Router not ready, queueing request.", + request.request_id); + + // Queue the request if not ready. We need the pubkey hex as the key. + try { + _pending_requests[key].emplace_back(std::move(request), std::move(callback)); + } catch (const std::exception& e) { + log::critical( + cat, + "[LokinetRouter Request {}]: Dropping after failure to queue due to error: {}.", + request.request_id, + e.what()); + return callback(false, false, -1, {content_type_plain_text}, e.what()); + } + return; + } + + // If the request is being sent to a `ServerDestination` then we need to make a proxied request + // instead + if (std::holds_alternative(request.destination)) { + log::debug( + cat, + "[LokinetRouter Request {}]: Destination is a server, finding a proxy node.", + request.request_id); + _send_proxy_request(std::move(request), std::move(callback)); + return; + } + + // When sending a direct request the response will be a json array of [{status_code}, {body}] so + // we need to process that before triggering the callback + auto json_parsing_callback = + [cb = std::move(callback)]( + bool success, bool timeout, int16_t status_code_, auto headers, auto response) { + if (!response) + return cb(success, timeout, status_code_, headers, response); + + try { + nlohmann::json response_json = nlohmann::json::parse(*response); + + if (!response_json.is_array() || response_json.size() != 2) + throw std::runtime_error{"Unexpected JSON response structure."}; + + uint16_t status_code = response_json[0].get(); + std::string data = response_json[1].dump(); + return cb(success, timeout, status_code, headers, data); + } catch (const std::exception& e) { + return cb(false, timeout, status_code_, {content_type_plain_text}, e.what()); + } + }; + + _send_direct_request(std::move(request), std::move(json_parsing_callback)); +} + +void LokinetRouter::_send_direct_request(Request request, network_response_callback_t callback) { + try { + if (std::holds_alternative(request.destination)) + throw std::runtime_error{"Attempted to send server request directly"}; + + auto address = address_for_destination(request.destination, request.request_id); + const auto address_pubkey_hex = oxenc::to_hex(address.view_remote_key()); + + if (auto it = _active_tunnels.find(address_pubkey_hex); it != _active_tunnels.end()) { + log::trace(cat, "[LokinetRouter Request {}] Found active tunnel.", request.request_id); + _send_via_tunnel(it->second, std::move(request), std::move(callback)); + return; + } + + // Add the request to the pending queue to be picked up once we have a tunnel for it + std::string initiating_req_id = request.request_id; + _pending_requests[address_pubkey_hex].emplace_back(std::move(request), std::move(callback)); + + // If there is only a single pending request then we wouldn't have started establishing a + // tunnel + if (_pending_requests.at(address_pubkey_hex).size() == 1) { + log::info( + cat, + "[LokinetRouter Request {}] No tunnel to {}, initiating new tunnel.", + initiating_req_id, + address_pubkey_hex); + _establish_tunnel(address, initiating_req_id); + } else + log::debug( + cat, + "[LokinetRouter Request {}] Tunnel to {} is pending, queueing request.", + initiating_req_id, + address_pubkey_hex); + } catch (const std::exception& e) { + log::error( + cat, + "[LokinetRouter Request {}] Failed to send request due to error: {}", + request.request_id, + e.what()); + return callback( + false, + false, + -1, + {content_type_plain_text}, + "Failed to send request due to error: {}"_format(e.what())); + } +} + +void LokinetRouter::_send_proxy_request(Request request, network_response_callback_t callback) { + auto snode_pool = _snode_pool.lock(); + if (!snode_pool) { + return callback( + false, + false, + -1, + {content_type_plain_text}, + "SnodePool was destroyed, cannot find proxy."); + } + + auto proxy_nodes = snode_pool->get_unused_nodes(1); + + if (proxy_nodes.empty()) { + log::warning( + cat, + "[LokinetRouter Request {}]: No available proxy nodes, waiting for SnodePool " + "refresh.", + request.request_id); + + snode_pool->refresh_if_needed( + {}, + [weak_self = weak_from_this(), + req = std::move(request), + cb = std::move(callback)]() { + auto self = weak_self.lock(); + if (!self) + return; + + auto snode_pool = self->_snode_pool.lock(); + if (!snode_pool) + return cb( + false, + false, + -1, + {content_type_plain_text}, + "SnodePool was destroyed, cannot find proxy."); + + if (snode_pool->get_unused_nodes(1).empty()) + return cb( + false, + false, + -1, + {content_type_plain_text}, + "SnodePool refresh failed."); + + log::info( + cat, + "[LokinetRouter Request {}]: SnodePool refresh complete, retrying " + "proxy selection.", + req.request_id); + self->_send_proxy_request(std::move(req), std::move(cb)); + }); + return; + } + + service_node proxy_node = proxy_nodes[0]; + std::vector encrypted_blob; + std::shared_ptr parser; + log::debug( + cat, + "[LokinetRouter Request {}]: Selected {} as proxy.", + request.request_id, + proxy_node.to_string()); + + try { + std::vector proxy_path = {proxy_node}; + auto builder = onionreq::Builder(request.destination, request.endpoint, proxy_path); + encrypted_blob = builder.generate_onion_blob(request.body); + parser = std::make_shared(builder); + } catch (const std::exception& e) { + log::warning( + cat, + "[LokinetRouter Request {}]: Failed to build proxy request payload: {}", + request.request_id, + e.what()); + return callback( + false, false, -1, {content_type_plain_text}, "Failed to build proxy request"); + } + + Request proxy_request{ + request.request_id, + network_destination{proxy_node}, // Send to the proxy node + std::string{"onion_req"}, // Send to onion request handling endpoint + std::move(encrypted_blob), // Encrypted payload + request.category, + request.time_remaining(), + request.overall_timeout}; + + auto proxy_callback = + [parser = std::move(parser), cb = std::move(callback)]( + bool success, bool timeout, int16_t status, auto headers, auto response) { + try { + if (!success) + throw std::runtime_error{response.value_or("Unknown request failure")}; + if (timeout) + throw std::runtime_error{response.value_or("Timed out")}; + if (!response) + throw std::runtime_error{"Unexpected empty response"}; + + onionreq::DecryptedResponse decrypted = parser->decrypted_response(*response); + cb(true, + false, + decrypted.status_code, + std::move(decrypted.headers), + std::move(decrypted.body)); + } catch (const std::exception& e) { + cb(false, + timeout, + status, + std::move(headers), + "Failed to handle proxied request response due to error: {}"_format( + e.what())); + } + }; + + // Now that we have a service_node destination we can send a direct request + _send_direct_request(std::move(proxy_request), std::move(proxy_callback)); +} + +void LokinetRouter::_establish_tunnel( + const oxen::quic::RemoteAddress& address, const std::string& initiating_req_id) { + auto key = address.view_remote_key(); + auto address_pubkey_hex = oxenc::to_hex(key); + + if (address_pubkey_hex.size() != 64) { + log::critical( + cat, + "[LokinetRouter] Destination had an invalid remote key, request {} is being " + "dropped.", + initiating_req_id); + // Fail all the pending requests for this connection + if (auto it = _pending_requests.find(address_pubkey_hex); it != _pending_requests.end()) { + auto to_fail = std::move(it->second); + _pending_requests.erase(it); + log::error( + cat, + "[LokinetRouter] Failing {} pending request(s) due to connection failure.", + to_fail.size()); + + for (auto& [req, cb] : to_fail) + cb(false, + false, + -1, + {content_type_plain_text}, + "Failed to establish tunnel to remote."); + } + return; + } + + llarp::RouterID router_id{key.first<32>()}; + // auto snode_address = "34d9udo9ethfcrcaxcgdyxsi1w8gr79jzornsytcfgdw5rpmif8y.loki";// + // address.to_network_address(true); + // auto snode_address = "55fxd8stjrt9g6rsbftx7eesy47pj4751xjghinr3k9ffxh4ieyo.snode"; + auto lokinet_address = router_id.to_network_address(true); + auto test_port = address.port(); // 35519; + + log::debug( + cat, + "[LokinetRouter Request {}] Establishing new tunnel to {}.", + initiating_req_id, + address_pubkey_hex); + lokinet->establish_udp( + lokinet_address.to_string(), + test_port, + [weak_self = weak_from_this(), address_pubkey_hex, initiating_req_id]( + lokinet::tunnel_info info) mutable { + auto self = weak_self.lock(); + if (!self) + return; + + log::info( + cat, + "[LokinetRouter Request {}] Tunnel to remote {} established.", + initiating_req_id, + address_pubkey_hex); + + auto requests_to_process = std::move(self->_pending_requests[address_pubkey_hex]); + self->_pending_requests.erase(address_pubkey_hex); + self->_active_tunnels.insert_or_assign(address_pubkey_hex, info); + + // We had a successful connection so update the status to connected + self->_update_status(ConnectionStatus::connected); + + if (!requests_to_process.empty()) { + log::debug( + cat, + "[LokinetRouter] Processing {} pending requests on new tunnel to " + "{}.", + requests_to_process.size(), + info.remote); + + for (auto&& [req, cb] : std::move(requests_to_process)) + self->_send_via_tunnel(info, std::move(req), std::move(cb)); + } + }, + [weak_self = weak_from_this(), address_pubkey_hex, initiating_req_id]( + std::string errmsg) mutable { + auto self = weak_self.lock(); + if (!self) + return; + + log::info( + cat, + "[LokinetRouter Request {}] Unable to establish lokinet UDP connection " + "to " + "{} due to error: {}.", + initiating_req_id, + address_pubkey_hex, + errmsg); + + self->_active_tunnels.erase(address_pubkey_hex); + + // Fail all the pending requests for this connection + if (auto it = self->_pending_requests.find(address_pubkey_hex); + it != self->_pending_requests.end()) { + auto to_fail = std::move(it->second); + self->_pending_requests.erase(it); + + log::error( + cat, + "[LokinetRouter] Failing {} pending requests due to UDP connection " + "failure.", + to_fail.size()); + + for (auto& [req, cb] : to_fail) + cb(false, false, -1, {content_type_plain_text}, errmsg); + } + + // If we have no longer have any active connections then we are disconnected + if (self->_active_tunnels.empty()) + self->_update_status(ConnectionStatus::disconnected); + }); +} + +void LokinetRouter::_send_via_tunnel( + lokinet::tunnel_info tunnel, Request request, network_response_callback_t callback) { + // TODO: Is there a way to check that the 'tunnel_info' still active? + + // If the request has already timedout at this point then just fail it immediately + auto timeout = request.time_remaining(); + if (timeout <= std::chrono::milliseconds::zero()) + return callback(false, true, 408, {content_type_plain_text}, "Request already timed out"); + + auto transport = _transport.lock(); + if (!transport) { + log::critical(cat, "[LokinetRouter] Transport was destroyed, cannot send request."); + return; + } + + // We have a valid connection and stream so we can send the request + log::debug(cat, "[LokinetRouter Request {}] Sending to {}.", request.request_id, tunnel.remote); + + oxen::quic::RemoteAddress address = + address_for_destination(request.destination, request.request_id); + auto key = address.view_remote_key(); + const auto address_pubkey_hex = oxenc::to_hex(key); + auto test_key = key; + // auto test_key = + // oxenc::from_base64("1n+DAM9hKyJhtXSPR5L/HdemIKPiHs8dZsPn2kEQuMs="); auto test_key + // = oxenc::from_base32z("55fxd8stjrt9g6rsbftx7eesy47pj4751xjghinr3k9ffxh4ieyo"); + auto loki_target = oxen::quic::RemoteAddress{test_key, "127.0.0.1", tunnel.local_port}; + + // Construct the actual request to send + std::optional remaining_overall_timeout = + (request.overall_timeout.has_value() ? std::optional{request.time_remaining()} + : std::nullopt); + Request lokinet_request{ + request.request_id, + network_destination{loki_target}, // Send to local lokinet address + request.endpoint, // Send to onion request handling endpoint + request.body, + request.category, + request.time_remaining(), + remaining_overall_timeout}; + + transport->send_request(std::move(lokinet_request), std::move(callback)); +} + +} // namespace session::network diff --git a/src/network/routing/onion_request_router.cpp b/src/network/routing/onion_request_router.cpp new file mode 100644 index 00000000..a6f23d55 --- /dev/null +++ b/src/network/routing/onion_request_router.cpp @@ -0,0 +1,1218 @@ +#include "session/network/routing/onion_request_router.hpp" + +#include + +#include +#include + +#include "session/network/network_opt.hpp" +#include "session/onionreq/builder.hpp" +#include "session/onionreq/response_parser.hpp" +#include "session/random.hpp" + +using namespace oxen; +using namespace session; +using namespace session::network; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network { + +namespace { + auto cat = oxen::log::Cat("network"); + + constexpr auto node_not_found_prefix = "502 Bad Gateway\n\nNext node not found: "sv; + constexpr auto node_not_found_prefix_no_status = "Next node not found: "sv; + + enum class PathSelectionBehaviour { + random, + new_or_least_busy, + }; + + inline std::string to_string(RequestCategory category, bool single_path_mode) { + if (single_path_mode) + return "single_path"; + + return to_string(category); + } + + PathSelectionBehaviour get_path_selection_behaviour(RequestCategory category) { + switch (category) { + case RequestCategory::standard: return PathSelectionBehaviour::random; + case RequestCategory::upload: return PathSelectionBehaviour::new_or_least_busy; + case RequestCategory::download: return PathSelectionBehaviour::new_or_least_busy; + } + return PathSelectionBehaviour::random; + } + + std::vector extract_nodes( + const std::unordered_map>& paths, + const std::unordered_map>& pending_paths) { + std::vector all_used_nodes; + + for (const auto& [pt, path_list] : paths) + for (const auto& p : path_list) + all_used_nodes.insert(all_used_nodes.end(), p.nodes.begin(), p.nodes.end()); + + for (const auto& [pid, nodes] : pending_paths) + all_used_nodes.insert(all_used_nodes.end(), nodes.begin(), nodes.end()); + + return all_used_nodes; + } +} // namespace + +std::string OnionPath::to_string() const { + std::vector node_descriptions; + std::transform( + nodes.begin(), + nodes.end(), + std::back_inserter(node_descriptions), + [](const service_node& node) { return node.to_string(); }); + + return "{}"_format(fmt::join(node_descriptions, ", ")); +} + +OnionRequestRouter::OnionRequestRouter( + config::OnionRequestRouterConfig config, + std::shared_ptr loop, + std::weak_ptr snode_pool, + std::weak_ptr transport) : + _config{std::move(config)}, _loop{loop}, _snode_pool{snode_pool}, _transport{transport} { + log::trace(cat, "[OnionRequestRouter] Initializing."); + + _request_queues[RequestCategory::standard] = + std::make_shared(loop, _config.request_timeout_check_frequency); + _request_queues[RequestCategory::upload] = + std::make_shared(loop, _config.request_timeout_check_frequency); + _request_queues[RequestCategory::download] = + std::make_shared(loop, _config.request_timeout_check_frequency); + + _loop->call_soon([this] { + auto snode_pool = _snode_pool.lock(); + if (!snode_pool) { + log::critical( + cat, "[OnionRequestRouter] SnodePool was destroyed, cannot setup router."); + return; + } + + if (snode_pool->size() == 0) + snode_pool->refresh_if_needed({}, [weak_self = weak_from_this()] { + if (auto self = weak_self.lock()) + self->_loop->call([weak_self] { + if (auto self = weak_self.lock()) + self->_finish_setup(); + }); + }); + else + _finish_setup(); + }); +} + +OnionRequestRouter::~OnionRequestRouter() { + // Use 'call_get' to force this to be synchronous + if (_loop) + _loop->call_get([this] { _close_connections(); }); + log::debug(cat, "[OnionRequestRouter] Destroyed."); +} + +// MARK: IRouter + +void OnionRequestRouter::suspend() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + _suspended = true; + _close_connections(); + log::info(cat, "[OnionRequestRouter] Suspended."); + }); +} + +void OnionRequestRouter::resume(bool automatically_reconnect) { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this, automatically_reconnect] { + if (!_suspended) + return; + + _suspended = false; + + if (automatically_reconnect) + _pre_build_paths_if_needed(); + + log::info(cat, "[OnionRequestRouter] Resumed."); + }); +} + +void OnionRequestRouter::close_connections() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { _close_connections(); }); +} + +std::vector OnionRequestRouter::get_active_paths() { + return _loop->call_get([this] { + std::vector result; + result.reserve(_paths.size()); + + for (const auto& [category, path_list] : _paths) + for (const auto& p : path_list) + result.push_back({p.nodes, OnionPathMetadata{category}}); + + return result; + }); +} + +std::vector OnionRequestRouter::get_all_used_nodes() { + return _loop->call_get([this] { return extract_nodes(_paths, _pending_paths); }); +} + +void OnionRequestRouter::send_request(Request request, network_response_callback_t callback) { + _loop->call([weak_self = weak_from_this(), req = std::move(request), cb = std::move(callback)] { + if (auto self = weak_self.lock()) + self->_send_request_internal(std::move(req), std::move(cb)); + }); +} + +// MARK: Internal Logic + +void OnionRequestRouter::_finish_setup() { + // Start processing requests + _ready = true; + log::debug(cat, "[OnionRequestRouter] Finishing setup, router is now ready."); + + // Pre-build paths if needed + _pre_build_paths_if_needed(); + + // Process any requests that were queued before we were ready + for (auto& [category, queue] : _request_queues) { + if (!queue->is_empty()) { + auto pending = queue->pop_all(); + log::debug( + cat, + "[OnionRequestRouter] Processing {} requests queued during initialization for " + "category '{}'.", + pending.size(), + to_string(category)); + + for (auto& [req, cb] : pending) + _send_request_internal(std::move(req), std::move(cb)); + } + } +} + +void OnionRequestRouter::_pre_build_paths_if_needed() { + if (!_config.disable_pre_build_paths) { + log::info(cat, "[OnionRequestRouter] Pre-building initial paths."); + + auto schedule_build = [this](RequestCategory category, int count) { + for (int i = 0; i < count; ++i) + _build_path( + category, + "pre-build-{}-{}"_format( + to_string(category, _config.single_path_mode), i + 1), + {}); + }; + + if (_config.single_path_mode) { + log::debug(cat, "[OnionRequestRouter] Pre-building 1 path for single_path_mode."); + schedule_build(RequestCategory::standard, 1); + } else { + for (const auto& [category, min_count] : _config.min_path_counts) { + if (min_count > 0) { + log::debug( + cat, + "[OnionRequestRouter] Pre-building {} path(s) for category '{}'.", + min_count, + to_string(category, _config.single_path_mode)); + schedule_build(category, min_count); + } + } + } + } else + log::debug(cat, "[OnionRequestRouter] Path pre-building is disabled."); +} + +void OnionRequestRouter::_close_connections() { + // Cancel any pending requests (they can't succeed once the connection is closed) + for (auto& [path_type, path_type_queue] : _request_queues) { + auto to_fail = path_type_queue->pop_all(); + + for (const auto& [req, callback] : to_fail) + callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "Network is suspended."); + } + + // Remove any failure listeners for the guard nodes of the current paths + if (auto transport = _transport.lock()) + for (const auto& [category, path_list] : _paths) + for (const auto& p : path_list) + if (!p.nodes.empty()) + transport->remove_failure_listeners( + ed25519_pubkey::from_bytes(p.nodes[0].view_remote_key())); + + // Clear all storage of requests, paths and connections so that we are in a fresh state on + // relaunch + // + // The connection status is recalculated based on these values so we need to call them + // before recalculation so it correctly detects the "disconnected" state + _paths.clear(); + _paths_pending_drop.clear(); + _in_progress_path_builds.clear(); + _path_build_retries.clear(); + _pending_paths.clear(); + _update_status(); + log::info(cat, "[OnionRequestRouter] Closed all connections."); +} + +void OnionRequestRouter::_update_status() { + ConnectionStatus new_status = ConnectionStatus::disconnected; + + // If we have at least one active "standard" path we are considered connected + auto paths_it = _paths.find(RequestCategory::standard); + if (paths_it != _paths.end() && !paths_it->second.empty()) + new_status = ConnectionStatus::connected; + // If we have at least one active non-standard path then considered connecting (not properly + // connected, but some requests may work) + else if (std::any_of( + _paths.begin(), _paths.end(), [](const auto& p) { return !p.second.empty(); })) + new_status = ConnectionStatus::connecting; + // Otherwise if we are building one then we are connecting + else if (std::any_of( + _in_progress_path_builds.begin(), + _in_progress_path_builds.end(), + [](const auto& p) { return p.second > 0; })) + new_status = ConnectionStatus::connecting; + + if (_status.load() != new_status) { + _status.store(new_status); + + if (on_status_changed) + on_status_changed(); + } +} + +void OnionRequestRouter::_send_request_internal( + Request request, network_response_callback_t callback) { + // If we are suspended then fail immediately + if (_suspended) + return callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "OnionRequestRouter is suspended."); + + auto initiating_req_category = + (_config.single_path_mode ? RequestCategory::standard : request.category); + + if (!_ready) { + log::debug( + cat, + "[OnionRequestRouter Request {}]: Router not ready, queueing request.", + request.request_id); + + try { + _request_queues.at(initiating_req_category) + ->add(std::move(request), std::move(callback)); + } catch (const std::exception& e) { + log::critical( + cat, + "[OnionRequestRouter] No request queue for category '{}', request {} is being " + "dropped.", + to_string(initiating_req_category, _config.single_path_mode), + request.request_id); + return callback( + false, false, -1, {content_type_plain_text}, "Unhandled request category"); + } + return; + } + + // Try to use an existing path if we have one + log::trace( + cat, + "[OnionRouter Request {}]: Received request for category '{}', searching for a path.", + request.request_id, + to_string(initiating_req_category, _config.single_path_mode)); + OnionPath* path = _find_valid_path(request); + + if (path) { + log::debug( + cat, + "[OnionRouter Request {}]: Found valid path {}, sending.", + request.request_id, + path->id); + _send_on_path(*path, std::move(request), std::move(callback)); + return; + } + + // No valid path, queue the request an build a path + log::debug( + cat, + "[OnionRouter Request {}]: No path available, queueing request.", + request.request_id); + + // Add the request to the queue for its category + auto initiating_req_id = request.request_id; + + try { + _request_queues.at(initiating_req_category)->add(std::move(request), std::move(callback)); + } catch (const std::exception& e) { + log::critical( + cat, + "[OnionRequestRouter] No request queue for category '{}', request {} is being " + "dropped.", + to_string(initiating_req_category, _config.single_path_mode), + request.request_id); + return callback(false, false, -1, {content_type_plain_text}, "Unhandled request category"); + } + + // Check if we need to build additional paths + const auto current = + _paths.count(initiating_req_category) ? _paths.at(initiating_req_category).size() : 0; + const auto in_progress = _in_progress_path_builds[initiating_req_category]; + bool should_build = false; + + // In single path mode, we only build if we have zero paths (current or in-progress) + if (_config.single_path_mode) + should_build = (current + in_progress == 0); + else { + // In multi-path mode, we build if we are below the min number + const auto needed = _config.min_path_counts.at(initiating_req_category); + should_build = (current + in_progress < needed); + } + + if (should_build) { + log::info( + cat, + "[OnionRouter Request {}]: Path count for '{}' is insufficient, building new path.", + initiating_req_id, + to_string(initiating_req_category, _config.single_path_mode)); + + _build_path(initiating_req_category, initiating_req_id, {}); + } +} + +void OnionRequestRouter::_build_path( + RequestCategory category, + std::optional initiating_req_id, + const std::vector& nodes_to_exclude_, + std::optional original_path_id) { + if (_suspended) { + log::info(cat, "Ignoring build_path call as network is suspended."); + return; + } + + const std::string req_id_log = (initiating_req_id ? *initiating_req_id : "internal"); + const std::string path_id = original_path_id.value_or("P-" + random::random_base32(4)); + log::info( + cat, + "[OnionRouter Request {} Path {}]: Starting build for {} path.", + req_id_log, + path_id, + to_string(category, _config.single_path_mode)); + + // If we were misconfigured to have a `path_length` of `0` then just fail all requests + if (_config.path_length == 0) { + log::error( + cat, + "[OnionRouter Request {} Path {}]: Cannot build path, path_size is configured to " + "0.", + req_id_log, + path_id); + + auto queue_it = _request_queues.find(category); + if (queue_it == _request_queues.end()) { + log::critical( + cat, + "[OnionRequestRouter] No request queue for category '{}'.", + to_string(category, _config.single_path_mode)); + return; + } + + if (!queue_it->second->is_empty()) { + auto to_fail = queue_it->second->pop_all(); + + for (const auto& [req, cb] : to_fail) + cb(false, + false, + -1, + {content_type_plain_text}, + "Router misconfigured: path_length is 0."); + } + return; + } + + _in_progress_path_builds[category]++; + _update_status(); + + auto nodes_to_exclude = extract_nodes(_paths, _pending_paths); + nodes_to_exclude.insert( + nodes_to_exclude.end(), nodes_to_exclude_.begin(), nodes_to_exclude_.end()); + + std::vector path_nodes; + + auto snode_pool = _snode_pool.lock(); + if (!snode_pool) { + log::critical(cat, "[OnionRequestRouter] SnodePool was destroyed, cannot build path."); + return; + } + + path_nodes = snode_pool->get_unused_nodes(_config.path_length, nodes_to_exclude); + + // If we don't have enough nodes to build a path then we should try to refresh the snode pool + if (path_nodes.size() < _config.path_length) { + log::warning( + cat, + "[OnionRouter Request {} Path {}]: Failed to get enough nodes from SnodePool (need " + "{}, got {}), queueing retry after pool refresh.", + req_id_log, + path_id, + _config.path_length, + path_nodes.size()); + _in_progress_path_builds[category]--; + + snode_pool->refresh_if_needed( + nodes_to_exclude, + [weak_self = weak_from_this(), category, initiating_req_id, nodes_to_exclude]() { + auto self = weak_self.lock(); + if (!self) + return; + + log::info( + cat, + "[OnionRouter Request {}]: SnodePool refresh complete, " + "retrying " + "path build.", + initiating_req_id.value_or("internal")); + self->_build_path(category, initiating_req_id, nodes_to_exclude); + }); + return; + } + + // Attempty to verify connectivity to the guard node + _pending_paths[path_id] = path_nodes; + auto guard_node = path_nodes.front(); + log::debug( + cat, + "[OnionRouter Request {} Path {}]: Testing connectivity to guard node {}.", + req_id_log, + path_id, + guard_node.to_string()); + + auto transport = _transport.lock(); + if (!transport) { + log::critical(cat, "[OnionRequestRouter] Transport was destroyed, cannot build path."); + return; + } + + transport->verify_connectivity( + guard_node, + 3s, + "{} - Path Build {}"_format(req_id_log, path_id), + [weak_self = weak_from_this(), path_id, category, initiating_req_id](bool success) { + if (auto self = weak_self.lock()) + self->_on_guard_connectivity_response( + path_id, category, initiating_req_id, success); + }); +} + +void OnionRequestRouter::_on_guard_connectivity_response( + const std::string& path_id, + RequestCategory category, + std::optional initiating_req_id, + bool success) { + const std::string req_id_log = initiating_req_id.value_or("internal"); + + auto pending_it = _pending_paths.find(path_id); + if (pending_it == _pending_paths.end()) { + log::warning( + cat, + "[OnionRouter Request {} Path {}]: Received connection callback for a path that is " + "no longer pending, ignoring.", + req_id_log, + path_id); + return; + } + + // Extract the pending path nodes and remove it from the pending list + auto path_nodes = std::move(pending_it->second); + _pending_paths.erase(pending_it); + + const auto& guard_node = path_nodes.front(); + + if (_in_progress_path_builds[category] > 0) + _in_progress_path_builds[category]--; + + if (!success) { + // The guard node failed so record the failure and try to build a new path to replace this + // failed one (excluding the failed guard node from the next attempt) + log::warning( + cat, + "[OnionRouter Request {} Path {}]: Failed to verify connectivity to guard node {}, " + "retrying path build.", + req_id_log, + path_id, + guard_node.to_string()); + if (auto snode_pool = _snode_pool.lock()) + snode_pool->record_node_failure(guard_node); + + int& retries = _path_build_retries[path_id]; + retries++; + + // If we tried, and failed, to build the path too many times then give up and fail all + // pending requests + if (retries > _config.path_build_retry_limit) { + log::critical( + cat, + "[OnionRouter Path {}]: Aborting build after {} failed attempts.", + path_id, + retries); + _path_build_retries.erase(path_id); + _update_status(); + + auto queue_it = _request_queues.find(category); + if (queue_it == _request_queues.end()) { + log::critical( + cat, + "[OnionRequestRouter] No request queue for category '{}'.", + to_string(category, _config.single_path_mode)); + return; + } + + if (!queue_it->second->is_empty()) { + auto to_fail = queue_it->second->pop_all(); + log::error( + cat, + "[OnionRequestRouter] Failing {} queued requests for '{}' paths due to " + "persistent path build failures.", + to_fail.size(), + to_string(category, _config.single_path_mode)); + + for (const auto& [req, cb] : to_fail) + cb(false, + false, + -1, + {content_type_plain_text}, + "Failed to build a required onion path after multiple retries."); + } + return; + } + + auto delay = _config.retry_delay.exponential(retries); + log::info( + cat, + "[OnionRouter Path {}]: Retrying path build in {}ms (attempt {}/{})", + path_id, + delay.count(), + retries, + _config.path_build_retry_limit); + _update_status(); + + _loop->call_later( + delay, + [weak_self = weak_from_this(), path_id, category, initiating_req_id, guard_node] { + if (auto self = weak_self.lock()) + self->_build_path(category, initiating_req_id, {guard_node}, path_id); + }); + return; + } + + OnionPath new_path{path_id, std::move(path_nodes)}; + log::info( + cat, + "[OnionRouter Request {} Path {}]: New {} path is active with nodes: [{}].", + req_id_log, + path_id, + to_string(category, _config.single_path_mode), + new_path.to_string()); + _paths[category].push_back(std::move(new_path)); + _path_build_retries.erase(path_id); + _update_status(); + + // Now, check the queue for any requests that were waiting for this path. + auto queue_it = _request_queues.find(category); + if (queue_it == _request_queues.end()) { + log::critical( + cat, + "[OnionRequestRouter] No request queue for category '{}'.", + to_string(category, _config.single_path_mode)); + return; + } + + auto pending_requests = queue_it->second->pop_all(); + + if (!pending_requests.empty()) { + std::deque> requeue; + log::debug( + cat, + "[OnionRouter Request {} Path {}]: Processing {} queued requests.", + req_id_log, + path_id, + pending_requests.size()); + + for (auto&& [req, cb] : std::move(pending_requests)) { + // Retrieve any path that is valid for the request + OnionPath* path_to_use = _find_valid_path(req); + + if (path_to_use) + _send_on_path(*path_to_use, std::move(req), std::move(cb)); + else + requeue.emplace_back(std::move(req), std::move(cb)); + } + + // Put any un-sendable requests back into the front of the queue (or fail in + // `single_path_mode`) + if (!requeue.empty()) { + if (_config.single_path_mode) { + log::warning( + cat, + "[OnionRouter Path {}]: {} requests could not be sent on the single " + "available path, failing them.", + path_id, + requeue.size()); + for (const auto& [req, cb] : requeue) + cb(false, + false, + -1, + {content_type_plain_text}, + "Request destination conflicts with the only available path in " + "single_path_mode"); + + return; + } + + log::debug( + cat, + "[OnionRouter Path {}]: Unable to process {} queued requests, requing them.", + path_id, + requeue.size()); + + while (!requeue.empty()) { + auto& req_pair = requeue.back(); + queue_it->second->add_front(std::move(req_pair)); + requeue.pop_back(); + } + + if (_in_progress_path_builds[category] == 0) { + log::info( + cat, + "[OnionRequestRouter] Building additional {} path for remaining requests.", + to_string(category, _config.single_path_mode)); + _build_path(category, "requeue-build", {}); + } + } + } + + // Now that we've established a path we need to start observing it in case the connection is + // lost + auto transport = _transport.lock(); + if (!transport) + return; + + transport->add_failure_listener( + ed25519_pubkey::from_bytes(guard_node.view_remote_key()), + [weak_self = weak_from_this(), pid = path_id, category] { + auto self = weak_self.lock(); + if (!self) + return; + + log::warning( + cat, + "[OnionRequestRouter Path {}]: Transport reported connection " + "failure, " + "retiring path.", + pid); + + // Set the failure_count of the path to the max value and report the error + // to trigger a rebuild + auto& active_paths = self->_paths[category]; + auto path_it = std::find_if( + active_paths.begin(), active_paths.end(), [&pid](const auto& p) { + return p.id == pid; + }); + + if (path_it != active_paths.end()) + path_it->failure_count = self->_config.path_failure_threshold; + + self->_handle_path_failure(pid, category, "Guard connection lost"); + }); +} + +OnionPath* OnionRequestRouter::_find_valid_path(const Request& request) { + auto it = _paths.find(request.category); + if (it == _paths.end() || it->second.empty()) + return nullptr; + + std::vector& candidate_paths = it->second; + std::vector suitable_paths; + suitable_paths.reserve(candidate_paths.size()); + + auto target_node = std::get_if(&request.destination); + + for (OnionPath& path : candidate_paths) { + // Ignore failed paths (these should have been removed from the list but better to be safe) + if (path.failure_count >= _config.path_failure_threshold) + continue; + + // Filter by destination conflict + if (target_node) { + bool conflict = false; + + for (const auto& path_node : path.nodes) { + if (path_node == *target_node) { + conflict = true; + break; + } + } + + if (conflict && _config.single_path_mode) + log::warning( + cat, + "[OnionRouter Request {}]: Path destination conflicts with the only " + "available path, but single_path_mode is enabled, proceeding.", + request.request_id); + else if (conflict) + continue; + } + + suitable_paths.push_back(&path); + } + + if (suitable_paths.empty()) + return nullptr; + + PathSelectionBehaviour behaviour = get_path_selection_behaviour(request.category); + + switch (behaviour) { + case PathSelectionBehaviour::new_or_least_busy: { + // Sort by the number of pending requests, ascending + std::sort( + suitable_paths.begin(), + suitable_paths.end(), + [](const OnionPath* a, const OnionPath* b) { + return a->pending_requests < b->pending_requests; + }); + + OnionPath* best_path = suitable_paths.front(); + const auto min_paths_for_type = _config.min_path_counts[request.category]; + + // Return the path with the fewest pending requests if we had one with no requets, or + // already have the minimum number of paths for this type + if (best_path->pending_requests == 0 || candidate_paths.size() >= min_paths_for_type) + return best_path; + + // Otherwise we want to build a new path (for this PathSelectionBehaviour the assuption + // is that it'd be faster to build a new path and send the request along that rather + // than use an existing path) + return nullptr; + } + + case PathSelectionBehaviour::random: + default: + // Shuffle the suitable paths to pick a random one. + std::shuffle(suitable_paths.begin(), suitable_paths.end(), csrng); + return suitable_paths.front(); + } +} + +void OnionRequestRouter::_send_on_path( + OnionPath& path, Request request, network_response_callback_t callback) { + log::trace(cat, "[OnionRouter Request {}]: Sending on path {}", request.request_id, path.id); + + std::vector encrypted_blob; + std::shared_ptr parser; + + try { + auto builder = + session::onionreq::Builder(request.destination, request.endpoint, path.nodes); + encrypted_blob = builder.generate_onion_blob(request.body); + parser = std::make_shared(builder); + } catch (const std::exception& e) { + log::warning( + cat, + "[OnionRouter Request {}]: Failed to prepare onion payload: {}", + request.request_id, + e.what()); + return callback( + false, + false, + -1, + {content_type_plain_text}, + "Failed to construct onion request payload"); + } + + // Construct the actual request to send + std::optional remaining_overall_timeout = + (request.overall_timeout.has_value() ? std::optional{request.time_remaining()} + : std::nullopt); + Request onion_request{ + request.request_id, + network_destination{path.nodes.front()}, // Send to guard node + std::string{"onion_req"}, // Send to onion request handling endpoint + std::move(encrypted_blob), // Encrypted payload + request.category, + request.time_remaining(), + remaining_overall_timeout}; + + // Increment the `pending_requests` and actually send the `onion_request` + path.pending_requests++; + + auto transport = _transport.lock(); + if (!transport) { + log::critical(cat, "[OnionRequestRouter] Transport was destroyed, cannot send request."); + return; + } + + auto decryption_callback = [weak_self = weak_from_this(), + parser = std::move(parser), + path_id = path.id, + original_request = std::move(request), + cb = std::move(callback)]( + bool success, + bool timeout, + int16_t status, + auto headers, + auto response) { + auto self = weak_self.lock(); + if (!self) + return; + + try { + if (!success) + throw std::runtime_error{response.value_or("Unknown request failure")}; + if (timeout) + throw std::runtime_error{response.value_or("Timed out")}; + if (!response) + throw std::runtime_error{"Unexpected empty response"}; + + onionreq::DecryptedResponse decrypted = parser->decrypted_response(*response); + self->_handle_transport_response( + path_id, + std::move(original_request), + true, + false, + decrypted.status_code, + std::move(decrypted.headers), + std::move(decrypted.body), + std::move(cb)); + } catch (const std::exception& e) { + self->_handle_transport_response( + path_id, + std::move(original_request), + false, + timeout, + status, + std::move(headers), + std::move("Failed to handle onion response due to error: {}"_format(e.what())), + std::move(cb)); + } + }; + + transport->send_request(std::move(onion_request), std::move(decryption_callback)); +} + +void OnionRequestRouter::_handle_transport_response( + std::string path_id, + Request original_request, + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional decrypted_body, + network_response_callback_t callback) { + auto final_success = success; + auto final_timeout = timeout; + auto final_status_code = status_code; + std::vector> final_headers = headers; + bool should_penalize_path = false; + bool is_server_dest = std::holds_alternative(original_request.destination); + + if (decrypted_body) + if (auto uniform_error = Response::find_uniform_batch_error(*decrypted_body)) + final_status_code = *uniform_error; + + if (final_success) + final_success = (final_status_code >= 200 && final_status_code <= 299); + + if (!final_success) { + switch (final_status_code) { + // These errors that are NEVER the path's fault + case 400: // Bad Request + case 403: // Forbidden + case 404: // Not Found + case 406: // Not Acceptable (clock skew) + case 425: // Too Early (also clock skew) + // These are application-level or client-side errors. Do nothing to + // the path. + log::trace( + cat, + "[OnionRouter Request {}]: Received benign error {}, path is considered " + "healthy.", + original_request.request_id, + final_status_code); + break; + + // These errors are only the path's fault if the destination is not a + // server + case 500: // Internal Server Error + if (!is_server_dest) + should_penalize_path = true; + break; + + case 504: // Gateway Timeout + final_timeout = true; + + if (!is_server_dest) + should_penalize_path = true; + break; + + // A status of -1 generally indicates either a timeout or some internal error + case -1: break; + + // Any other non-success code is treated as a potential path issue. + default: should_penalize_path = true; break; + } + } + + // If we got a timeout and the destination wasn't a server then we need to + // assume it was from a path node + if (!is_server_dest && timeout) + should_penalize_path = true; + + // Handle the failure if needed + if (should_penalize_path) { + log::debug( + cat, + "[OnionRouter Request {}]: Received error {} on path {}, handling " + "failure.", + original_request.request_id, + final_status_code, + path_id); + _handle_path_failure(path_id, original_request.category, decrypted_body); + } + + // Clean up paths if needed + _decrement_and_cleanup_path(path_id, original_request.category); + + // Now we can trigger the callback with the result + return callback( + final_success, + final_timeout, + final_status_code, + std::move(headers), + std::move(decrypted_body)); +} + +void OnionRequestRouter::_decrement_and_cleanup_path( + const std::string& path_id, RequestCategory category) { + // Check active paths first + auto& active_paths = _paths[category]; + + if (auto it = std::find_if( + active_paths.begin(), + active_paths.end(), + [&path_id](const auto& p) { return p.id == path_id; }); + it != active_paths.end()) { + if (it->pending_requests > 0) + it->pending_requests--; + + // The path is still active so we don't need to do anything else + return; + } + + // If we didn't find an active path then check paths pending drop + auto& dying_paths = _paths_pending_drop[category]; + if (auto it = std::find_if( + dying_paths.begin(), + dying_paths.end(), + [&path_id](const auto& p) { return p.id == path_id; }); + it != dying_paths.end()) { + if (it->pending_requests > 0) + it->pending_requests--; + + // If this was the last request, we can now safely delete the path + if (it->pending_requests == 0) { + log::debug( + cat, + "[OnionRequestRouter] Retiring path {} as it has no more pending requests.", + path_id); + dying_paths.erase(it); + } + + return; + } + + // This can happen if the path was already retired and removed, it's not an error + log::trace( + cat, + "[OnionRequestRouter] Request completed on path {}, which has already been removed.", + path_id); +} + +void OnionRequestRouter::_handle_path_failure( + const std::string& path_id, + const RequestCategory& request_category, + const std::optional& error_body) { + auto& active_paths = _paths[request_category]; + auto path_it = + std::find_if(active_paths.begin(), active_paths.end(), [&path_id](const auto& p) { + return p.id == path_id; + }); + + // If the path is no longer in the active list then no need to do anything + if (path_it == active_paths.end()) { + log::trace( + cat, + "[OnionRouter Path {}]: Failure on path, but path is no longer active.", + path_id); + return; + } + + // Increment the `failure_count` on the path + OnionPath& path = *path_it; + path.failure_count++; + + // If the path is still potentially valid then check if the response has one of the + // 'node_not_found' prefixes + if (path.failure_count < _config.path_failure_threshold) { + std::optional ed25519PublicKey; + + if (error_body) { + if (error_body->starts_with(node_not_found_prefix)) + ed25519PublicKey = {error_body->data() + node_not_found_prefix.size()}; + else if (error_body->starts_with(node_not_found_prefix_no_status)) + ed25519PublicKey = {error_body->data() + node_not_found_prefix_no_status.size()}; + } + + // If we found a result then try to extract the pubkey and replace that node in the path. We + // do still want to increment the `failure_count` on the path in this case to prevent a + // rogue relay from using this error as a mechanism to take full control of the path + if (ed25519PublicKey && ed25519PublicKey->size() == 64 && + oxenc::is_hex(*ed25519PublicKey)) { + try { + session::network::ed25519_pubkey bad_node_pk = + session::network::ed25519_pubkey::from_hex(*ed25519PublicKey); + auto edpk_view = to_span(bad_node_pk.view()); + + auto bad_node_it = std::find_if( + path.nodes.begin(), path.nodes.end(), [&edpk_view](const auto& node) { + return to_string_view(node.view_remote_key()) == + to_string_view(edpk_view); + }); + + if (bad_node_it != path.nodes.end()) { + log::debug( + cat, + "[OnionRouter Path {}]: Failure identified for specific node {}.", + path.id, + bad_node_pk.view()); + std::vector replacements; + + auto snode_pool = _snode_pool.lock(); + if (!snode_pool) { + log::critical( + cat, + "[OnionRequestRouter] Cannot repair path as SnodePool was " + "destroyed, dropping instead."); + path.failure_count = _config.path_failure_threshold; + return; + } + + // Flag the bad node as permanently failed until the next cache refresh + snode_pool->record_node_failure(*bad_node_it, true); + + auto used_nodes = extract_nodes(_paths, _pending_paths); + replacements = snode_pool->get_unused_nodes(1, used_nodes); + + // If we found a replacement node then swap out the bad one and reset the + // path failure count (assume the bad node was the cause of any failures), + // we can then stop here (the path is repaired so no need to continue) + if (!replacements.empty()) { + log::info( + cat, + "[OnionRouter Path {}]: Repairing path by replacing node {} " + "with {}.", + path.id, + bad_node_it->to_string(), + replacements[0].to_string()); + *bad_node_it = replacements[0]; + } else { + log::warning( + cat, + "[OnionRouter Path {}]: Cannot repair path due to lack of " + "replacement node, dropping instead.", + path.id); + path.failure_count = _config.path_failure_threshold; + } + } + } catch (...) { /* Invalid pubkey, fall through to general failure */ + } + } + } + + log::debug( + cat, + "[OnionRouter Path {}]: Recorded failure, total failures: {}/{}", + path.id, + path.failure_count, + _config.path_failure_threshold); + + // If the path has exceeded its failure threshold, retire it. + if (path.failure_count >= _config.path_failure_threshold) { + log::warning( + cat, "[OnionRouter Path {}]: Path has exceeded its failure threshold.", path.id); + + // Tell the SnodePool that all nodes on this path are now suspect + if (auto snode_pool = _snode_pool.lock()) + for (const auto& node : path.nodes) + snode_pool->record_node_failure(node); + + // Remove failure listeners for the path + if (auto transport = _transport.lock()) + if (!path.nodes.empty()) + transport->remove_failure_listeners( + ed25519_pubkey::from_bytes(path.nodes[0].view_remote_key())); + + // Store for subsequent path building + const auto old_path_id = path.id; + auto nodes_to_exclude = path.nodes; + + if (path.pending_requests == 0) { + log::debug(cat, "[OnionRouter Path {}]: Retiring idle path immediately.", old_path_id); + active_paths.erase(path_it); + _update_status(); + } else { + log::debug( + cat, + "[OnionRouter Path {}]: Retiring active path ({} pending requests), moving to " + "pending drop.", + old_path_id, + path.pending_requests); + _paths_pending_drop[request_category].push_back(std::move(path)); + active_paths.erase(path_it); + _update_status(); + } + + // Automatically rebuild if needed + RequestCategory category_to_rebuild = + (_config.single_path_mode ? RequestCategory::standard : request_category); + const auto min_paths = + (_config.single_path_mode ? 1 : _config.min_path_counts.at(category_to_rebuild)); + const auto current_active = + (_paths.count(category_to_rebuild) ? _paths.at(category_to_rebuild).size() : 0); + const auto in_progress = _in_progress_path_builds[category_to_rebuild]; + + if (current_active + in_progress < min_paths) { + log::info( + cat, + "[OnionRequestRouter] Path count for {} is below the minimum {}, building " + "replacement.", + to_string(request_category, _config.single_path_mode), + min_paths); + _build_path(request_category, "failure-replacement-" + old_path_id, nodes_to_exclude); + } + } +} + +} // namespace session::network diff --git a/src/network/service_node.cpp b/src/network/service_node.cpp new file mode 100644 index 00000000..d0f472ea --- /dev/null +++ b/src/network/service_node.cpp @@ -0,0 +1,328 @@ +#include "session/network/service_node.hpp" + +#include + +#include +#include +#include + +using namespace oxen; +using namespace oxen::log::literals; + +namespace session::network { + +session::network::x25519_pubkey service_node::swarm_pubkey() const { + return session::network::compute_x25519_pubkey(view_remote_key()); +} + +std::string service_node::to_string() const { + return oxenc::to_hex(_remote_pubkey); +} + +std::string service_node::to_https_string() const { + return "{}:{}"_format(host(), https_port); +} + +std::string service_node::to_omq_string() const { + return "{}:{}"_format(host(), omq_port); +} + +service_node service_node::from(const network_service_node& node) { + std::vector pubkey; + pubkey.reserve(32); + oxenc::from_hex( + node.ed25519_pubkey_hex, node.ed25519_pubkey_hex + 64, std::back_inserter(pubkey)); + + return {std::move(pubkey), + oxen::quic::ipv4{std::span(node.ip, 4)}, + node.https_port, + node.omq_port, + {node.version[0], node.version[1], node.version[2]}, + node.swarm_id}; +} + +void service_node::into(network_service_node& n) const { + auto ed25519_pubkey_hex = oxenc::to_hex(view_remote_key()); + strncpy(n.ed25519_pubkey_hex, ed25519_pubkey_hex.c_str(), 64); + n.ed25519_pubkey_hex[64] = '\0'; // Ensure null termination + n.ip[0] = (ip.addr >> 24) & 0xFF; + n.ip[1] = (ip.addr >> 16) & 0xFF; + n.ip[2] = (ip.addr >> 8) & 0xFF; + n.ip[3] = ip.addr & 0xFF; + n.https_port = https_port; + n.omq_port = omq_port; + std::memcpy(n.version, storage_server_version.data(), sizeof(storage_server_version)); + n.swarm_id = swarm_id; +} + +service_node service_node::legacy_from_json(nlohmann::json json) { + auto pk_ed = json["pubkey_ed25519"].get(); + if (pk_ed.size() != 64 || !oxenc::is_hex(pk_ed)) + throw std::invalid_argument{ + "Invalid service node json: pubkey_ed25519 is not a valid, hex pubkey"}; + + std::vector pubkey; + pubkey.reserve(32); + oxenc::from_hex(pk_ed.begin(), pk_ed.end(), std::back_inserter(pubkey)); + + // When parsing a node from JSON it'll generally be from the 'get_swarm` endpoint or a 421 + // error neither of which contain the `storage_server_version` - luckily we don't need the + // version for these two cases so can just default it to `0.0.0` + std::array storage_server_version = {0, 0, 0}; + if (json.contains("storage_server_version")) { + if (json["storage_server_version"].is_array()) { + if (json["storage_server_version"].size() > 0) { + // Convert the version to a string and parse it back into a version code to + // ensure the version formats remain consistent throughout + auto json_version = json["storage_server_version"].get>(); + + for (size_t i = 0; i < 3; ++i) + storage_server_version[i] = + (i < json_version.size() ? static_cast(json_version[i]) : 0); + } + } else { + auto json_version = json["storage_server_version"].get(); + auto split_version = session::split(json_version, "."); + + for (size_t i = 0; i < 3 && i < split_version.size(); ++i) { + int value; + + if (!quic::parse_int(split_version[i], value)) + throw std::invalid_argument{"Invalid version"}; + + storage_server_version[i] = static_cast(value); + } + } + } + + std::string ip; + if (json.contains("public_ip")) + ip = json["public_ip"].get(); + else + ip = json["ip"].get(); + + if (ip == "0.0.0.0") + throw std::runtime_error{"Invalid IP address"}; + + uint16_t https_port; + if (json.contains("storage_https_port")) + https_port = json["storage_https_port"].get(); + else if (json.contains("storage_port")) + https_port = json["storage_port"].get(); + else + https_port = json["port_https"].get(); + + uint16_t omq_port; + if (json.contains("storage_lmq_port")) + omq_port = json["storage_lmq_port"].get(); + else + omq_port = json["port_omq"].get(); + + if (omq_port == 0) + throw std::runtime_error{"Invalid omq port"}; + + swarm_id_t swarm_id = INVALID_SWARM_ID; + if (json.contains("swarm_id")) + swarm_id = json["swarm_id"].get(); + + return {pubkey, quic::ipv4{ip}, https_port, omq_port, storage_server_version, swarm_id}; +} + +service_node service_node::legacy_from_disk(std::string_view str) { + // Format is "{ip}|{port}|{version}|{ed_pubkey}|{swarm_id}" + auto parts = split(str, "|"); + if (parts.size() != 5) + throw std::invalid_argument("Invalid service node serialisation: {}"_format(str)); + if (parts[3].size() != 64 || !oxenc::is_hex(parts[3])) + throw std::invalid_argument{ + "Invalid service node serialisation: pubkey is not hex or has wrong size"}; + + uint16_t port; + if (!quic::parse_int(parts[1], port)) + throw std::invalid_argument{"Invalid service node serialization: invalid port"}; + + auto version_parts = split(parts[2], "."); + std::array version_array{0, 0, 0}; + for (size_t i = 0; i < std::min(size_t{3}, version_parts.size()); ++i) { + uint16_t v; + + if (quic::parse_int(version_parts[i], v)) + version_array[i] = v; + } + + if (version_array == std::array{0, 0, 0}) + throw std::invalid_argument{"Invalid service node serialization: invalid version"}; + + swarm_id_t swarm_id = INVALID_SWARM_ID; + quic::parse_int(parts[4], swarm_id); + + std::vector pubkey; + pubkey.reserve(32); + oxenc::from_hex(parts[3].begin(), parts[3].end(), std::back_inserter(pubkey)); + + return { + pubkey, // ed25519_pubkey + quic::ipv4{std::string{parts[0]}}, // ip + 0, // https_port + port, // omq_port + version_array, // storage_server_version + swarm_id // swarm_id + }; +} + +std::string service_node::legacy_to_disk() const { + // Format is "{ip}|{port}|{version}|{ed_pubkey}|{swarm_id}" + auto ed25519_pubkey_hex = oxenc::to_hex(view_remote_key()); + + return fmt::format( + "{}|{}|{}.{}.{}|{}|{}", + host(), + omq_port, + storage_server_version[0], + storage_server_version[1], + storage_server_version[2], + ed25519_pubkey_hex, + swarm_id); +} + +service_node service_node::from_disk(std::string_view str) { + // Format is "{ed_pubkey}|{ip}|{https_port}|{omq_port}|{version}|{swarm_id}" + auto parts = split(str, "|"); + if (parts.size() != 6) + throw std::invalid_argument("Invalid service node serialisation: {}"_format(str)); + if (parts[0].size() != 64 || !oxenc::is_hex(parts[0])) + throw std::invalid_argument{ + "Invalid service node serialisation: pubkey is not hex or has wrong size"}; + + std::vector pubkey; + pubkey.reserve(32); + oxenc::from_hex(parts[0].begin(), parts[0].end(), std::back_inserter(pubkey)); + + uint16_t https_port, omq_port; + if (!quic::parse_int(parts[2], https_port)) + throw std::invalid_argument{"Invalid service node serialization: invalid https_port"}; + if (!quic::parse_int(parts[3], omq_port)) + throw std::invalid_argument{"Invalid service node serialization: invalid omq_port"}; + + auto version_parts = split(parts[4], "."); + std::array version_array{0, 0, 0}; + for (size_t i = 0; i < std::min(size_t{3}, version_parts.size()); ++i) { + uint16_t v; + + if (quic::parse_int(version_parts[i], v)) + version_array[i] = v; + } + + if (version_array == std::array{0, 0, 0}) + throw std::invalid_argument{"Invalid service node serialization: invalid version"}; + + swarm_id_t swarm_id = INVALID_SWARM_ID; + quic::parse_int(parts[5], swarm_id); + + return {pubkey, + quic::ipv4{std::string{parts[1]}}, + https_port, + omq_port, + version_array, + swarm_id}; +} + +std::pair, int> service_node::process_snode_cache_bin( + std::vector cache_bin) { + constexpr size_t SNODE_SIZE = 51; + constexpr size_t PK_SIZE = 32; + constexpr size_t SWARM_ID_SIZE = 8; + constexpr size_t IP_SIZE = 4; + constexpr size_t HTTPS_PORT_SIZE = 2; + constexpr size_t OMQ_PORT_SIZE = 2; + constexpr size_t VERSION_SIZE = 3; + + // Sanity check field sizes + static_assert( + PK_SIZE + SWARM_ID_SIZE + IP_SIZE + HTTPS_PORT_SIZE + OMQ_PORT_SIZE + VERSION_SIZE == + SNODE_SIZE, + "Field sizes do not sum to snode size"); + + if (cache_bin.size() % SNODE_SIZE != 0) + throw std::runtime_error{ + "Snode cache size is not a multiple of snode size ({})."_format(SNODE_SIZE)}; + + // Parse the binary + int failed_nodes = 0; + std::vector nodes; + nodes.reserve(cache_bin.size() / SNODE_SIZE); + + const std::byte* current_ptr = cache_bin.data(); + const std::byte* const end_ptr = cache_bin.data() + cache_bin.size(); + + while (current_ptr < end_ptr) { + const std::byte* note_ptr = current_ptr; + + try { + // Pubkey + std::vector pubkey; + pubkey.assign( + reinterpret_cast(current_ptr), + reinterpret_cast(current_ptr) + PK_SIZE); + note_ptr += PK_SIZE; + + // Swarm ID + uint64_t swarm_id_u64 = 0; + for (int i = 0; i < SWARM_ID_SIZE; ++i) + swarm_id_u64 = (swarm_id_u64 << 8) | + static_cast(static_cast(note_ptr[i])); + + swarm_id_t swarm_id = static_cast(swarm_id_u64); + note_ptr += SWARM_ID_SIZE; + + // Public IP + std::span ip_bytes_span( + reinterpret_cast(note_ptr), IP_SIZE); + quic::ipv4 ip(ip_bytes_span); + note_ptr += IP_SIZE; + + // IP can be 0 (ie. node is not in a valid state for use yet) + if (ip.addr == 0) + throw std::runtime_error{"Invalid IP"}; + + // HTTPS port + uint16_t https_port = + (static_cast(static_cast(note_ptr[0])) << 8) | + (static_cast(static_cast(note_ptr[1]))); + note_ptr += HTTPS_PORT_SIZE; + + // QUIC port + uint16_t quic_port = + (static_cast(static_cast(note_ptr[0])) << 8) | + (static_cast(static_cast(note_ptr[1]))); + note_ptr += OMQ_PORT_SIZE; + + // quic_port can be 0 (ie. node is not in a valid state for use yet) + if (quic_port == 0) + throw std::runtime_error{"Invalid QUIC port"}; + + // Storage server version + std::array version_array{0, 0, 0}; + for (size_t i = 0; i < VERSION_SIZE; ++i) + version_array[i] = static_cast(static_cast(note_ptr[i])); + note_ptr += VERSION_SIZE; + + nodes.emplace_back( + std::move(pubkey), + ip, + https_port, + quic_port, + std::move(version_array), + swarm_id); + } catch (...) { + failed_nodes++; + } + + // Move the ptr to the start of the next node + current_ptr += SNODE_SIZE; + } + + return {nodes, failed_nodes}; +} + +} // namespace session::network \ No newline at end of file diff --git a/src/network/session_network.cpp b/src/network/session_network.cpp new file mode 100644 index 00000000..5b5d05cf --- /dev/null +++ b/src/network/session_network.cpp @@ -0,0 +1,1244 @@ +#include "session/network/session_network.hpp" + +#include + +#include +#include +#include +#include + +#include "session/blinding.hpp" +#include "session/network/network_config.hpp" +#include "session/network/network_opt.hpp" +#include "session/network/routing/direct_router.hpp" +#include "session/network/routing/lokinet_router.hpp" +#include "session/network/routing/onion_request_router.hpp" +#include "session/network/session_network.h" +#include "session/network/session_network_types.hpp" +#include "session/network/transport/quic_transport.hpp" +#include "session/random.hpp" + +using namespace oxen; +using namespace session::network; +using namespace session::network::config; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network { + +namespace { + + inline auto cat = log::Cat("network"); + + constexpr auto file_server = "filev2.getsession.org"sv; + constexpr auto file_server_pubkey = + "da21e1d886c6fbaea313f75298bd64aab03a97ce985b46bb2dad9f2089c8ee59"sv; + + config::SnodePoolConfig build_snode_pool_config(const config::Config& main_config) { + return {main_config.cache_directory, + main_config.cache_expiration, + main_config.cache_min_lifetime, + main_config.enforce_subnet_diversity, + main_config.retry_delay, + main_config.netid, + main_config.seed_nodes, + main_config.cache_min_size, + main_config.cache_num_nodes_to_use_for_refresh, + main_config.cache_node_failure_threshold, + main_config.cache_refresh_using_legacy_endpoint}; + } + + config::QuicTransportConfig build_quic_transport_config(const config::Config& main_config) { + return {main_config.quic_handshake_timeout, + main_config.quic_keep_alive, + main_config.quic_disable_mtu_discovery}; + } + + config::LokinetRouterConfig build_lokinet_router_config(const config::Config& main_config) { + if (!main_config.cache_directory) + throw std::invalid_argument{"Lokinet requires a cache_directory to be configured."}; + + if (main_config.netid == opt::netid::Target::devnet) + throw std::invalid_argument{"Lokinet does not support devnet."}; + + return {main_config.netid, + *main_config.cache_directory, + main_config.request_timeout_check_frequency, + main_config.path_length}; + } + + config::OnionRequestRouterConfig build_onion_request_router_config( + const config::Config& main_config) { + return {main_config.retry_delay, + main_config.request_timeout_check_frequency, + main_config.path_length, + main_config.onionreq_path_failure_threshold, + main_config.onionreq_path_build_retry_limit, + main_config.onionreq_disable_pre_build_paths, + main_config.onionreq_single_path_mode, + main_config.onionreq_min_path_counts}; + } + +} // namespace + +namespace detail { + + std::vector convert_service_nodes( + std::vector nodes) { + std::vector converted_nodes; + for (auto& node : nodes) { + network_service_node converted_node; + node.into(converted_node); + converted_nodes.push_back(converted_node); + } + + return converted_nodes; + } + +} // namespace detail + +Network::Network(config::Config config) : config{config} { + // Start by validating the configuration + switch (config.transport) { + case opt::transport::Type::quic: break; + case opt::transport::Type::callbacks: + break; + if (!config.callbacks_callback) + throw std::invalid_argument{"Callbacks requires a callback to be provided."}; + break; + } + + // Now we can properly do any setup needed + _loop = std::make_shared(); + + // Setup the transport layer + switch (config.transport) { + case opt::transport::Type::quic: + _transport = std::make_shared( + std::move(build_quic_transport_config(config)), _loop); + break; + + case opt::transport::Type::callbacks: + // _transport = std::make_shared(_config, *_snode_pool, _loop); + break; + } + + // The SnodePool is needed regardless of the transport layer as it includes swarm information + // which is needed by the clients in order to send requests + auto bootstrap_fetcher = [bt = std::weak_ptr{_transport}]( + Request req, network_response_callback_t on_complete) { + if (auto transport = bt.lock()) + transport->send_request(std::move(req), std::move(on_complete)); + else + log::error( + cat, + "Transport provided to the SnodePool bootstrap fetcher has been destroyed."); + }; + _snode_pool = std::make_shared( + std::move(build_snode_pool_config(config)), _loop, bootstrap_fetcher); + + // Additional transport configuration + _transport->set_node_failure_reporter( + [pool = _snode_pool.get()](const ed25519_pubkey& pubkey, bool permanent) { + if (pool) + pool->record_node_failure(pubkey, permanent); + }); + + // Setup the router + switch (config.router) { + case opt::router::Type::onion_requests: + _router = std::make_unique( + std::move(build_onion_request_router_config(config)), + _loop, + _snode_pool, + _transport); + break; + + case opt::router::Type::lokinet: + _router = std::make_unique( + std::move(build_lokinet_router_config(config)), _loop, _snode_pool, _transport); + break; + + case opt::router::Type::direct: + _router = std::make_unique(_loop, _transport); + break; + } + + // Now that we have our router setup we need to setup the `standard_fetcher` on the `SnodePool` + auto routed_fetcher = [r = std::weak_ptr{_router}, loop = _loop]( + Request req, network_response_callback_t on_complete) { + loop->call([r, req = std::move(req), on_complete = std::move(on_complete)] { + if (auto router = r.lock()) + router->send_request(std::move(req), std::move(on_complete)); + else + log::error( + cat, "Router provided to the SnodePool routed_fetcher has been destroyed."); + }); + }; + auto routed_fetcher_connected = [r = std::weak_ptr{_router}, loop = _loop]() -> bool { + return loop->call_get([r] { + if (auto router = r.lock()) + return router->get_status() == ConnectionStatus::connected; + + return false; + }); + }; + _snode_pool->set_routed_fetcher(std::move(routed_fetcher), std::move(routed_fetcher_connected)); + + // Add hooks to update the connection status + _router->on_status_changed = [this] { _recalculate_status(); }; + _transport->on_status_changed = [this] { _recalculate_status(); }; +} + +Network::~Network() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { _update_status(ConnectionStatus::disconnected); }); + log::debug(cat, "[Network] Destroyed."); +} + +void Network::clear_cache() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + if (_snode_pool) + _snode_pool->clear_cache(); + if (_router) + _router->clear_cache(); + }); +} + +// MARK: Connection + +void Network::suspend() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + _suspended = true; + + if (_snode_pool) + _snode_pool->suspend(); + if (_transport) + _transport->suspend(); + if (_router) + _router->suspend(); + + _close_connections(); + log::info(cat, "Suspended."); + }); +} + +void Network::resume(bool automatically_reconnect) { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this, automatically_reconnect] { + if (!_suspended) + return; + + if (_snode_pool) + _snode_pool->resume(); + if (_transport) + _transport->resume(automatically_reconnect); + if (_router) + _router->resume(automatically_reconnect); + + _suspended = false; + log::info(cat, "Resumed."); + }); +} + +void Network::close_connections() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { _close_connections(); }); +} + +// MARK: Interface + +ConnectionStatus Network::get_status() { + return _status.load(); +} + +std::vector Network::get_active_paths() { + if (_router) + return _router->get_active_paths(); + + return {}; +} + +void Network::get_swarm( + session::network::x25519_pubkey swarm_pubkey, + std::function swarm)> callback) { + _snode_pool->get_swarm(std::move(swarm_pubkey), std::move(callback)); +} + +void Network::get_random_nodes( + uint16_t count, std::function nodes)> callback) { + _loop->call([this, count, cb = std::move(callback)] { + auto unused_nodes = _snode_pool->get_unused_nodes(count); + + // If we don't have sufficient nodes then we need to refresh the snode cache + if (unused_nodes.size() < count) { + std::vector nodes_to_exclude = _router->get_all_used_nodes(); + + return _snode_pool->refresh_if_needed( + nodes_to_exclude, + [this, count, cb = std::move(cb)] { get_random_nodes(count, cb); }); + } + cb(unused_nodes); + }); +} + +void Network::send_request(Request request, network_response_callback_t callback) { + if (!_transport) + return callback( + false, false, -1, {content_type_plain_text}, "No transport layer configured"); + if (!_router) + return callback(false, false, -1, {content_type_plain_text}, "No router configured"); + + try { + auto processed_request = _preprocess_request(std::move(request)); + auto router_callback = + [this, original_req = processed_request, cb = std::move(callback)]( + bool success, bool timeout, int16_t status_code, auto headers, auto body) { + // If we got a successful response (with a body) and the request was sent to a + // service node then we should update the network state based on the response + // (Note: we don't want to do this for server requests because they could + // include values in different formats, eg. the "Session Network" API returns + // `t` in seconds) + if (success && body && + std::holds_alternative(original_req.destination)) + _update_network_state(*body); + + int16_t final_status_code = status_code; + + if (body) + if (auto uniform_error = Response::find_uniform_batch_error(*body)) + final_status_code = *uniform_error; + + // If we got a 421 then our swarm info is out of data so we need to refresh our + // cache, the original request might succeed after this refresh so we should + // just automatically retry + if (final_status_code == 421) { + _handle_421_retry(std::move(original_req), std::move(cb)); + return; + } + + // For debugging purposes we want to add a log if this was a successful request + // after we did an automatic retry + if (original_req.retry_count > 0) + log::info( + cat, + "[Request {}] Received valid response after 421 retry.", + original_req.request_id); + + auto final_success = + (success && final_status_code >= 200 && final_status_code <= 299); + cb(final_success, timeout, status_code, std::move(headers), std::move(body)); + }; + + _router->send_request(std::move(processed_request), std::move(router_callback)); + } catch (const std::exception& e) { + return callback(false, false, -1, {content_type_plain_text}, e.what()); + } +} + +// MARK: Internal Logic + +void Network::_close_connections() { + if (_transport) + _transport->close_connections(); + if (_router) + _router->close_connections(); + + _recalculate_status(); + log::info(cat, "Closed all connections."); +} + +void Network::_recalculate_status() { + _loop->call([this] { + if (!_transport || !_router) + return _update_status(ConnectionStatus::disconnected); + + auto transport_status = _transport->get_status(); + auto router_status = _router->get_status(); + + // If both layers report being fully connected then we are connected + if (transport_status == ConnectionStatus::connected && + router_status == ConnectionStatus::connected) + _update_status(ConnectionStatus::connected); + // If either layer is disconnected, the whole system is disconnected + else if ( + transport_status == ConnectionStatus::disconnected || + router_status == ConnectionStatus::disconnected) + _update_status(ConnectionStatus::disconnected); + // If either layer is trying to connect, the whole system is connecting + else if ( + transport_status == ConnectionStatus::connecting || + router_status == ConnectionStatus::connecting) + _update_status(ConnectionStatus::connecting); + // Otherwise, we are in an unknown state + else + _update_status(ConnectionStatus::unknown); + }); +} + +void Network::_update_status(ConnectionStatus new_status) { + if (_status == new_status) + return; + + _status = new_status; + + if (on_status_changed) + on_status_changed(new_status); +} + +Request Network::_preprocess_request(Request request) { + std::visit( + [&](auto&& details) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + if (!request.body) + throw std::invalid_argument("Upload request must have a body."); + + if (request.category != RequestCategory::upload) { + log::warning( + cat, + "Request {} has UploadInfo but category is not 'upload', forcing " + "to 'upload'.", + request.request_id); + request.category = RequestCategory::upload; + } + + // Add the required headers if they weren't provided + if (auto* dest = std::get_if(&request.destination)) { + if (!dest->headers) + dest->headers.emplace(); + + std::unordered_set existing_keys; + if (dest->headers) + for (const auto& [key, val] : *dest->headers) + existing_keys.insert(key); + + if (existing_keys.find("Content-Type") == existing_keys.end()) + dest->headers->emplace_back("Content-Type", "application/octet-stream"); + + if (existing_keys.find("Content-Disposition") == existing_keys.end()) { + if (details.file_name) + dest->headers->emplace_back( + "Content-Disposition", + fmt::format( + "attachment; filename=\"{}\"", *details.file_name)); + else + dest->headers->emplace_back("Content-Disposition", "attachment"); + } + } + } else if constexpr (std::is_same_v) { /* No special handling */ + } + }, + request.details); + + return request; +} + +void Network::_update_network_state(const std::string& body) { + try { + auto json = nlohmann::json::parse(body); + const nlohmann::json* target_json = &json; + + // If it was a batch/sequence request then take the one with the highest "t" value as that + // would have been the one which was returned last + if (json.contains("results") && json["results"].is_array()) { + log::trace(cat, "Parsing batch response for latest network state."); + + int64_t max_t = -1; + const nlohmann::json* latest_body = nullptr; + + for (const auto& result : json["results"]) { + if (!result.is_object() || !result.contains("body") || !result["body"].is_object()) + continue; + + const auto& result_body = result["body"]; + + if (result_body.contains("t") && result_body["t"].is_number()) { + int64_t current_t = result_body["t"].get(); + + if (current_t > max_t) { + max_t = current_t; + latest_body = &result_body; + } + } + } + + if (latest_body) + target_json = latest_body; + } + + auto old_offset = _network_time_offset.load(); + auto old_versions = _fork_versions.load(); + + // Update time offset + if (target_json->contains("t") && (*target_json)["t"].is_number()) { + auto server_timestamp_ms = (*target_json)["t"].get(); + auto server_time = std::chrono::time_point( + std::chrono::milliseconds{server_timestamp_ms}); + auto now = std::chrono::system_clock::now(); + _network_time_offset = + std::chrono::duration_cast(server_time - now); + + log::trace(cat, "Network offset set to: {}", (server_time - now).count()); + } + + // Update hardfork/softfork versions + if (target_json->contains("hf") && (*target_json)["hf"].is_array() && + (*target_json)["hf"].size() >= 2) { + std::pair new_versions = { + (*target_json)["hf"][0].get(), + (*target_json)["hf"][1].get()}; + + auto current_versions = old_versions; + auto desired_next_versions = current_versions; + + if (new_versions.first > desired_next_versions.hardfork) + desired_next_versions = {new_versions.first, new_versions.second}; + else if ( + new_versions.first == desired_next_versions.hardfork && + new_versions.second > desired_next_versions.softfork) + desired_next_versions.softfork = new_versions.second; + + if (current_versions != desired_next_versions) + _fork_versions.compare_exchange_weak(current_versions, desired_next_versions); + log::trace( + cat, + "Fork version set to: {}.{}", + desired_next_versions.hardfork, + desired_next_versions.softfork); + } + + // If the network info changed then call the callback + if (on_network_info_changed) { + auto new_offset = _network_time_offset.load(); + auto new_versions = _fork_versions.load(); + + if (new_offset != old_offset || new_versions != old_versions) + on_network_info_changed(new_offset, new_versions.hardfork, new_versions.softfork); + } + } catch (const std::exception& e) { + log::warning(cat, "Failed to parse network state from response: {}", e.what()); + } +} + +void Network::_handle_421_retry( + Request original_request, network_response_callback_t final_callback) { + if (original_request.retry_count >= config.redirect_retry_count) { + log::error( + cat, + "Request {} received 421 but exceeded max retry count.", + original_request.request_id); + return final_callback( + false, false, 421, {content_type_plain_text}, "Exceeded retry limit for 421 error"); + } + + // Shouldn't automatically retry if the destination isn't a node (we on'y want to auto-retry due + // to a node being in the wrong swarm) + auto* original_dest_node = std::get_if(&original_request.destination); + if (!original_dest_node) + return final_callback( + false, + false, + 421, + {content_type_plain_text}, + "Received 421 from a non-service-node destination"); + + // If we got a 421 it means our snode cache is outdated (because the swarm the destination node + // belongs to doesn't match our cache anymore) + log::info( + cat, + "Request {} received 421 from node {}, refreshing swarm if stale.", + original_request.request_id, + original_dest_node->to_string()); + + auto failed_node_copy = *original_dest_node; + std::vector nodes_to_exclude = _router->get_all_used_nodes(); + _snode_pool->refresh_if_needed( + std::move(nodes_to_exclude), + [this, + req_to_retry = std::move(original_request), + cb = std::move(final_callback), + failed_node = failed_node_copy] { + auto swarm_pubkey = failed_node.swarm_pubkey(); + + _snode_pool->get_swarm( + swarm_pubkey, + [this, + req_to_retry = std::move(req_to_retry), + cb = std::move(cb), + failed_node](swarm::swarm_id_t, std::vector swarm_nodes) { + std::optional new_target; + std::shuffle(swarm_nodes.begin(), swarm_nodes.end(), csrng); + + for (const auto& node : swarm_nodes) { + if (node != failed_node) { + new_target = node; + break; + } + } + + if (!new_target) + return cb( + false, + false, + 421, + {content_type_plain_text}, + "421 Misdirected Request, but no other nodes in swarm to " + "retry"); + + log::info( + cat, + "Request {} retrying 421 error on new node {}.", + req_to_retry.request_id, + new_target->to_string()); + auto final_request = req_to_retry; + final_request.retry_count++; + final_request.destination = *new_target; + this->send_request(std::move(final_request), std::move(cb)); + }); + }); +} + +} // namespace session::network + +// MARK: C API + +struct session_response_handle_cpp_t { + session::network::network_response_callback_t cpp_callback; +}; + +namespace { + +inline session::network::Network& unbox(network_object* network_) { + assert(network_ && network_->internals); + return *static_cast(network_->internals); +} + +inline bool set_error(char* error, const std::exception& e) { + if (!error) + return false; + + std::string msg = e.what(); + if (msg.size() > 255) + msg.resize(255); + std::memcpy(error, msg.c_str(), msg.size() + 1); + return false; +} + +} // namespace + +extern "C" { + +using namespace session; +using namespace session::network; + +LIBSESSION_C_API session_network_config session_network_config_default() { + Config cpp_defaults{}; + session_network_config config = {}; + + switch (cpp_defaults.netid) { + case opt::netid::Target::mainnet: config.netid = SESSION_NETWORK_MAINNET; + case opt::netid::Target::testnet: config.netid = SESSION_NETWORK_TESTNET; + case opt::netid::Target::devnet: config.netid = SESSION_NETWORK_DEVNET; + default: config.netid = SESSION_NETWORK_MAINNET; + } + + switch (cpp_defaults.router) { + case opt::router::Type::onion_requests: + config.router = SESSION_NETWORK_ROUTER_ONION_REQUESTS; + case opt::router::Type::lokinet: config.router = SESSION_NETWORK_ROUTER_LOKINET; + case opt::router::Type::direct: config.router = SESSION_NETWORK_ROUTER_DIRECT; + default: config.router = SESSION_NETWORK_ROUTER_ONION_REQUESTS; + } + + switch (cpp_defaults.transport) { + case opt::transport::Type::quic: config.transport = SESSION_NETWORK_TRANSPORT_QUIC; + case opt::transport::Type::callbacks: + config.transport = SESSION_NETWORK_TRANSPORT_CALLBACKS; + default: config.transport = SESSION_NETWORK_TRANSPORT_QUIC; + } + + config.path_length = cpp_defaults.path_length; + config.enforce_subnet_diversity = cpp_defaults.enforce_subnet_diversity; + config.redirect_retry_count = cpp_defaults.redirect_retry_count; + config.min_retry_delay_ms = cpp_defaults.retry_delay.base_delay.count(); + config.max_retry_delay_ms = cpp_defaults.retry_delay.max_delay.count(); + config.request_timeout_check_frequency_ms = + cpp_defaults.request_timeout_check_frequency.count(); + + config.devnet_seed_nodes = nullptr; + config.devnet_seed_nodes_size = 0; + + config.cache_dir = nullptr; + config.cache_expiration_minutes = + std::chrono::duration_cast(cpp_defaults.cache_expiration).count(); + config.cache_min_lifetime_ms = + std::chrono::duration_cast(cpp_defaults.cache_min_lifetime) + .count(); + ; + config.cache_min_size = cpp_defaults.cache_min_size; + config.cache_num_nodes_to_use_for_refresh = cpp_defaults.cache_num_nodes_to_use_for_refresh; + config.cache_node_failure_threshold = cpp_defaults.cache_node_failure_threshold; + config.cache_refresh_using_legacy_endpoint = cpp_defaults.cache_refresh_using_legacy_endpoint; + + config.onionreq_path_failure_threshold = cpp_defaults.onionreq_path_failure_threshold; + config.onionreq_path_build_retry_limit = cpp_defaults.onionreq_path_build_retry_limit; + config.onionreq_min_path_count_standard = + cpp_defaults.onionreq_min_path_counts[RequestCategory::standard]; + config.onionreq_min_path_count_upload = + cpp_defaults.onionreq_min_path_counts[RequestCategory::upload]; + config.onionreq_min_path_count_download = + cpp_defaults.onionreq_min_path_counts[RequestCategory::download]; + config.onionreq_single_path_mode = cpp_defaults.onionreq_single_path_mode; + config.onionreq_disable_pre_build_paths = cpp_defaults.onionreq_disable_pre_build_paths; + + config.quic_handshake_timeout_seconds = + std::chrono::duration_cast(cpp_defaults.quic_handshake_timeout) + .count(); + config.quic_keep_alive_seconds = + std::chrono::duration_cast(cpp_defaults.quic_keep_alive).count(); + config.quic_disable_mtu_discovery = cpp_defaults.quic_disable_mtu_discovery; + + config.transport_callback = nullptr; + config.transport_callback_ctx = nullptr; + + return config; +} + +LIBSESSION_C_API bool session_network_init( + network_object** network, const session_network_config* config, char* error) { + if (!network || !config) + return set_error(error, std::invalid_argument{"network or config were null."}); + + try { + // Build the configuration options (ordered this way for the debug logs to make the most + // sense) + std::vector cpp_opts; + + // Network ID + switch (config->netid) { + case SESSION_NETWORK_MAINNET: cpp_opts.emplace_back(opt::netid::mainnet()); break; + case SESSION_NETWORK_TESTNET: cpp_opts.emplace_back(opt::netid::testnet()); break; + case SESSION_NETWORK_DEVNET: + if (!config->devnet_seed_nodes || config->devnet_seed_nodes_size == 0) + throw std::runtime_error( + "SESSION_NETWORK_DEVNET requires at least one seed node."); + + std::vector seed_nodes; + seed_nodes.reserve(config->devnet_seed_nodes_size); + + for (size_t i = 0; i < config->devnet_seed_nodes_size; ++i) + seed_nodes.push_back(service_node::from(config->devnet_seed_nodes[i])); + + cpp_opts.emplace_back(opt::netid::devnet(std::move(seed_nodes))); + break; + } + + // Router + switch (config->router) { + case SESSION_NETWORK_ROUTER_ONION_REQUESTS: + cpp_opts.emplace_back(opt::router::onion_requests()); + break; + case SESSION_NETWORK_ROUTER_LOKINET: + cpp_opts.emplace_back(opt::router::lokinet()); + break; + case SESSION_NETWORK_ROUTER_DIRECT: cpp_opts.emplace_back(opt::router::direct()); break; + } + + // Transport + switch (config->transport) { + case SESSION_NETWORK_TRANSPORT_QUIC: + cpp_opts.emplace_back(opt::transport::quic()); + break; + + case SESSION_NETWORK_TRANSPORT_CALLBACKS: + if (!config->transport_callback) + throw std::runtime_error( + "transport_callback must be set when using the CALLBACKS for sending " + "requests."); + + auto c_callback_ptr = config->transport_callback; + auto ctx = config->transport_callback_ctx; + + opt::transport::network_callback_t cpp_callback = + [c_callback_ptr, ctx]( + std::string url, + std::string body, + session::network::network_response_callback_t handle_response) { + auto* c_response_handle = + new session_response_handle_t{std::move(handle_response)}; + + c_callback_ptr( + url.c_str(), body.data(), body.size(), c_response_handle, ctx); + }; + + cpp_opts.emplace_back(opt::transport::callbacks(std::move(cpp_callback))); + break; + } + + if (!config->enforce_subnet_diversity) + cpp_opts.emplace_back(opt::disable_subnet_diversity{}); + + if (config->min_retry_delay_ms > 0 || config->max_retry_delay_ms > 0) + cpp_opts.emplace_back(opt::retry_delay{ + std::chrono::milliseconds{config->min_retry_delay_ms}, + std::chrono::milliseconds{config->max_retry_delay_ms}}); + + // A `0` value is valid for this option + cpp_opts.emplace_back(opt::redirect_retry_count{config->redirect_retry_count}); + + if (config->request_timeout_check_frequency_ms > 0) + cpp_opts.emplace_back(opt::request_timeout_check_frequency{ + std::chrono::milliseconds{config->request_timeout_check_frequency_ms}}); + + // Snode cache + if (config->cache_dir) + cpp_opts.emplace_back(opt::cache_directory{std::filesystem::path{config->cache_dir}}); + + if (config->cache_expiration_minutes > 0) + cpp_opts.emplace_back( + opt::cache_expiration{std::chrono::minutes{config->cache_expiration_minutes}}); + + if (config->cache_min_lifetime_ms > 0) + cpp_opts.emplace_back(opt::cache_min_lifetime{ + std::chrono::milliseconds{config->cache_min_lifetime_ms}}); + + if (config->cache_min_size > 0) + cpp_opts.emplace_back(opt::cache_min_size{config->cache_min_size}); + + // A `0` value is valid for this option + cpp_opts.emplace_back(opt::cache_num_nodes_to_use_for_refresh{ + config->cache_num_nodes_to_use_for_refresh}); + + if (config->cache_node_failure_threshold > 0) + cpp_opts.emplace_back( + opt::cache_node_failure_threshold{config->cache_node_failure_threshold}); + + if (config->cache_refresh_using_legacy_endpoint) + cpp_opts.emplace_back(opt::cache_refresh_using_legacy_endpoint{}); + + // Router-specific settings + switch (config->router) { + case SESSION_NETWORK_ROUTER_ONION_REQUESTS: + // Process the Onion Request options since we are using them + if (config->path_length > 0) + cpp_opts.emplace_back(opt::path_length{config->path_length}); + + if (config->onionreq_path_failure_threshold > 0) + cpp_opts.emplace_back(opt::onionreq_path_failure_threshold{ + config->onionreq_path_failure_threshold}); + + if (config->onionreq_path_build_retry_limit > 0) + cpp_opts.emplace_back(opt::onionreq_path_build_retry_limit{ + config->onionreq_path_build_retry_limit}); + + if (config->onionreq_min_path_count_standard > 0) + cpp_opts.emplace_back(opt::onionreq_min_path_count{ + RequestCategory::standard, config->onionreq_min_path_count_standard}); + + if (config->onionreq_min_path_count_upload > 0) + cpp_opts.emplace_back(opt::onionreq_min_path_count{ + RequestCategory::upload, config->onionreq_min_path_count_upload}); + + if (config->onionreq_min_path_count_download > 0) + cpp_opts.emplace_back(opt::onionreq_min_path_count{ + RequestCategory::download, config->onionreq_min_path_count_download}); + + if (config->onionreq_single_path_mode) + cpp_opts.emplace_back(opt::onionreq_single_path_mode{}); + + if (config->onionreq_disable_pre_build_paths) + cpp_opts.emplace_back(opt::onionreq_disable_pre_build_paths{}); + break; + + case SESSION_NETWORK_ROUTER_LOKINET: + // Process the Lokinet options since we are using them + if (config->path_length > 0) + cpp_opts.emplace_back(opt::path_length{config->path_length}); + break; + + case SESSION_NETWORK_ROUTER_DIRECT: break; + } + + // Transport-specific settings + switch (config->transport) { + case SESSION_NETWORK_TRANSPORT_QUIC: + if (config->quic_handshake_timeout_seconds > 0) + cpp_opts.emplace_back(opt::quic_handshake_timeout{ + std::chrono::seconds{config->quic_handshake_timeout_seconds}}); + + if (config->quic_keep_alive_seconds > 0) + cpp_opts.emplace_back(opt::quic_keep_alive{ + std::chrono::seconds{config->quic_keep_alive_seconds}}); + + if (config->quic_disable_mtu_discovery) + cpp_opts.emplace_back(opt::quic_disable_mtu_discovery{}); + + break; + + case SESSION_NETWORK_TRANSPORT_CALLBACKS: break; + } + + // Construct the Network instance + Config final_config(cpp_opts); + auto n = std::make_unique(std::move(final_config)); + auto n_object = std::make_unique(); + n_object->internals = n.release(); + *network = n_object.release(); + return true; + } catch (const std::exception& e) { + return set_error(error, e); + } +} + +LIBSESSION_C_API void session_network_free(network_object* network) { + delete static_cast(network->internals); + delete network; +} + +LIBSESSION_C_API void session_request_params_free(session_request_params* params) { + if (params) + std::free(params); +} + +LIBSESSION_C_API void session_network_suspend(network_object* network) { + unbox(network).suspend(); +} + +LIBSESSION_C_API void session_network_resume( + network_object* network, bool automatically_reconnect) { + unbox(network).resume(automatically_reconnect); +} + +LIBSESSION_C_API void session_network_close_connections(network_object* network) { + unbox(network).close_connections(); +} + +LIBSESSION_C_API void session_network_clear_cache(network_object* network) { + unbox(network).clear_cache(); +} + +LIBSESSION_C_API int64_t session_network_time_offset(network_object* network) { + return unbox(network).network_time_offset().count(); +} + +LIBSESSION_C_API uint16_t session_network_hardfork(network_object* network) { + return unbox(network).hardfork(); +} + +LIBSESSION_C_API uint16_t session_network_softfork(network_object* network) { + return unbox(network).softfork(); +} + +LIBSESSION_C_API void session_network_set_status_changed_callback( + network_object* network, void (*callback)(CONNECTION_STATUS status, void* ctx), void* ctx) { + if (!callback) + unbox(network).on_status_changed = nullptr; + else + unbox(network).on_status_changed = [cb = std::move(callback), + ctx](ConnectionStatus status) { + cb(static_cast(status), ctx); + }; +} + +LIBSESSION_C_API void session_network_set_network_info_changed_callback( + network_object* network, + void (*callback)( + int64_t network_time_offset, uint16_t hardfork, uint16_t softfork, void* ctx), + void* ctx) { + if (!callback) + unbox(network).on_network_info_changed = nullptr; + else + unbox(network).on_network_info_changed = + [cb = std::move(callback), ctx]( + std::chrono::milliseconds network_time_offset, + uint16_t hardfork, + uint16_t softfork) { + cb(network_time_offset.count(), hardfork, softfork, ctx); + }; +} + +LIBSESSION_C_API void session_network_callbacks_respond( + network_object* network, + session_response_handle_t* response_handle, + bool success, + bool timeout, + int16_t status_code, + const char* const* headers_, + const char* const* header_values, + size_t headers_size, + const char* body_, + size_t body_len) { + if (!response_handle) + return; + + std::unique_ptr handle_guard(response_handle); + std::vector> headers; + headers.reserve(headers_size); + + if (headers_size > 0) + for (size_t i = 0; i < headers_size; i++) + headers.emplace_back(headers_[i], header_values[i]); + + std::optional body; + if (body_len > 0) + body.emplace(body_, body_len); + + handle_guard->cpp_callback(success, timeout, status_code, std::move(headers), std::move(body)); +} + +LIBSESSION_C_API CONNECTION_STATUS session_network_get_status(network_object* network) { + if (!network) + return CONNECTION_STATUS_UNKNOWN; + + return static_cast(unbox(network).get_status()); +} + +LIBSESSION_C_API void session_network_get_active_paths( + network_object* network, session_path_info** out_paths, size_t* out_paths_len) { + if (!network || !out_paths || !out_paths_len) + return; + + *out_paths = nullptr; + *out_paths_len = 0; + + try { + std::vector cpp_paths = unbox(network).get_active_paths(); + if (cpp_paths.empty()) + return; + + // Calculate the size of the data + size_t total_size = cpp_paths.size() * sizeof(session_path_info); + size_t total_nodes = 0; + for (const auto& path : cpp_paths) + total_nodes += path.nodes.size(); + total_size += total_nodes * sizeof(network_service_node); + + size_t total_metadata_size = 0; + for (const auto& p : cpp_paths) { + std::visit( + [&](auto&& md) { + using T = std::decay_t; + if constexpr (std::is_same_v) + total_metadata_size += sizeof(session_onion_path_metadata); + else if constexpr (std::is_same_v) + total_metadata_size += sizeof(session_lokinet_tunnel_metadata); + }, + p.metadata); + } + total_size += total_metadata_size; + + // Allocate and assign the memory + unsigned char* buffer = static_cast(std::malloc(total_size)); + if (!buffer) + return; + + auto* c_paths_array = reinterpret_cast(buffer); + auto* current_node_ptr = + reinterpret_cast(c_paths_array + cpp_paths.size()); + unsigned char* current_metadata_ptr = + reinterpret_cast(current_node_ptr + total_nodes); + + for (size_t i = 0; i < cpp_paths.size(); ++i) { + const auto& cpp_path = cpp_paths[i]; + auto& c_path = c_paths_array[i]; + + new (&c_path) session_path_info{}; + + c_path.nodes = current_node_ptr; + c_path.nodes_count = cpp_path.nodes.size(); + for (const auto& cpp_node : cpp_path.nodes) { + new (current_node_ptr) network_service_node{}; + cpp_node.into(*current_node_ptr); + current_node_ptr++; + } + + // Copy metadata + std::visit( + [&](auto&& m) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + auto* meta = reinterpret_cast( + current_metadata_ptr); + new (meta) session_onion_path_metadata{}; + meta->category = + static_cast(m.category); + c_path.onion_metadata = meta; + current_metadata_ptr += sizeof(session_onion_path_metadata); + } else if constexpr (std::is_same_v) { + auto* meta = reinterpret_cast( + current_metadata_ptr); + new (meta) session_lokinet_tunnel_metadata{}; + strncpy(meta->destination_pubkey, + m.destination_pubkey.c_str(), + sizeof(meta->destination_pubkey) - 1); + meta->destination_pubkey[sizeof(meta->destination_pubkey) - 1] = '\0'; + strncpy(meta->destination_snode_address, + m.destination_snode_address.c_str(), + sizeof(meta->destination_snode_address) - 1); + meta->destination_snode_address + [sizeof(meta->destination_snode_address) - 1] = '\0'; + c_path.lokinet_metadata = meta; + current_metadata_ptr += sizeof(session_lokinet_tunnel_metadata); + } + }, + cpp_path.metadata); + } + + *out_paths = c_paths_array; + *out_paths_len = cpp_paths.size(); + } catch (...) { + *out_paths = nullptr; + *out_paths_len = 0; + } +} + +LIBSESSION_C_API void session_network_paths_free(session_path_info* paths) { + if (paths) + std::free(paths); +} + +LIBSESSION_C_API void session_network_get_swarm( + network_object* network, + const char* swarm_pubkey_hex, + void (*callback)(network_service_node* nodes, size_t nodes_len, void*), + void* ctx) { + assert(swarm_pubkey_hex && callback); + unbox(network).get_swarm( + x25519_pubkey::from_hex({swarm_pubkey_hex, 64}), + [cb = std::move(callback), ctx](swarm_id_t, std::vector nodes) { + auto c_nodes = network::detail::convert_service_nodes(nodes); + cb(c_nodes.data(), c_nodes.size(), ctx); + }); +} + +LIBSESSION_C_API void session_network_get_random_nodes( + network_object* network, + uint16_t count, + void (*callback)(network_service_node*, size_t, void*), + void* ctx) { + assert(callback); + unbox(network).get_random_nodes( + count, [cb = std::move(callback), ctx](std::vector nodes) { + auto c_nodes = network::detail::convert_service_nodes(nodes); + cb(c_nodes.data(), c_nodes.size(), ctx); + }); +} + +LIBSESSION_C_API void session_network_send_request( + network_object* network, + const session_request_params* params, + session_network_response_t callback, + void* ctx) { + assert(callback); + + try { + if (!network) + throw std::invalid_argument("Invalid request: 'network' cannot be null."); + if (!params) + throw std::invalid_argument("Invalid request: 'params' cannot be null."); + + network_destination dest; + + if (params->snode_dest && params->server_dest) + throw std::invalid_argument( + "Invalid request: Cannot have both 'snode_dest' and 'server_dest' set."); + + if (params->snode_dest) { + dest = service_node::from(*params->snode_dest); + } else if (params->server_dest) { + const auto& c_server = *params->server_dest; + + std::optional>> headers; + if (c_server.headers_kv_pairs && c_server.headers_kv_pairs_len > 0) { + if (c_server.headers_kv_pairs_len % 2 != 0) + throw std::invalid_argument( + "Invalid request: Header must have an even number of key-value " + "strings."); + + headers.emplace(); + headers->reserve(c_server.headers_kv_pairs_len / 2); + for (int i = 0; i < c_server.headers_kv_pairs_len; i += 2) { + const char* key = c_server.headers_kv_pairs[i]; + const char* val = c_server.headers_kv_pairs[i + 1]; + + if (!key || !val) + throw std::invalid_argument( + "Invalid request: Header list contains a null key or value."); + + headers->emplace_back(key, val); + } + } + + dest = ServerDestination{ + c_server.protocol, + c_server.host, + x25519_pubkey::from_hex(c_server.x25519_pubkey_hex), + (c_server.port > 0 ? std::optional{c_server.port} : std::nullopt), + headers, + c_server.method}; + } else + throw std::invalid_argument( + "Invalid request: Must have either 'snode_dest' or 'server_dest' set."); + + std::optional> body; + if (params->body && params->body_size > 0) + body.emplace(params->body, params->body + params->body_size); + + std::optional request_id; + if (params->request_id) + request_id = params->request_id; + + auto request = Request{ + dest, + std::string{params->endpoint}, + body, + static_cast(params->category), + std::chrono::milliseconds{params->request_timeout_ms}, + (params->overall_timeout_ms > 0 + ? std::optional{std::chrono::milliseconds{params->overall_timeout_ms}} + : std::nullopt), + request_id}; + auto cpp_callback = [c_cb = callback, c_ctx = ctx]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional body) { + std::vector c_headers; + c_headers.reserve(headers.size() * 2 + 1); + for (const auto& [key, val] : headers) { + c_headers.push_back(key.c_str()); + c_headers.push_back(val.c_str()); + } + c_headers.push_back(nullptr); // NULL terminator + + c_cb(success, + timeout, + status_code, + c_headers.data(), + (headers.size() * 2), + body ? reinterpret_cast(body->data()) : nullptr, + body ? body->size() : 0, + c_ctx); + }; + + unbox(network).send_request(std::move(request), std::move(cpp_callback)); + } catch (const std::exception& e) { + callback( + false, + false, + -1, + nullptr, + 0, + reinterpret_cast(e.what()), + strlen(e.what()), + ctx); + } +} + +} // extern "C" diff --git a/src/network/session_network_internal.cpp b/src/network/session_network_internal.cpp new file mode 100644 index 00000000..4df7d6a1 --- /dev/null +++ b/src/network/session_network_internal.cpp @@ -0,0 +1,158 @@ +#include "session_network_internal.hpp" + +#include +#include +#include +#include + +#include "session/network/service_node.hpp" + +namespace session::network::detail { + +session_request_params* convert_cpp_request_to_c(const session::network::Request& req) { + size_t total_size = sizeof(session_request_params); + size_t string_data_size = 0; + + // Calculate the expected size + auto add_string_size = [&](const std::string& s) { + if (!s.empty()) + string_data_size += s.length() + 1; + }; + + add_string_size(req.request_id); + add_string_size(req.endpoint); + + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + total_size += sizeof(network_service_node); + } else if constexpr (std::is_same_v) { + total_size += sizeof(network_server_destination); + add_string_size(arg.protocol); + add_string_size(arg.host); + add_string_size(arg.method); + add_string_size(arg.x25519_pubkey.hex()); + + if (arg.headers) { + // key pointers + value pointers + NULL terminator + string_data_size += (arg.headers->size() * 2 + 1) * sizeof(const char*); + add_string_size(arg.x25519_pubkey.hex()); + + for (const auto& [k, v] : *arg.headers) { + add_string_size(k); + add_string_size(v); + } + } + } else if constexpr (std::is_same_v) { + total_size += sizeof(session_remote_address); + } + }, + req.destination); + + size_t body_size = req.body ? req.body->size() : 0; + total_size += body_size; + + // Allocate the data and assign values + unsigned char* buffer = static_cast(std::malloc(total_size + string_data_size)); + if (!buffer) + return nullptr; + + auto* c_params = reinterpret_cast(buffer); + unsigned char* current_ptr = buffer + sizeof(session_request_params); + + auto copy_string = [&](const std::string& s) -> const char* { + if (s.empty()) + return nullptr; + char* dest = reinterpret_cast(current_ptr); + std::memcpy(dest, s.c_str(), s.length() + 1); + current_ptr += s.length() + 1; + return dest; + }; + + new (c_params) session_request_params{}; + c_params->request_id = copy_string(req.request_id); + c_params->endpoint = copy_string(req.endpoint); + + c_params->category = static_cast(req.category); + c_params->request_timeout_ms = req.request_timeout.count(); + c_params->overall_timeout_ms = (req.overall_timeout ? req.overall_timeout->count() : 0); + + if (body_size > 0) { + std::memcpy(current_ptr, req.body->data(), body_size); + c_params->body = current_ptr; + c_params->body_size = body_size; + current_ptr += body_size; + } + + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + auto* c_snode = reinterpret_cast(current_ptr); + arg.into(*c_snode); + c_params->snode_dest = c_snode; + current_ptr += sizeof(network_service_node); + } else if constexpr (std::is_same_v) { + auto* c_server_dest = + reinterpret_cast(current_ptr); + new (c_server_dest) network_server_destination{}; + c_params->server_dest = c_server_dest; + current_ptr += sizeof(network_server_destination); + + c_server_dest->protocol = copy_string(arg.protocol); + c_server_dest->host = copy_string(arg.host); + c_server_dest->method = copy_string(arg.method); + c_server_dest->x25519_pubkey_hex = copy_string(arg.x25519_pubkey.hex()); + c_server_dest->port = arg.port.value_or(0); + + if (arg.headers) { + auto** c_headers_array = reinterpret_cast(current_ptr); + c_server_dest->headers_kv_pairs = c_headers_array; + c_server_dest->headers_kv_pairs_len = arg.headers->size() * 2; + current_ptr += (arg.headers->size() * 2 + 1) * sizeof(const char*); + + int i = 0; + for (const auto& [k, v] : *arg.headers) { + c_headers_array[i++] = copy_string(k); + c_headers_array[i++] = copy_string(v); + } + c_headers_array[i] = nullptr; // Null terminator for safety + } + } else if constexpr (std::is_same_v) { + auto* c_remote = reinterpret_cast(current_ptr); + new (c_remote) session_remote_address{}; + c_params->remote_addr_dest = c_remote; + current_ptr += sizeof(session_remote_address); + + auto ed25519_pubkey_hex = oxenc::to_hex(arg.view_remote_key()); + oxen::quic::ipv4 ip = arg.to_ipv4(); + + strncpy(c_remote->ed25519_pubkey_hex, ed25519_pubkey_hex.c_str(), 64); + c_remote->ed25519_pubkey_hex[64] = '\0'; // Ensure null termination + c_remote->ip[0] = (ip.addr >> 24) & 0xFF; + c_remote->ip[1] = (ip.addr >> 16) & 0xFF; + c_remote->ip[2] = (ip.addr >> 8) & 0xFF; + c_remote->ip[3] = ip.addr & 0xFF; + c_remote->port = arg.port(); + } + }, + req.destination); + + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + if (arg.file_name) + c_params->upload_file_name = copy_string(*arg.file_name); + } + }, + req.details); + + return c_params; +} + +} // namespace session::network::detail \ No newline at end of file diff --git a/src/network/session_network_internal.hpp b/src/network/session_network_internal.hpp new file mode 100644 index 00000000..59d95c75 --- /dev/null +++ b/src/network/session_network_internal.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "session/network/session_network_types.h" +#include "session/network/session_network_types.hpp" + +namespace session::network::detail { +session_request_params* convert_cpp_request_to_c(const session::network::Request& req); +} \ No newline at end of file diff --git a/src/network/session_network_types.cpp b/src/network/session_network_types.cpp new file mode 100644 index 00000000..7b90f6a0 --- /dev/null +++ b/src/network/session_network_types.cpp @@ -0,0 +1,106 @@ +#include "session/network/session_network_types.hpp" + +#include +#include + +#include "session/random.hpp" + +using namespace oxen; +using namespace oxen::log::literals; + +namespace session::network { + +Request::Request( + std::string request_id, + network_destination destination, + std::string endpoint, + std::optional> body, + RequestCategory category, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout, + RequestDetails details, + bool ephemeral_connection) : + request_id{std::move(request_id)}, + destination{std::move(destination)}, + endpoint{std::move(endpoint)}, + body{std::move(body)}, + category{std::move(category)}, + request_timeout{std::move(request_timeout)}, + overall_timeout{std::move(overall_timeout)}, + details{details}, + ephemeral_connection{ephemeral_connection} {} + +Request::Request( + network_destination destination, + std::string endpoint, + std::optional> body, + RequestCategory category, + std::chrono::milliseconds request_timeout, + std::optional overall_timeout, + std::optional request_id, + RequestDetails details, + bool ephemeral_connection) : + request_id{std::move(request_id.value_or("R-{}"_format(random::random_base32(4))))}, + destination{std::move(destination)}, + endpoint{std::move(endpoint)}, + body{std::move(body)}, + category{std::move(category)}, + request_timeout{std::move(request_timeout)}, + overall_timeout{std::move(overall_timeout)}, + details{details}, + ephemeral_connection{ephemeral_connection} {} + +std::optional> Response::parse_text_error(const std::string& body) { + static const std::unordered_map> error_map = { + {"400 Bad Request", {400, false}}, + {"401 Unauthorized", {401, false}}, + {"403 Forbidden", {403, false}}, + {"404 Not Found", {404, false}}, + {"405 Method Not Allowed", {405, false}}, + {"406 Not Acceptable", {406, false}}, + {"408 Request Timeout", {408, false}}, + {"500 Internal Server Error", {500, false}}, + {"502 Bad Gateway", {502, false}}, + {"503 Service Unavailable", {503, false}}, + {"504 Gateway Timeout", {504, true}}, + }; + + for (const auto& [prefix, result] : error_map) + if (body.starts_with(prefix)) + return result; + + return std::nullopt; +} + +std::optional Response::find_uniform_batch_error(const std::string& body) { + try { + auto json = nlohmann::json::parse(body); + + // If it wasn't a batch response then just handle the non-batch status code + if (json.contains("results") && json["results"].is_array() && !json["results"].empty()) { + int16_t first_status_code = -1; + + for (const auto& result : json["results"]) { + if (!result.contains("code") || !result["code"].is_number()) + return std::nullopt; + + // If we got a success then we can just use the original status code + int16_t code = result["code"].get(); + if (code >= 200 && code <= 299) + return std::nullopt; + + if (first_status_code == -1) + first_status_code = code; + else if (first_status_code != code) + return std::nullopt; + } + + return first_status_code; + } + } catch (...) { /* Do nothing */ + } + + return std::nullopt; +} + +} // namespace session::network \ No newline at end of file diff --git a/src/network/snode_pool.cpp b/src/network/snode_pool.cpp new file mode 100644 index 00000000..44c3f140 --- /dev/null +++ b/src/network/snode_pool.cpp @@ -0,0 +1,958 @@ +#include "session/network/snode_pool.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "session/file.hpp" +#include "session/hash.hpp" +#include "session/random.hpp" + +using namespace oxen; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace std { + +template <> +struct hash { + size_t operator()(const oxen::quic::ipv4& ip) const noexcept { + return std::hash{}(ip.addr); + } +}; + +} // namespace std + +namespace session::network { + +namespace fs = std::filesystem; + +namespace { + inline auto cat = log::Cat("snode_pool"); +} // namespace + +SnodePool::SnodePool( + config::SnodePoolConfig config, + std::shared_ptr loop, + network_fetcher_t direct_fetcher) : + _config{config}, _loop{loop}, _direct_fetcher{std::move(direct_fetcher)} { + if (_config.cache_directory) { + std::string cache_file_name; + + switch (_config.netid) { + case opt::netid::Target::mainnet: cache_file_name = "snode_pool"; break; + case opt::netid::Target::testnet: cache_file_name = "snode_pool_testnet"; break; + case opt::netid::Target::devnet: + std::string seed_node_data; + + for (const auto& node : _config.seed_nodes) + node.to_disk(std::back_inserter(seed_node_data)); + + auto hash_bytes = session::hash::hash(32, session::to_span(seed_node_data)); + cache_file_name = "snode_pool_devnet_" + oxenc::to_hex(hash_bytes); + break; + } + + _snode_cache_file_path = *_config.cache_directory / cache_file_name; + _load_from_disk(); + _disk_write_thread = std::thread{&SnodePool::_disk_write_loop, this}; + } +} + +SnodePool::~SnodePool() { + if (_disk_write_thread.joinable()) { + { + std::unique_lock lock{_cache_mutex}; + _shut_down_disk_thread = true; + } + + _disk_write_cv.notify_one(); + _disk_write_thread.join(); + } +} + +// MARK: Disk I/O Functions + +void SnodePool::_load_from_disk() { + if (_snode_cache_file_path.empty()) { + log::error(cat, "Tried to load cache from disk without a cache file path."); + return; + } + + // Load the cache if present + try { + if (!fs::exists(_snode_cache_file_path)) { + log::info(cat, "No existing snode cache, will rebuild."); + return; + } + + auto ftime = fs::last_write_time(_snode_cache_file_path); + _last_snode_cache_update = + std::chrono::time_point_cast( + ftime - fs::file_time_type::clock::now() + + std::chrono::system_clock::now()); + + std::vector loaded_cache_data = read_whole_file(_snode_cache_file_path); + std::vector loaded_cache; + auto invalid_entries = 0; + + std::string_view data_view( + reinterpret_cast(loaded_cache_data.data()), loaded_cache_data.size()); + loaded_cache.reserve( + (data_view.size() / service_node_disk_format::MAX_LINE_SIZE) + 1); // +1 for safety + + size_t start = 0; + while (start < data_view.size()) { + // Find either \n or \r + size_t end = data_view.find_first_of("\n\r", start); + if (end == std::string_view::npos) + end = data_view.size(); + + if (end > start) { // Skip empty lines + std::string_view line = data_view.substr(start, end - start); + + try { + loaded_cache.push_back(service_node::from_disk(line)); + } catch (...) { + ++invalid_entries; + } + } + + // Skip past any line ending characters (\n, \r, or both in any order) + start = end; + while (start < data_view.size() && + (data_view[start] == '\n' || data_view[start] == '\r')) { + ++start; + } + } + + if (loaded_cache_data.size() > 0 && loaded_cache.size() == 0) + throw std::runtime_error{"Snode cache has invalid format"}; + + if (invalid_entries > 0) + log::warning(cat, "Skipped {} invalid entries in snode cache.", invalid_entries); + + std::shuffle(loaded_cache.begin(), loaded_cache.end(), csrng); + _snode_cache = std::move(loaded_cache); + _all_swarms = swarm::generate_swarms(_snode_cache); + + log::info( + cat, + "Loaded cache of {} snodes, {} swarms.", + _snode_cache.size(), + _all_swarms.size()); + } catch (const std::exception& e) { + log::error(cat, "Failed to load snode cache, will rebuild ({}).", e.what()); + + if (fs::exists(_snode_cache_file_path)) + fs::remove_all(_snode_cache_file_path); + } +} + +void SnodePool::_disk_write_loop() { + std::unique_lock lock{_cache_mutex}; + + while (!_shut_down_disk_thread) { + _disk_write_cv.wait(lock, [this] { + return _need_write || _need_clear_cache || _shut_down_disk_thread; + }); + + // Shutdown if needed + if (_shut_down_disk_thread) + break; + + // Clear cache if needed + if (_need_clear_cache) { + _snode_cache = {}; + _all_swarms = {}; + _swarm_cache = {}; + + auto path_to_clear = _snode_cache_file_path; + lock.unlock(); + try { + if (!path_to_clear.empty() && fs::exists(path_to_clear)) + fs::remove_all(path_to_clear); + log::info(cat, "Cleared snode cache from disk."); + } catch (const std::exception& e) { + log::error(cat, "Failed to clear snode cache file: {}", e.what()); + } + lock.lock(); + _need_clear_cache = false; + } + + if (_need_write) { + // Just in case + if (_snode_cache_file_path.empty()) { + _need_write = false; + continue; + } + + // Make a local copy so that we can release the lock and not + // worry about other threads wanting to change things + auto path_to_write = _snode_cache_file_path; + auto snode_cache_write = _snode_cache; + + lock.unlock(); + { + try { + if (snode_cache_write.empty()) + throw std::runtime_error{"cache was empty."}; + + // Create the cache directories if needed + fs::create_directories(path_to_write.parent_path()); + + // Save the snode pool to disk + auto tmp_path = path_to_write; + tmp_path += u8"_new"; + + { + std::string output_buffer; + output_buffer.reserve( + snode_cache_write.size() * service_node_disk_format::MAX_LINE_SIZE); + + for (const auto& snode : snode_cache_write) + snode.to_disk(std::back_inserter(output_buffer)); + + std::ofstream file(tmp_path, std::ios::binary); + file.write(output_buffer.data(), output_buffer.size()); + } + + fs::rename(tmp_path, path_to_write); + log::debug(cat, "Finished writing snode cache to disk."); + } catch (const std::exception& e) { + log::error(cat, "Failed to write snode cache: {}", e.what()); + } + } + lock.lock(); + _need_write = false; + } + } +} + +// MARK: Refresh Functions + +void SnodePool::_refresh_snode_cache(std::optional request_id_opt) { + const auto request_id = request_id_opt.value_or("RSC-" + random::random_base32(4)); + bool use_routed_fetcher = true; + uint8_t num_nodes_for_refresh = 0; + + { + std::unique_lock lock{_cache_mutex}; + + if (_suspended) { + log::info(cat, "Ignoring refresh as pool is suspended."); + return; + } + + // Only allow a single cache refresh at a time + if (_current_snode_cache_refresh_id) { + log::debug( + cat, + "[Request {}] Ignoring refresh snode cache attempt; a refresh is already in " + "progress ({}).", + request_id, + *_current_snode_cache_refresh_id); + return; + } + + log::info(cat, "[Request {}] Starting cache refresh.", request_id); + _current_snode_cache_refresh_id = request_id; + _snode_refresh_results.clear(); + _refresh_candidate_nodes.clear(); + + // We should only use the routed_fetcher if it exists, passes a connectivity check, and + // there are enough cached nodes + const auto cache_insufficient = + (_config.cache_num_nodes_to_use_for_refresh > 0 && + _snode_cache.size() < _config.cache_num_nodes_to_use_for_refresh); + use_routed_fetcher = + (cache_insufficient && _routed_fetcher && _routed_fetcher_connectivity_check && + (*_routed_fetcher_connectivity_check)()); + + // We should only refresh using seed nodes if using cached nodes is disabled, or there + // aren't enough cached nodes to refresh from + const auto use_seed_nodes = + (_config.cache_num_nodes_to_use_for_refresh == 0 || cache_insufficient); + + // Seed nodes are trusted so we only need to use a single node when refreshing from them + num_nodes_for_refresh = (use_seed_nodes ? 1 : _config.cache_num_nodes_to_use_for_refresh); + _refresh_candidate_nodes = (use_seed_nodes ? _config.seed_nodes : _snode_cache); + std::shuffle(_refresh_candidate_nodes.begin(), _refresh_candidate_nodes.end(), csrng); + + if (!use_routed_fetcher && use_seed_nodes) + log::debug( + cat, + "[Request {}] Refreshing using seed nodes{}.", + request_id, + (cache_insufficient ? " (cache is insufficient)" : "")); + else if (!use_routed_fetcher && !use_seed_nodes) + log::warning( + cat, + "[Request {}] {}, using direct fetcher to fetch from {} nodes for cache " + "refresh.", + request_id, + (!_routed_fetcher ? "No routed fetcher set" : "Routed fetcher not ready"), + num_nodes_for_refresh); + else if (use_routed_fetcher && use_seed_nodes) + log::debug( + cat, + "[Request {}] Refreshing using seed nodes (cache is insufficient).", + request_id); + else + log::debug( + cat, + "[Request {}] Refrshing via routed fetcher using {} nodes.", + request_id, + num_nodes_for_refresh); + + // If we (somehow) have no candidate nodes then error and reset the state so we can try + // again later + if (_refresh_candidate_nodes.empty()) { + log::critical( + cat, + "Cannot refresh cache: {}", + (use_seed_nodes ? "No seed nodes are configured!" + : "Found no nodes and decided not to use seed nodes!")); + _current_snode_cache_refresh_id.reset(); + return; + } + } + + // Kick off the concurrent requests (if there are any) + for (uint8_t i = 0; i < num_nodes_for_refresh; ++i) + _launch_next_refresh_request(request_id, !use_routed_fetcher, num_nodes_for_refresh); +} + +void SnodePool::_launch_next_refresh_request( + const std::string& request_id, + const bool use_direct_fetcher, + const uint8_t total_requests) { + service_node target_node; + session::network::SnodePool::network_fetcher_t fetcher_to_use; + bool use_legacy_endpoint = false; + + { + std::unique_lock lock{_cache_mutex}; + + if (!_current_snode_cache_refresh_id) + return; + + if (_refresh_candidate_nodes.empty()) { + // If we run out of candidate nodes then we should fail this refresh request and start a + // new one with a new id (the `_refresh_snode_cache` will decide which nodes and fetcher + // should be used) + _snode_cache_refresh_failure_count++; + auto delay = _config.retry_delay.exponential(_snode_cache_refresh_failure_count); + log::warning( + cat, + "[Request {}] Ran out of nodes for refresh, discarding partial results and " + "trying again in {}ms.", + request_id, + delay.count()); + _loop->call_later(delay, [weak_self = weak_from_this()] { + // We need to wait until after the `call_later` to reset the `refresh_id` (and clear + // previous results) as if we don't then additional refreshes could be triggered + // during the delay + auto self = weak_self.lock(); + if (!self) + return; + + { + std::unique_lock lock{self->_cache_mutex}; + self->_current_snode_cache_refresh_id.reset(); + self->_snode_refresh_results.clear(); + } + + self->_refresh_snode_cache(); + }); + return; + } + + target_node = _refresh_candidate_nodes.back(); + _refresh_candidate_nodes.pop_back(); + fetcher_to_use = (use_direct_fetcher ? _direct_fetcher : *_routed_fetcher); + use_legacy_endpoint = (!use_direct_fetcher && _config.cache_refresh_using_legacy_endpoint); + } + + // If we somehow got into '_launch_next_refresh_request' for a routed request then we need to + // make sure '_routed_fetcher' was set before we try to use it + if (!fetcher_to_use) { + log::critical(cat, "[Request {}] No fetcher available, aborting refresh.", request_id); + std::unique_lock lock{_cache_mutex}; + _current_snode_cache_refresh_id.reset(); + _refresh_candidate_nodes.clear(); + return; + } + + log::debug( + cat, + "[Request {}] Launching {} refresh request to {}", + request_id, + (use_direct_fetcher ? "direct" : "routed"), + target_node.to_string()); + const Request request = + [this, &request_id, &target_node, use_direct_fetcher, use_legacy_endpoint]() { + // A mandatory service node upgrade needs to go out to support calling + // `active_nodes_bin` via onion requests so if the `use_legacy_endpoint` setting is + // set then we should use the legacy endpoint to refresh the cache + if (use_legacy_endpoint) { + nlohmann::json body{ + {"endpoint", "get_service_nodes"}, + {"params", + {{"active_only", true}, + {"fields", + {{"pubkey_ed25519", true}, + {"public_ip", true}, + {"storage_port", true}, + {"storage_lmq_port", true}, + {"storage_server_version", true}, + {"swarm_id", true}}}}}, + }; + + return Request{ + request_id, + network_destination{target_node}, + std::string{"oxend_request"}, + to_vector(body.dump()), + RequestCategory::standard, + 10s, + std::nullopt, // overall_timeout + std::monostate{}, // details + true // ephemeral_connection + }; + } + + return Request{ + request_id, + network_destination{target_node}, + std::string{"active_nodes_bin"}, + std::nullopt, + RequestCategory::standard, + 10s, + std::nullopt, // overall_timeout + std::monostate{}, // details + true // ephemeral_connection + }; + }(); + + fetcher_to_use( + request, + [this, request_id, use_direct_fetcher, total_requests, use_legacy_endpoint]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + // This callback runs on the network loop so acquire a lock + std::unique_lock lock{_cache_mutex}; + + // If the refresh was cancelled or completed while we were in-flight, do nothing + if (!_current_snode_cache_refresh_id || + *_current_snode_cache_refresh_id != request_id) { + log::debug(cat, "[Request {}] Ignoring stale refresh response.", request_id); + return; + } + + std::vector result; + + try { + if (!success || timeout || !response) + throw std::runtime_error{response.value_or("Unknown error.")}; + + if (status_code < 200 || status_code > 299) + throw status_code_exception{ + status_code, + {content_type_plain_text}, + "Request failed with status code: {}, error: {}"_format( + status_code, response.value_or("Unknown error."))}; + + result.assign( + reinterpret_cast(response->data()), + reinterpret_cast( + response->data() + response->length())); + } catch (const std::exception& e) { + _snode_cache_refresh_failure_count++; + auto delay = + _config.retry_delay.exponential(_snode_cache_refresh_failure_count); + + log::warning( + cat, + "Failed to refresh cache from one node: {}. Trying another in {}ms.", + e.what(), + delay.count()); + _loop->call_later( + delay, + [weak_self = weak_from_this(), + request_id, + use_direct_fetcher, + total_requests] { + if (auto self = weak_self.lock()) + self->_retry_refresh_request( + request_id, use_direct_fetcher, total_requests); + }); + return; + } + + _snode_refresh_results.push_back(std::move(result)); + log::info( + cat, + "[Request {}] Received refresh response {}/{}.", + request_id, + _snode_refresh_results.size(), + total_requests); + + // If we've received all the results then we need to process them and complete the + // refresh + if (_snode_refresh_results.size() >= total_requests) { + auto final_results = std::move(_snode_refresh_results); + auto refresh_id = *_current_snode_cache_refresh_id; + lock.unlock(); // Unlock so `_on_refresh_complete` can get it's own lock + _on_refresh_complete( + refresh_id, + final_results, + use_direct_fetcher, + total_requests, + use_legacy_endpoint); + } + }); +} + +void SnodePool::_retry_refresh_request( + const std::string& request_id, + const bool use_direct_fetcher, + const uint8_t total_requests) { + _launch_next_refresh_request(request_id, use_direct_fetcher, total_requests); +} + +void SnodePool::_on_refresh_complete( + std::string refresh_id, + std::vector> raw_results, + const bool use_direct_fetcher, + const uint8_t total_requests, + const bool from_legacy_endpoint) { + log::info( + cat, + "[Request {}] Have {} responses, processing and finalizing cache refresh.", + refresh_id, + raw_results.size()); + + // Sort the vectors (so make it easier to find the intersection) + std::vector> processed_nodes; + processed_nodes.reserve(raw_results.size()); + for (size_t i = 0; i < raw_results.size(); ++i) { + try { + auto& nodes_bin = raw_results[i]; + std::pair, int> result; + auto& [nodes, invalid_count] = result; + + // Due to how onion requests work they need to return JSON data which means the data + // could be base64-encoded, so handle that case if needed + if (from_legacy_endpoint) { + nlohmann::json response_json = nlohmann::json::parse(to_string_view(nodes_bin)); + + if (!response_json.contains("result") || !response_json["result"].is_object()) + throw std::runtime_error{"JSON missing result field."}; + + nlohmann::json result_json = response_json["result"]; + if (!result_json.contains("service_node_states") || + !result_json["service_node_states"].is_array()) + throw std::runtime_error{"JSON missing service_node_states field."}; + + for (auto& snode : result_json["service_node_states"]) + try { + nodes.emplace_back(service_node::legacy_from_json(snode)); + } catch (...) { + invalid_count++; + } + } else if (!use_direct_fetcher && oxenc::is_base64(nodes_bin)) { + std::vector converted_nodes; + oxenc::from_base64( + nodes_bin.begin(), nodes_bin.end(), std::back_inserter(converted_nodes)); + result = service_node::process_snode_cache_bin(converted_nodes); + } else + result = service_node::process_snode_cache_bin(nodes_bin); + + log::info( + cat, + "[Request {}] Refresh response #{} included {} nodes, {} invalid.", + refresh_id, + (i + 1), + nodes.size(), + invalid_count); + std::stable_sort(nodes.begin(), nodes.end()); + processed_nodes.emplace_back(std::move(nodes)); + } catch (const std::exception& e) { + std::chrono::milliseconds delay; + + { + std::unique_lock lock{_cache_mutex}; + _snode_refresh_results.clear(); + _snode_cache_refresh_failure_count++; + delay = _config.retry_delay.exponential(_snode_cache_refresh_failure_count); + } + + log::error( + cat, + "[Request {}] Discarding responses and retrying after {}ms due to invalid " + "response #{}: {}.", + refresh_id, + delay.count(), + (i + 1), + e.what()); + _loop->call_later( + delay, + [weak_self = weak_from_this(), refresh_id, use_direct_fetcher, total_requests] { + if (auto self = weak_self.lock()) + for (uint8_t i = 0; i < total_requests; ++i) + self->_launch_next_refresh_request( + refresh_id, use_direct_fetcher, total_requests); + }); + return; + } + } + + auto nodes = processed_nodes[0]; + + // If we triggered multiple requests then get the intersection of all vectors + for (size_t i = 1; i < processed_nodes.size(); ++i) { + std::vector intersection; + std::set_intersection( + nodes.begin(), + nodes.end(), + processed_nodes[i].begin(), + processed_nodes[i].end(), + std::back_inserter(intersection)); + nodes = std::move(intersection); + } + + // Shuffle the nodes so we don't have a specific order + std::shuffle(nodes.begin(), nodes.end(), csrng); + log::info(cat, "[Request {}] Cache refresh complete with {} nodes.", refresh_id, nodes.size()); + + std::vector> after_refresh; + + { + std::unique_lock lock{_cache_mutex}; + + // Update the in-memory caches and, since the swarm cache could now be invalid, clear it and + // re-generate `_all_swarms` + _snode_cache = std::move(nodes); + _all_swarms = swarm::generate_swarms(_snode_cache); + _swarm_cache.clear(); + _last_snode_cache_update = std::chrono::system_clock::now(); + + // Reset all failure and refresh-in-progress state + _snode_failure_counts.clear(); + _current_snode_cache_refresh_id.reset(); + _snode_refresh_results.clear(); + _refresh_candidate_nodes.clear(); + _snode_cache_refresh_failure_count = 0; + + // Move any callbacks (so they can be called after the lock is freed) + after_refresh = std::move(_after_snode_cache_refresh); + + // Flag that we need to write the updated cache to disk + _need_write = true; + } + + _disk_write_cv.notify_one(); + + // Trigger any callbacks + if (!after_refresh.empty()) { + log::debug(cat, "Executing {} post-refresh callbacks.", after_refresh.size()); + + for (const auto& cb : after_refresh) { + try { + cb(); + } catch (const std::exception& e) { + log::error(cat, "Exception thrown in a post-refresh callback: {}", e.what()); + } + } + } +} + +// MARK: Public Functions + +void SnodePool::suspend() { + std::unique_lock lock{_cache_mutex}; + _suspended = true; + log::info(cat, "Suspended."); +} + +void SnodePool::resume() { + std::unique_lock lock{_cache_mutex}; + if (!_suspended) + return; + + _suspended = false; + log::info(cat, "Resumed."); +} + +void SnodePool::set_routed_fetcher( + network_fetcher_t routed_fetcher, fetcher_connectivity_check_t connectivity_check) { + std::unique_lock lock{_cache_mutex}; + _routed_fetcher = std::move(routed_fetcher); + _routed_fetcher_connectivity_check = std::move(connectivity_check); +} + +size_t SnodePool::size() { + std::lock_guard lock{_cache_mutex}; + return _snode_cache.size(); +} + +void SnodePool::clear_cache() { + { + std::lock_guard lock{_cache_mutex}; + _need_clear_cache = true; + } + _disk_write_cv.notify_one(); +} + +void SnodePool::record_node_failure(const service_node& node, bool permanent) { + record_node_failure(ed25519_pubkey::from_bytes(node.view_remote_key()), permanent); +} + +void SnodePool::record_node_failure(const ed25519_pubkey& key, bool permanent) { + std::lock_guard lock{_cache_mutex}; + _snode_failure_counts[key] = + (permanent ? _config.cache_node_failure_threshold : _snode_failure_counts[key] += 1); + log::trace( + cat, + "Recorded failure for node {}, total failures: {}", + key.hex(), + _snode_failure_counts[key]); +} + +uint16_t SnodePool::node_failure_count(const service_node& node) { + return node_failure_count(ed25519_pubkey::from_bytes(node.view_remote_key())); +} + +uint16_t SnodePool::node_failure_count(const ed25519_pubkey& key) { + std::lock_guard lock{_cache_mutex}; + if (_snode_failure_counts.contains(key)) + return _snode_failure_counts.at(key); + + return 0; +} + +void SnodePool::clear_node_failure_counts() { + std::lock_guard lock{_cache_mutex}; + _snode_failure_counts.clear(); +} + +void SnodePool::refresh_if_needed( + const std::vector& in_use_nodes, std::function on_refresh_complete) { + bool needs_to_start_refresh = false; + bool already_running = false; + std::optional delay; + + { + std::lock_guard lock{_cache_mutex}; + + if (_suspended) { + log::info(cat, "Ignoring refresh as pool is suspended."); + return; + } + + // Don't bother if we are alread doing a refresh + if (_current_snode_cache_refresh_id) + already_running = true; + else { + auto cache_lifetime = std::chrono::system_clock::now() - _last_snode_cache_update; + needs_to_start_refresh = + (_snode_cache.empty() || cache_lifetime > _config.cache_expiration); + + // Also need to refresh if there are not enough non-failed nodes in the cache + if (!needs_to_start_refresh) { + size_t usable_nodes_count = 0; + + std::unordered_set in_use_keys; + for (const auto& node : in_use_nodes) + in_use_keys.insert(ed25519_pubkey::from_bytes(node.view_remote_key())); + + for (const auto& node : _snode_cache) { + auto pubkey = ed25519_pubkey::from_bytes(node.view_remote_key()); + auto it = _snode_failure_counts.find(pubkey); + if (it != _snode_failure_counts.end() && + it->second >= _config.cache_node_failure_threshold) + continue; + + // If the caller considers the node as already in use then it wouldn't be + // considered usable so ignore it for the purpose of determining whether we have + // enough nodes to avoid a refresh + if (in_use_keys.count(pubkey)) + continue; + + usable_nodes_count++; + + if (usable_nodes_count >= _config.cache_min_size) + break; + } + + if (usable_nodes_count < _config.cache_min_size) + needs_to_start_refresh = true; + } + + if (needs_to_start_refresh && cache_lifetime < _config.cache_min_lifetime) + delay.emplace(std::chrono::duration_cast( + _config.cache_min_lifetime - cache_lifetime)); + } + + // If a refresh is needed or already running, queue the callback + if ((needs_to_start_refresh || already_running) && on_refresh_complete) + _after_snode_cache_refresh.push_back(std::move(on_refresh_complete)); + } + + // Kick off a refresh if needed (if none was needed then we should trigger the + // on_refresh_complete callback immediately) + if (needs_to_start_refresh) + if (delay) { + _loop->call_later(*delay, [weak_self = weak_from_this()] { + if (auto self = weak_self.lock()) + self->_refresh_snode_cache(); + }); + } else + _refresh_snode_cache(); + else if (!already_running && on_refresh_complete) + on_refresh_complete(); +} + +std::vector SnodePool::get_unused_nodes( + size_t count, const std::vector& exclude_nodes) { + // Kick of a cache refresh in the background if needed (call_soon to ensure it is scheduled + // after whatever called `get_unused_nodes` which may be something trying to make it's own + // request that we would want to run first) + _loop->call_soon([weak_self = weak_from_this(), exclude_nodes] { + if (auto self = weak_self.lock()) + self->refresh_if_needed(exclude_nodes); + }); + + // Then try to get the desired number of nodes from the current cache + std::vector result; + result.reserve(count); + + std::unordered_set exclusion_keys; + exclusion_keys.reserve(exclude_nodes.size()); + for (const auto& node : exclude_nodes) + exclusion_keys.insert(ed25519_pubkey::from_bytes(node.view_remote_key())); + + std::unordered_set used_subnets; + if (_config.enforce_subnet_diversity) + for (const auto& node : exclude_nodes) + used_subnets.insert(node.ip.to_base(24)); + + std::lock_guard lock{_cache_mutex}; + + if (_snode_cache.empty()) { + log::warning(cat, "Cannot get unused nodes: snode cache is empty."); + return result; + } + + // Pick a random starting index to start checking for unused nodes + size_t start_index = random::get_uniform_distribution(0, _snode_cache.size() - 1); + + for (size_t i = 0; i < _snode_cache.size(); ++i) { + if (result.size() >= count) + break; + + const size_t current_index = (start_index + i) % _snode_cache.size(); + const auto& node = _snode_cache[current_index]; + auto current_key = ed25519_pubkey::from_bytes(node.view_remote_key()); + + // Skip nodes explicitly excluded (needed in case subnet diversity is disabled) + if (exclusion_keys.count(current_key)) + continue; + + // Skip nodes with too many failures + auto it = _snode_failure_counts.find(current_key); + if (it != _snode_failure_counts.end() && it->second >= _config.cache_node_failure_threshold) + continue; + + // Skip nodes whos IP addresses are in the exclusion list + if (_config.enforce_subnet_diversity) { + auto subnet = node.ip.to_base(24); + if (used_subnets.count(subnet)) + continue; + } + + result.push_back(node); + + if (_config.enforce_subnet_diversity) + used_subnets.insert(node.ip.to_base(24)); + } + + if (result.size() < count) + log::warning(cat, "Could only find {}/{} suitable unused nodes.", result.size(), count); + + return result; +} + +void SnodePool::get_swarm( + session::network::x25519_pubkey swarm_pubkey, + std::function swarm)> callback) { + log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, swarm_pubkey.hex()); + + std::unique_lock lock{_cache_mutex}; + + // Check the in-memory swarm cache first + if (auto it = _swarm_cache.find(swarm_pubkey); it != _swarm_cache.end()) + return callback(it->second.first, it->second.second); + + // If we have no snode cache or no swarms then we need to rebuild the cache (which will also + // rebuild the swarms) and run this request again + if (_snode_cache.empty() || _all_swarms.empty()) { + log::debug(cat, "Cache is empty, deferring get_swarm until refresh is complete."); + + // Queue this entire function call to be re-run after the refresh. + _after_snode_cache_refresh.push_back([this, swarm_pubkey, cb = std::move(callback)]() { + this->get_swarm(swarm_pubkey, std::move(cb)); + }); + + // Check if a refresh is already running. If not, we need to start one. + bool needs_to_start_refresh = !_current_snode_cache_refresh_id; + + // We MUST unlock before calling '_refresh_snode_cache' as it acquires a lock itself + lock.unlock(); + + // Start the refresh if we're the ones who decided it was needed + if (needs_to_start_refresh) + _refresh_snode_cache(); + + return; + } + + // Copy the required data and release the lock so we don't hold it during calculation + auto all_swarms_copy = _all_swarms; + lock.unlock(); + + // Trigger a non-blocking background refresh if the data is stale + _loop->call_soon([weak_self = weak_from_this()] { + if (auto self = weak_self.lock()) + self->refresh_if_needed({}); + }); + + // Perform the swarm calculation using our local copy of the data + auto swarm = swarm::get_swarm(swarm_pubkey, all_swarms_copy); + log::info( + cat, + "Found swarm with {} nodes for {}, adding to cache.", + swarm.second.size(), + swarm_pubkey.hex()); + + // Update our in-memory cache (need to re-acquire the lock to do so) + { + std::lock_guard write_lock{_cache_mutex}; + _swarm_cache[swarm_pubkey] = swarm; + } + + // Trigger the callback with the swarm we found + callback(swarm.first, swarm.second); +} + +} // namespace session::network \ No newline at end of file diff --git a/src/network/swarm.cpp b/src/network/swarm.cpp new file mode 100644 index 00000000..3f54bebe --- /dev/null +++ b/src/network/swarm.cpp @@ -0,0 +1,70 @@ +#include "session/network/swarm.hpp" + +#include + +#include "session/network/service_node.hpp" +#include "session/network/session_network.hpp" + +namespace session::network::swarm { + +swarm_id_t pubkey_to_swarm_space(const session::network::x25519_pubkey& pk) { + swarm_id_t res = 0; + for (size_t i = 0; i < 4; i++) { + swarm_id_t buf; + std::memcpy(&buf, pk.data() + i * 8, 8); + res ^= buf; + } + oxenc::big_to_host_inplace(res); + + return res; +} + +std::vector>> generate_swarms( + const std::vector nodes) { + std::vector>> result; + std::unordered_map> _grouped_nodes; + + for (const auto& node : nodes) + _grouped_nodes[node.swarm_id].push_back(node); + + for (auto& [swarm_id, nodes] : _grouped_nodes) + result.emplace_back(swarm_id, std::move(nodes)); + + std::sort(result.begin(), result.end(), [](const auto& a, const auto& b) { + return a.first < b.first; + }); + return result; +} + +std::pair> get_swarm( + const session::network::x25519_pubkey swarm_pubkey, + const std::vector>> all_swarms) { + // If there is only a single swarm then return it + if (all_swarms.size() == 1) + return all_swarms.front(); + + // Generate a swarm_id for the pubkey + const swarm_id_t swarm_id = pubkey_to_swarm_space(swarm_pubkey); + + // Find the right boundary, i.e. first swarm with swarm_id >= res + auto right_it = std::lower_bound( + all_swarms.begin(), all_swarms.end(), swarm_id, [](const auto& s, uint64_t v) { + return s.first < v; + }); + + if (right_it == all_swarms.end()) + // res is > the top swarm_id, meaning it is big and in the wrapping space between last + // and first elements. + right_it = all_swarms.begin(); + + // Our "left" is the one just before that (with wraparound, if right is the first swarm) + auto left_it = std::prev(right_it == all_swarms.begin() ? all_swarms.end() : right_it); + + uint64_t dright = right_it->first - swarm_id; + uint64_t dleft = swarm_id - left_it->first; + auto swarm = &*(dright < dleft ? right_it : left_it); + + return *swarm; +} + +} // namespace session::network::swarm diff --git a/src/network/transport/quic_transport.cpp b/src/network/transport/quic_transport.cpp new file mode 100644 index 00000000..ececa392 --- /dev/null +++ b/src/network/transport/quic_transport.cpp @@ -0,0 +1,634 @@ +#include "session/network/transport/quic_transport.hpp" + +#include +#include +#include + +#include "session/ed25519.hpp" +#include "session/network/session_network_types.hpp" + +using namespace oxen; +using namespace session; +using namespace session::network; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network { + +namespace { + inline auto cat = log::Cat("network"); +} + +// TODO: Should the `ALPN` be changed to an argument passed into the `connect` function? +constexpr auto ALPN = "oxenstorage"; + +QuicTransport::QuicTransport( + config::QuicTransportConfig config, std::shared_ptr loop) : + _config{std::move(config)}, _loop{loop} { + log::trace(cat, "[QuicTransport] Initializing."); + _recreate_endpoint(); +} + +QuicTransport::~QuicTransport() { + // Use 'call_get' to force this to be synchronous + if (_loop) + _loop->call_get([this] { _close_connections(); }); + log::debug(cat, "[QuicTransport] Destroyed."); +} + +// MARK: ITransport + +void QuicTransport::suspend() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + if (!_suspended) + return; + + _suspended = true; + _close_connections(); + log::info(cat, "[QuicTransport] Suspended."); + }); +} + +void QuicTransport::resume(bool automatically_reconnect) { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { + // Recreate the endpoint before updating the `_suspended` flag to avoid the chance that + // something will try to use it before we are ready + _recreate_endpoint(); + _suspended = false; + log::info(cat, "[QuicTransport] Resumed."); + }); +} + +void QuicTransport::close_connections() { + // Use 'call_get' to force this to be synchronous + _loop->call_get([this] { _close_connections(); }); +} + +void QuicTransport::set_node_failure_reporter(node_failure_reporter_t reporter) { + _loop->call([weak_self = weak_from_this(), r = std::move(reporter)] { + if (auto self = weak_self.lock()) + self->_report_node_failure.emplace(std::move(r)); + }); +} + +void QuicTransport::verify_connectivity( + service_node node, + std::chrono::milliseconds timeout, + const std::string& context_id, + std::function callback) { + // For Quic, a successful connection IS a successful ping so we can just check for an existing + // connection and, if one doesn't exist, try to establish one + _loop->call([weak_self = weak_from_this(), + node = std::move(node), + cb = std::move(callback), + context_id]() { + auto self = weak_self.lock(); + if (!self) + return; + + const auto pubkey_hex = oxenc::to_hex(node.view_remote_key()); + + // If we already have a connection we can stop here + if (self->_active_connection_ids.count(pubkey_hex) || + self->_pending_requests.count(pubkey_hex)) + return cb(true); + + self->_pending_verification_callbacks[pubkey_hex].push_back(std::move(cb)); + + // Only try to establish a connection if we are the first to ask for one + if (self->_pending_requests.count(pubkey_hex) == 0 && + self->_pending_verification_callbacks.at(pubkey_hex).size() == 1) + self->_establish_connection( + {node.view_remote_key(), node.host(), node.omq_port}, context_id); + }); +} + +void QuicTransport::add_failure_listener( + const ed25519_pubkey& pubkey, std::function listener) { + _loop->call([weak_self = weak_from_this(), + pk_hex = pubkey.hex(), + l = std::move(listener)]() mutable { + if (auto self = weak_self.lock()) + self->_failure_listeners[pk_hex].push_back(std::move(l)); + }); +} + +void QuicTransport::remove_failure_listeners(const ed25519_pubkey& pubkey) { + _loop->call([weak_self = weak_from_this(), pk_hex = pubkey.hex()] { + if (auto self = weak_self.lock()) + self->_failure_listeners.erase(pk_hex); + }); +} + +void QuicTransport::send_request(Request request, network_response_callback_t callback) { + log::trace(cat, "[QuicTransport] Dispatching request {} to loop.", request.request_id); + _loop->call([weak_self = weak_from_this(), req = std::move(request), cb = std::move(callback)] { + if (auto self = weak_self.lock()) + self->_send_request_internal(std::move(req), std::move(cb)); + }); +} + +// MARK: Internal Logic + +void QuicTransport::_recreate_endpoint() { + _endpoint = quic::Endpoint::endpoint( + *_loop, + quic::Address{"0.0.0.0", 0}, + quic::opt::alpns{ALPN}, + (_config.disable_mtu_discovery ? std::optional{} + : std::nullopt)); +} + +void QuicTransport::_close_connections() { + // Explicitly close all connections then reset the endpoint + if (_endpoint) + _endpoint->close_conns(); + _endpoint.reset(); + + // Cancel any pending verifications (they can't succeed once the connection is closed) + for (const auto& [pubkey, callbacks] : _pending_verification_callbacks) + for (const auto& callback : callbacks) + callback(false); + + // Cancel any pending requests (they can't succeed once the connection is closed) + for (const auto& [pubkey, pupkey_requests] : _pending_requests) + for (const auto& [info, callback] : pupkey_requests) + callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "QuickTransport is suspended."); + + // Clear all storage of requests, paths and connections so that we are in a fresh state on + // relaunch + _ephemeral_connection_ids.clear(); + _active_connection_ids.clear(); + _active_stream_ids.clear(); + _pending_verification_callbacks.clear(); + _pending_requests.clear(); + + _update_status(ConnectionStatus::disconnected); + log::info(cat, "[QuicTransport] Closed all connections."); +} + +void QuicTransport::_update_status(ConnectionStatus new_status) { + ConnectionStatus old_status = _status.load(); + if (old_status == new_status) + return; + + // Prevent swapping from "connected" back to "connecting" if a background connection is being + // established while we are already connected + if (old_status == ConnectionStatus::connected && new_status == ConnectionStatus::connecting) + return; + + // If we already tried to reconnect but failed, then we want to prevent swapping between + // "disconnected" and "connecting" + if (old_status == ConnectionStatus::disconnected && + new_status == ConnectionStatus::connecting && _has_attempted_reconnect) + return; + + _status.store(new_status); + + if (old_status == ConnectionStatus::disconnected && new_status == ConnectionStatus::connecting) + _has_attempted_reconnect = true; + + if (new_status == ConnectionStatus::connected) + _has_attempted_reconnect = false; + + if (on_status_changed) + on_status_changed(); +} + +void QuicTransport::_send_request_internal(Request request, network_response_callback_t callback) { + // If we are suspended then fail immediately + if (_suspended) + return callback( + false, + false, + ERROR_NETWORK_SUSPENDED, + {content_type_plain_text}, + "QuickTransport is suspended."); + + std::optional remote; + + std::visit( + [&remote, request_id = request.request_id](auto&& arg) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + log::trace( + cat, + "[QuicTransport Request {}]: Using pre-resolved RemoteAddress.", + request_id); + remote = arg; + } else if constexpr (std::is_same_v) { + log::trace( + cat, + "[QuicTransport Request {}]: Resolving service_node to RemoteAddress.", + request_id); + remote.emplace(arg.view_remote_key(), arg.host(), arg.omq_port); + } + }, + request.destination); + + if (!remote) { + log::critical( + cat, "[QuicTransport Request {}] Invalid destination type!", request.request_id); + return callback( + false, + false, + -1, + {content_type_plain_text}, + "Internal error: invalid destination for QuicTransport"); + } + + const auto remote_pubkey_hex = oxenc::to_hex(remote->view_remote_key()); + + // If an active connection exists then we can send the request over that + if (auto it = _active_connection_ids.find(remote_pubkey_hex); + it != _active_connection_ids.end()) { + log::trace( + cat, "[QuicTransport Request {}] Found active connection ID.", request.request_id); + _send_on_connection(it->second, std::move(request), std::move(callback)); + return; + } + + // If we should already be establishing a connection then we can just add this as a pending + // request and it'll be picked up once the connection is made + if (_pending_requests.count(remote_pubkey_hex)) { + log::debug( + cat, + "[QuicTransport Request {}] Connection to {} is pending, queueing request.", + request.request_id, + remote_pubkey_hex); + _pending_requests[remote_pubkey_hex].emplace_back(std::move(request), std::move(callback)); + return; + } + + // No connection exists so we need to start a new one and queue the request + log::info( + cat, + "[QuicTransport Request {}] No connection to {}, initiating new connection.", + request.request_id, + remote_pubkey_hex); + std::string initiating_req_id = request.request_id; + _pending_requests[remote_pubkey_hex].emplace_back(std::move(request), std::move(callback)); + _establish_connection(*remote, initiating_req_id); +} + +void QuicTransport::_establish_connection( + const oxen::quic::RemoteAddress& address, const std::string& initiating_req_id) { + const auto address_pubkey_hex = oxenc::to_hex(address.view_remote_key()); + auto conn_key_pair = ed25519::ed25519_key_pair(); + auto creds = quic::GNUTLSCreds::make_from_ed_seckey(to_string_view(conn_key_pair.second)); + + // If we are starting a connection attempt then transition to the "connecting" state + if (_status.load() == ConnectionStatus::unknown || + _status.load() == ConnectionStatus::disconnected) + _update_status(ConnectionStatus::connecting); + + log::debug( + cat, + "[QuicTransport Request {}] Establishing new connection to {}.", + initiating_req_id, + address_pubkey_hex); + try { + _endpoint->connect( + address, + creds, + oxen::quic::opt::handshake_timeout{_config.handshake_timeout}, + oxen::quic::opt::keep_alive{_config.keep_alive}, + [weak_self = weak_from_this(), address_pubkey_hex, initiating_req_id]( + oxen::quic::Connection& conn) { + auto self = weak_self.lock(); + if (!self) + return; + + log::info( + cat, + "[QuicTransport Request {}] Successfully established connection to " + "{}.", + initiating_req_id, + address_pubkey_hex); + + auto stream = conn.open_stream(); + auto conn_id = conn.reference_id(); + auto stream_id = stream->stream_id(); + auto verification_callbacks = + std::move(self->_pending_verification_callbacks[address_pubkey_hex]); + self->_pending_verification_callbacks.erase(address_pubkey_hex); + + auto requests_to_process = + std::move(self->_pending_requests[address_pubkey_hex]); + self->_pending_requests.erase(address_pubkey_hex); + + // Only persistent requests verify connectivity so if there is a + // verification callback then it should be persistent, otherwise if ANY of + // the requests require persistence then we should store the connection (if + // we don't store it then the connection will timeout and be closed) + bool is_persistent = !verification_callbacks.empty(); + if (!is_persistent) + is_persistent = std::any_of( + requests_to_process.begin(), + requests_to_process.end(), + [](const auto& req_pair) { + return !req_pair.first.ephemeral_connection; + }); + + if (is_persistent) { + self->_ephemeral_connection_ids.erase(conn_id); // Just in case + self->_active_connection_ids.insert_or_assign(address_pubkey_hex, conn_id); + } else + self->_ephemeral_connection_ids.insert(conn_id); + + self->_active_stream_ids.insert_or_assign(conn_id, stream_id); + + // We had a successful connection so update the status to connected + self->_update_status(ConnectionStatus::connected); + + for (const auto& pending_cb : verification_callbacks) + pending_cb(true); + + if (!requests_to_process.empty()) { + log::debug( + cat, + "[QuicTransport] Processing {} pending requests on new stream " + "{} " + "with " + "conn {}.", + requests_to_process.size(), + stream_id, + conn_id.to_string()); + + for (auto&& [req, cb] : std::move(requests_to_process)) + self->_send_on_connection(conn_id, std::move(req), std::move(cb)); + } + }, + [weak_self = weak_from_this(), address_pubkey_hex, initiating_req_id]( + oxen::quic::Connection& conn, uint64_t error_code) { + if (auto self = weak_self.lock()) + self->_fail_connection( + address_pubkey_hex, + initiating_req_id, + conn.reference_id(), + error_code, + std::nullopt); + }); + } catch (const std::exception& e) { + _fail_connection( + address_pubkey_hex, initiating_req_id, std::nullopt, std::nullopt, e.what()); + } +} + +void QuicTransport::_send_on_connection( + oxen::quic::ConnectionID conn_id, Request request, network_response_callback_t callback) { + // Try to retrieve the active connection first + auto conn = _endpoint->get_conn(conn_id); + if (!conn) { + log::warning( + cat, + "[QuicTransport Request {}] Attempted to send on a connection (ID {}) that no " + "longer exists.", + request.request_id, + conn_id.to_string()); + + // Since the connection is dead we should remove it from our active list and fail the + // request (the client can retry if they want) + for (auto it = _active_connection_ids.begin(); it != _active_connection_ids.end(); ++it) { + if (it->second == conn_id) { + _active_connection_ids.erase(it); + break; + } + } + _active_stream_ids.erase(conn_id); + + return callback( + false, + false, + -1, + {content_type_plain_text}, + "Connection died before request could be sent"); + } + + // Then try to get an active stream for this connection + auto stream_it = _active_stream_ids.find(conn_id); + if (stream_it == _active_stream_ids.end()) { + // Something has gone horribly wrong, lets close the connection and the client can retry + log::critical( + cat, + "[QuicTransport Request {}] No stream ID found for active connection {}, closing " + "connection.", + request.request_id, + conn_id.to_string()); + conn->close_connection(); + return callback( + false, + false, + -1, + {content_type_plain_text}, + "Internal error: Stream state missing for active connection"); + } + + auto stream_id = stream_it->second; + auto stream = conn->get_stream(stream_id); + if (!stream) { + // Similar to the above, if the stream is gone then the connection ir probably in a bad + // state so we should just close it + log::warning( + cat, + "[QuicTransport Request {}] Stream {} on connection {} has died, closing " + "connection.", + request.request_id, + stream_id, + conn_id.to_string()); + conn->close_connection(); + return callback( + false, false, -1, {content_type_plain_text}, "Connection stream was closed"); + } + + // If the request has already timedout at this point then just fail it immediately + auto timeout = request.time_remaining(); + if (timeout <= std::chrono::milliseconds::zero()) + return callback(false, true, 408, {content_type_plain_text}, "Request already timed out"); + + // We have a valid connection and stream so we can send the request + log::debug( + cat, + "[QuicTransport Request {}] Sending on stream {} with conn {}", + request.request_id, + stream_id, + conn_id.to_string()); + + std::span payload{}; + + if (request.body) + payload = to_span(*request.body); + + stream->command( + request.endpoint, + payload, + timeout, + [weak_self = weak_from_this(), + cb = std::move(callback), + conn_id, + stream_id, + req_id = request.request_id](quic::message resp) { + auto self = weak_self.lock(); + if (!self) + return; + + log::trace(cat, "[QuicTransport Request {}] Received response.", req_id); + + // If this connection was an ephemeral connection then we should close it (don't + // want to keep it alive longer than needed) + if (self->_ephemeral_connection_ids.count(conn_id)) { + self->_ephemeral_connection_ids.erase(conn_id); + self->_active_stream_ids.erase(conn_id); + + if (auto conn = self->_endpoint->get_conn(conn_id)) + conn->close_connection(); + } + + // Trigger the callback based on the response we got + if (resp.timed_out) { + log::debug(cat, "[QuicTransport Request {}] Timed out.", req_id); + return cb(false, true, 408, {content_type_plain_text}, "Request timed out"); + } + + if (resp.is_error()) { + auto final_timeout = resp.timed_out; + auto final_status_code = -1; + std::string err_body = + (resp.body().empty() ? "Unknown QUIC layer error" + : std::string{resp.body()}); + + // The response doesn't provide a status code but the body can include it, + // in which case we should try to extract it from the body so we can perform + // any status code related logic + if (auto result = Response::parse_text_error(err_body)) { + final_status_code = result->first; + final_timeout = result->second; + } + + log::debug( + cat, + "[QuicTransport Request {}] Failed with QUIC error: {}.", + req_id, + err_body); + return cb( + false, + final_timeout, + final_status_code, + {content_type_plain_text}, + err_body); + } + + log::debug( + cat, "[QuicTransport Request {}] Received raw success response.", req_id); + cb(true, false, 200, {}, std::string{resp.body()}); + }); +} + +void QuicTransport::_fail_connection( + const std::string& address_pubkey_hex, + const std::string& initiating_req_id, + std::optional conn_id, + std::optional error_code, + std::optional custom_error) { + if (error_code == NGTCP2_NO_ERROR) + log::info( + cat, + "[QuicTransport Request {}] Connection to {} closed gracefully.", + initiating_req_id, + address_pubkey_hex); + else if (error_code == static_cast(NGTCP2_ERR_HANDSHAKE_TIMEOUT)) { + log::warning( + cat, + "[QuicTransport Request {}] Handshake timeout when connecting to {}. " + "The node is likely unreachable.", + initiating_req_id, + address_pubkey_hex); + + // If the connection failed with a handshake timeout then the node is + // unreachable, either due to a device network issue or because the node is down + // so permanently fail the node so it won't be used for subsequent requests + // (until the next cache refresh) + if (_report_node_failure) + (*_report_node_failure)(ed25519_pubkey::from_hex(address_pubkey_hex), true); + } else if (error_code == quic::CONN_SEND_FAIL) { + log::warning( + cat, + "[QuicTransport Request {}] Connection to {} failed as we were unable to send it a " + "packet (error: {})", + initiating_req_id, + address_pubkey_hex, + *error_code); + } else if (error_code) + log::warning( + cat, + "[QuicTransport Request {}] Connection to {} failed or was closed with " + "error code: {}", + initiating_req_id, + address_pubkey_hex, + *error_code); + else + log::error( + cat, + "[QuicTransport Request {}] Connection to {} failed or was closed due to error: " + "{}.", + initiating_req_id, + address_pubkey_hex, + custom_error.value_or("Unknown error")); + + _active_connection_ids.erase(address_pubkey_hex); + + if (conn_id) { + _ephemeral_connection_ids.erase(*conn_id); + _active_stream_ids.erase(*conn_id); + } + + // Process any waiting verification requests + if (auto it = _pending_verification_callbacks.find(address_pubkey_hex); + it != _pending_verification_callbacks.end()) { + for (const auto& pending_cb : it->second) + pending_cb(false); + _pending_verification_callbacks.erase(it); + } + + // Fail all the pending requests for this connection + if (auto it = _pending_requests.find(address_pubkey_hex); it != _pending_requests.end()) { + auto to_fail = std::move(it->second); + _pending_requests.erase(it); + + std::string failure_reason = "Failed to establish connection to service node"; + if (error_code == static_cast(NGTCP2_ERR_HANDSHAKE_TIMEOUT)) + failure_reason += " (handshake timeout)"; + + log::error( + cat, + "[QuicTransport] Failing {} pending request(s) due to connection " + "failure.", + to_fail.size()); + + for (auto& [req, cb] : to_fail) + cb(false, false, -1, {content_type_plain_text}, failure_reason); + } + + // Notify any failure listeners that the connection has been closed + if (auto it = _failure_listeners.find(address_pubkey_hex); it != _failure_listeners.end()) { + auto to_fail = std::move(it->second); + _failure_listeners.erase(it); + + for (const auto& listener : it->second) + listener(); + } + + // If we have no longer have any active connections then we are disconnected + if (_active_connection_ids.empty()) + _update_status(ConnectionStatus::disconnected); +} + +} // namespace session::network diff --git a/src/onionreq/builder.cpp b/src/onionreq/builder.cpp index bea6688c..57b49534 100644 --- a/src/onionreq/builder.cpp +++ b/src/onionreq/builder.cpp @@ -22,22 +22,24 @@ #include #include "session/export.h" +#include "session/network/key_types.hpp" +#include "session/network/service_node.hpp" +#include "session/network/session_network_types.hpp" #include "session/onionreq/builder.h" #include "session/onionreq/hop_encryption.hpp" -#include "session/onionreq/key_types.hpp" -#include "session/session_network.hpp" #include "session/util.hpp" #include "session/xed25519.hpp" using namespace std::literals; using namespace oxen::log::literals; +using namespace session::network; namespace session::onionreq { namespace detail { - session::onionreq::x25519_pubkey pubkey_for_destination(network_destination destination) { + session::network::x25519_pubkey pubkey_for_destination(network_destination destination) { if (auto* dest = std::get_if(&destination)) - return compute_x25519_pubkey(dest->view_remote_key()); + return network::compute_x25519_pubkey(dest->view_remote_key()); if (auto* dest = std::get_if(&destination)) return dest->x25519_pubkey; @@ -66,34 +68,41 @@ EncryptType parse_enc_type(std::string_view enc_type) { Builder Builder::make( const network_destination& destination, + const std::string& endpoint, const std::vector& nodes, const EncryptType enc_type_) { - return Builder{destination, nodes, enc_type_}; + return Builder{destination, endpoint, nodes, enc_type_}; } Builder::Builder( const network_destination& destination, + const std::string& endpoint, const std::vector& nodes, const EncryptType enc_type_) : + endpoint_{endpoint}, enc_type{enc_type_}, - destination_x25519_public_key{detail::pubkey_for_destination(destination)} { + is_v4_request{std::holds_alternative(destination)}, + destination_x25519_public_key_{detail::pubkey_for_destination(destination)} { set_destination(destination); for (auto& n : nodes) add_hop(n.view_remote_key()); } void Builder::add_hop(std::span remote_key) { - hops_.push_back({ed25519_pubkey::from_bytes(remote_key), compute_x25519_pubkey(remote_key)}); + hops_.push_back( + {network::ed25519_pubkey::from_bytes(remote_key), + network::compute_x25519_pubkey(remote_key)}); } void Builder::set_destination(network_destination destination) { ed25519_public_key_.reset(); - if (auto* dest = std::get_if(&destination)) - ed25519_public_key_.emplace(ed25519_pubkey::from_bytes(dest->view_remote_key())); - else if (auto* dest = std::get_if(&destination)) { + if (auto* dest = std::get_if(&destination)) { + is_v4_request = false; + ed25519_public_key_.emplace(network::ed25519_pubkey::from_bytes(dest->view_remote_key())); + } else if (auto* dest = std::get_if(&destination)) { + is_v4_request = true; host_.emplace(dest->host); - endpoint_.emplace(dest->endpoint); method_.emplace(dest->method); // Remove the '://' from the protocol if it was given @@ -112,21 +121,29 @@ void Builder::set_destination(network_destination destination) { throw std::invalid_argument{"Invalid destination type."}; } -void Builder::set_destination_pubkey(session::onionreq::x25519_pubkey x25519_pubkey) { - destination_x25519_public_key.reset(); - destination_x25519_public_key.emplace(x25519_pubkey); -} - -void Builder::generate(network::request_info& info) { - info.body = build(_generate_payload(info.original_body)); +std::vector Builder::generate_onion_blob( + const std::optional>& plaintext_body) { + return build(_generate_payload(plaintext_body)); } std::vector Builder::_generate_payload( std::optional> body) const { // If we don't have the data required for a server request, then assume it's targeting a - // service node and, therefore, the `body` is the payload - if (!host_ || !endpoint_ || !protocol_ || !method_ || !destination_x25519_public_key) - return body.value_or(std::vector{}); + // service node which has a different structure (`method` is the endpoint and the body is + // `params`) + if (!host_ || !protocol_ || !method_ || !destination_x25519_public_key_) { + nlohmann::json params_json; + + if (body) + params_json = nlohmann::json::parse(*body); + else + params_json = nlohmann::json::object(); + + nlohmann::json wrapped_payload = {{"method", endpoint_}, {"params", params_json}}; + + std::string payload_str = wrapped_payload.dump(); + return {payload_str.begin(), payload_str.end()}; + } // Otherwise generate the payload for a server request auto headers_json = nlohmann::json::object(); @@ -141,13 +158,19 @@ std::vector Builder::_generate_payload( if (body && !headers_json.contains("Content-Type")) headers_json["Content-Type"] = "application/json"; + // When making a server request we need a leading forward-slash on the `endpoint` + auto final_endpoint = endpoint_; + + if (!final_endpoint.empty() && final_endpoint[0] != '/') + final_endpoint = '/' + final_endpoint; + // Structure the request information nlohmann::json request_info{ - {"method", *method_}, {"endpoint", *endpoint_}, {"headers", headers_json}}; + {"method", *method_}, {"endpoint", final_endpoint}, {"headers", headers_json}}; std::vector payload{request_info.dump()}; // If we were given a body, add it to the payload - if (body.has_value()) + if (body) payload.emplace_back(session::to_string(*body)); auto result = oxenc::bt_serialize(payload); @@ -195,8 +218,8 @@ std::vector Builder::build(std::vector payload) { // any onion encryption at all all the way back to the client. // Ephemeral keypair: - x25519_pubkey A; - x25519_seckey a; + network::x25519_pubkey A; + network::x25519_seckey a; nlohmann::json final_route; { @@ -205,7 +228,7 @@ std::vector Builder::build(std::vector payload) { // The data we send to the destination differs depending on whether the destination is a // server or a service node - if (host_ && protocol_ && destination_x25519_public_key) { + if (host_ && protocol_ && destination_x25519_public_key_) { final_route = { {"host", *host_}, {"target", "/oxen/v4/lsrpc"}, // All servers support V4 onion requests @@ -217,8 +240,8 @@ std::vector Builder::build(std::vector payload) { {"enc_type", to_string(enc_type)}, }; - blob = e.encrypt(enc_type, payload, *destination_x25519_public_key); - } else if (ed25519_public_key_ && destination_x25519_public_key) { + blob = e.encrypt(enc_type, payload, *destination_x25519_public_key_); + } else if (ed25519_public_key_ && destination_x25519_public_key_) { nlohmann::json control{{"headers", ""}}; final_route = { {"destination", ed25519_public_key_.value().hex()}, // Next hop's ed25519 key @@ -232,11 +255,11 @@ std::vector Builder::build(std::vector payload) { auto data = encode_size(payload.size()); data.insert(data.end(), payload.begin(), payload.end()); data.insert(data.end(), control_span.begin(), control_span.end()); - blob = e.encrypt(enc_type, data, *destination_x25519_public_key); + blob = e.encrypt(enc_type, data, *destination_x25519_public_key_); } else { - if (!destination_x25519_public_key.has_value()) + if (!destination_x25519_public_key_) throw std::runtime_error{"Destination not set: No destination x25519 public key"}; - if (!ed25519_public_key_.has_value()) + if (!ed25519_public_key_) throw std::runtime_error{"Destination not set: No destination ed25519 public key"}; throw std::runtime_error{ "Destination not set: " + host_.value_or("N/A") + ", " + @@ -334,45 +357,37 @@ LIBSESSION_C_API void onion_request_builder_set_snode_destination( const char* ed25519_pubkey) { assert(builder && ip && ed25519_pubkey); - std::array target_ip; - std::memcpy(target_ip.data(), ip, target_ip.size()); - - unbox(builder).set_destination(session::network::service_node( - oxenc::from_hex({ed25519_pubkey, 64}), - {0}, - session::network::INVALID_SWARM_ID, - "{}"_format(fmt::join(target_ip, ".")), - quic_port)); + std::vector pubkey; + pubkey.reserve(32); + oxenc::from_hex(ed25519_pubkey, ed25519_pubkey + 64, std::back_inserter(pubkey)); + + unbox(builder).set_destination(session::network::service_node{ + pubkey, + oxen::quic::ipv4{std::span(ip, 4)}, + 0, + quic_port, + {0, 0, 0}, + session::network::INVALID_SWARM_ID}); } LIBSESSION_C_API void onion_request_builder_set_server_destination( onion_request_builder_object* builder, const char* protocol, const char* host, - const char* endpoint, const char* method, uint16_t port, const char* x25519_pubkey) { - assert(builder && protocol && host && endpoint && protocol && x25519_pubkey); + assert(builder && protocol && host && protocol && x25519_pubkey); - unbox(builder).set_destination(session::onionreq::ServerDestination{ + unbox(builder).set_destination(session::network::ServerDestination{ protocol, host, - endpoint, - session::onionreq::x25519_pubkey::from_hex({x25519_pubkey, 64}), + session::network::x25519_pubkey::from_hex({x25519_pubkey, 64}), port, std::nullopt, method}); } -LIBSESSION_C_API void onion_request_builder_set_destination_pubkey( - onion_request_builder_object* builder, const char* x25519_pubkey) { - assert(builder && x25519_pubkey); - - unbox(builder).set_destination_pubkey( - session::onionreq::x25519_pubkey::from_hex({x25519_pubkey, 64})); -} - LIBSESSION_C_API void onion_request_builder_add_hop( onion_request_builder_object* builder, const char* ed25519_pubkey, @@ -380,8 +395,8 @@ LIBSESSION_C_API void onion_request_builder_add_hop( assert(builder && ed25519_pubkey && x25519_pubkey); unbox(builder).add_hop( - {session::onionreq::ed25519_pubkey::from_hex({ed25519_pubkey, 64}), - session::onionreq::x25519_pubkey::from_hex({x25519_pubkey, 64})}); + {session::network::ed25519_pubkey::from_hex({ed25519_pubkey, 64}), + session::network::x25519_pubkey::from_hex({x25519_pubkey, 64})}); } LIBSESSION_C_API bool onion_request_builder_build( diff --git a/src/onionreq/hop_encryption.cpp b/src/onionreq/hop_encryption.cpp index 611ce0bd..69a491c6 100644 --- a/src/onionreq/hop_encryption.cpp +++ b/src/onionreq/hop_encryption.cpp @@ -17,8 +17,8 @@ #include #include "session/export.h" +#include "session/network/key_types.hpp" #include "session/onionreq/builder.hpp" -#include "session/onionreq/key_types.hpp" #include "session/util.hpp" #include "session/xed25519.hpp" @@ -28,7 +28,7 @@ namespace { // Derive shared secret from our (ephemeral) `seckey` and the other party's `pubkey` std::array calculate_shared_secret( - const x25519_seckey& seckey, const x25519_pubkey& pubkey) { + const network::x25519_seckey& seckey, const network::x25519_pubkey& pubkey) { std::array secret; if (crypto_scalarmult(secret.data(), seckey.data(), pubkey.data()) != 0) throw std::runtime_error("Shared key derivation failed (crypto_scalarmult)"); @@ -38,7 +38,7 @@ namespace { constexpr std::string_view salt{"LOKI"}; std::array derive_symmetric_key( - const x25519_seckey& seckey, const x25519_pubkey& pubkey) { + const network::x25519_seckey& seckey, const network::x25519_pubkey& pubkey) { auto key = calculate_shared_secret(seckey, pubkey); auto usalt = to_span(salt); @@ -56,9 +56,9 @@ namespace { // could be used for AES-GCM as well, but would break backwards compatibility with existing // Session clients). std::array xchacha20_shared_key( - const x25519_pubkey& local_pub, - const x25519_seckey& local_sec, - const x25519_pubkey& remote_pub, + const network::x25519_pubkey& local_pub, + const network::x25519_seckey& local_sec, + const network::x25519_pubkey& remote_pub, bool local_first) { std::array key; static_assert(crypto_aead_xchacha20poly1305_ietf_KEYBYTES >= crypto_scalarmult_BYTES); @@ -90,7 +90,9 @@ bool HopEncryption::response_long_enough(EncryptType type, size_t response_size) } std::vector HopEncryption::encrypt( - EncryptType type, std::vector plaintext, const x25519_pubkey& pubkey) const { + EncryptType type, + std::vector plaintext, + const network::x25519_pubkey& pubkey) const { switch (type) { case EncryptType::xchacha20: return encrypt_xchacha20(plaintext, pubkey); case EncryptType::aes_gcm: return encrypt_aesgcm(plaintext, pubkey); @@ -101,7 +103,7 @@ std::vector HopEncryption::encrypt( std::vector HopEncryption::decrypt( EncryptType type, std::vector ciphertext, - const x25519_pubkey& pubkey) const { + const network::x25519_pubkey& pubkey) const { switch (type) { case EncryptType::xchacha20: return decrypt_xchacha20(ciphertext, pubkey); case EncryptType::aes_gcm: return decrypt_aesgcm(ciphertext, pubkey); @@ -110,7 +112,7 @@ std::vector HopEncryption::decrypt( } std::vector HopEncryption::encrypt_aesgcm( - std::vector plaintext, const x25519_pubkey& pubKey) const { + std::vector plaintext, const network::x25519_pubkey& pubKey) const { auto key = derive_symmetric_key(private_key_, pubKey); // Initialise cipher context with the key @@ -141,7 +143,7 @@ std::vector HopEncryption::encrypt_aesgcm( } std::vector HopEncryption::decrypt_aesgcm( - std::vector ciphertext_, const x25519_pubkey& pubKey) const { + std::vector ciphertext_, const network::x25519_pubkey& pubKey) const { std::span ciphertext = to_span(ciphertext_); if (!response_long_enough(EncryptType::aes_gcm, ciphertext_.size())) @@ -176,7 +178,7 @@ std::vector HopEncryption::decrypt_aesgcm( } std::vector HopEncryption::encrypt_xchacha20( - std::vector plaintext, const x25519_pubkey& pubKey) const { + std::vector plaintext, const network::x25519_pubkey& pubKey) const { std::vector ciphertext; ciphertext.resize( @@ -208,7 +210,7 @@ std::vector HopEncryption::encrypt_xchacha20( } std::vector HopEncryption::decrypt_xchacha20( - std::vector ciphertext_, const x25519_pubkey& pubKey) const { + std::vector ciphertext_, const network::x25519_pubkey& pubKey) const { std::span ciphertext = to_span(ciphertext_); // Extract nonce from the beginning of the ciphertext: diff --git a/src/onionreq/parser.cpp b/src/onionreq/parser.cpp index 766150db..e83640d6 100644 --- a/src/onionreq/parser.cpp +++ b/src/onionreq/parser.cpp @@ -13,7 +13,8 @@ OnionReqParser::OnionReqParser( std::span x25519_sk, std::span req, size_t max_size) : - keys{x25519_pubkey::from_bytes(x25519_pk), x25519_seckey::from_bytes(x25519_sk)}, + keys{network::x25519_pubkey::from_bytes(x25519_pk), + network::x25519_seckey::from_bytes(x25519_sk)}, enc{keys.second, keys.first} { if (sodium_init() == -1) throw std::runtime_error{"Failed to initialize libsodium!"}; @@ -35,7 +36,7 @@ OnionReqParser::OnionReqParser( // else leave it at the backwards-compat AES-GCM default if (auto itr = metadata.find("ephemeral_key"); itr != metadata.end()) - remote_pk = parse_x25519_pubkey(itr->get()); + remote_pk = network::parse_x25519_pubkey(itr->get()); else throw std::invalid_argument{"metadata does not have 'ephemeral_key' entry"}; diff --git a/src/onionreq/response_parser.cpp b/src/onionreq/response_parser.cpp index 64c0d3b1..89230aee 100644 --- a/src/onionreq/response_parser.cpp +++ b/src/onionreq/response_parser.cpp @@ -1,26 +1,33 @@ #include "session/onionreq/response_parser.hpp" +#include #include #include #include #include "session/export.h" +#include "session/network/service_node.hpp" #include "session/onionreq/builder.h" #include "session/onionreq/builder.hpp" #include "session/onionreq/hop_encryption.hpp" +using namespace session; + namespace session::onionreq { ResponseParser::ResponseParser(session::onionreq::Builder builder) { - if (!builder.destination_x25519_public_key.has_value()) + auto dest_x25519_pubkey = builder.get_destination_x25519_public_key(); + + if (!dest_x25519_pubkey) throw std::runtime_error{"Builder does not contain destination x25519 public key"}; - if (!builder.final_hop_x25519_keypair.has_value()) + if (!builder.final_hop_x25519_keypair) throw std::runtime_error{"Builder does not contain final keypair"}; enc_type_ = builder.enc_type; - destination_x25519_public_key_ = builder.destination_x25519_public_key.value(); + destination_x25519_public_key_ = *dest_x25519_pubkey; x25519_keypair_ = builder.final_hop_x25519_keypair.value(); + v4_request_ = builder.is_v4_request; } bool ResponseParser::response_long_enough(EncryptType enc_type, size_t response_size) { @@ -50,6 +57,100 @@ std::vector ResponseParser::decrypt(std::vector ci } } +DecryptedResponse ResponseParser::decrypted_response(const std::string& encrypted_response) { + // Ensure the response is long enough to be processed, if not then handle it as an error + if (!response_long_enough(enc_type_, encrypted_response.size())) + throw std::runtime_error{ + "Response is too short to be an onion request response: " + encrypted_response}; + + if (v4_request_) + return _decrypt_v4_response(encrypted_response); + else + return _decrypt_v3_response(encrypted_response); +} + +DecryptedResponse ResponseParser::_decrypt_v3_response(const std::string& response) { + std::string base64_iv_and_ciphertext; + try { + nlohmann::json response_json = nlohmann::json::parse(response); + + if (!response_json.contains("result") || !response_json["result"].is_string()) + throw std::runtime_error{"JSON missing result field."}; + + base64_iv_and_ciphertext = response_json["result"].get(); + } catch (...) { + base64_iv_and_ciphertext = response; + } + + if (!oxenc::is_base64(base64_iv_and_ciphertext)) + throw std::runtime_error{"Invalid base64 encoded IV and ciphertext."}; + + std::vector iv_and_ciphertext; + oxenc::from_base64( + base64_iv_and_ciphertext.begin(), + base64_iv_and_ciphertext.end(), + std::back_inserter(iv_and_ciphertext)); + auto result = decrypt(iv_and_ciphertext); + auto result_json = nlohmann::json::parse(result); + int16_t status_code; + std::vector> headers; + std::string body; + + if (result_json.contains("status_code") && result_json["status_code"].is_number()) + status_code = result_json["status_code"].get(); + else if (result_json.contains("status") && result_json["status"].is_number()) + status_code = result_json["status"].get(); + else + throw std::runtime_error{"Invalid JSON response, missing required status_code field."}; + + if (result_json.contains("headers")) { + auto header_vals = result_json["headers"]; + + for (auto it = header_vals.begin(); it != header_vals.end(); ++it) + headers.emplace_back(it.key(), it.value()); + } + + if (result_json.contains("body") && result_json["body"].is_string()) + body = result_json["body"].get(); + else + body = result_json.dump(); + + return {status_code, headers, body}; +} + +DecryptedResponse ResponseParser::_decrypt_v4_response(const std::string& response) { + auto response_data = to_vector(response); + auto result = decrypt(response_data); + + // Process the bencoded response + oxenc::bt_list_consumer result_bencode{to_span(result)}; + + if (result_bencode.is_finished() || !result_bencode.is_string()) + throw std::runtime_error{"Invalid bencoded response"}; + + auto response_info_string = result_bencode.consume_string(); + int16_t status_code; + std::vector> headers; + nlohmann::json response_info_json = nlohmann::json::parse(response_info_string); + + if (response_info_json.contains("code") && response_info_json["code"].is_number()) + status_code = response_info_json["code"].get(); + else + throw std::runtime_error{"Invalid JSON response, missing required code field."}; + + if (response_info_json.contains("headers")) { + auto header_vals = response_info_json["headers"]; + + for (auto it = header_vals.begin(); it != header_vals.end(); ++it) + headers.emplace_back(it.key(), it.value()); + } + + if (result_bencode.is_finished()) + return {status_code, headers, std::nullopt}; + + return {status_code, headers, result_bencode.consume_string()}; +} + } // namespace session::onionreq extern "C" { @@ -83,8 +184,8 @@ LIBSESSION_C_API bool onion_request_decrypt( } session::onionreq::HopEncryption d{ - session::onionreq::x25519_seckey::from_bytes({final_x25519_seckey, 32}), - session::onionreq::x25519_pubkey::from_bytes({final_x25519_pubkey, 32}), + session::network::x25519_seckey::from_bytes({final_x25519_seckey, 32}), + session::network::x25519_pubkey::from_bytes({final_x25519_pubkey, 32}), false}; std::vector result; @@ -99,13 +200,13 @@ LIBSESSION_C_API bool onion_request_decrypt( result = d.decrypt( enc_type, ciphertext, - session::onionreq::x25519_pubkey::from_bytes({destination_x25519_pubkey, 32})); + session::network::x25519_pubkey::from_bytes({destination_x25519_pubkey, 32})); } catch (...) { if (enc_type == session::onionreq::EncryptType::xchacha20) result = d.decrypt( session::onionreq::EncryptType::aes_gcm, ciphertext, - session::onionreq::x25519_pubkey::from_bytes( + session::network::x25519_pubkey::from_bytes( {destination_x25519_pubkey, 32})); else return false; diff --git a/src/session_network.cpp b/src/session_network.cpp deleted file mode 100644 index eaa85de3..00000000 --- a/src/session_network.cpp +++ /dev/null @@ -1,3310 +0,0 @@ -#include "session/session_network.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "session/blinding.hpp" -#include "session/ed25519.hpp" -#include "session/export.h" -#include "session/file.hpp" -#include "session/onionreq/builder.h" -#include "session/onionreq/builder.hpp" -#include "session/onionreq/key_types.hpp" -#include "session/onionreq/response_parser.hpp" -#include "session/session_network.h" -#include "session/util.hpp" - -using namespace oxen; -using namespace session::onionreq; -using namespace std::literals; -using namespace oxen::log::literals; - -namespace session::network { - -namespace { - - inline auto cat = log::Cat("network"); - - class load_cache_exception : public std::runtime_error { - public: - load_cache_exception(std::string message) : std::runtime_error(message) {} - }; - class status_code_exception : public std::runtime_error { - public: - int16_t status_code; - std::vector> headers; - - status_code_exception( - int16_t status_code, - std::vector> headers, - std::string message) : - std::runtime_error(message), status_code{status_code}, headers{headers} {} - }; - - constexpr int16_t error_network_suspended = -10001; - constexpr int16_t error_building_onion_request = -10002; - constexpr int16_t error_path_build_timeout = -10003; - - const std::pair content_type_plain_text = { - "Content-Type", "text/plain; charset=UTF-8"}; - const std::pair content_type_json = { - "Content-Type", "application/json"}; - - // The amount of time the snode cache can be used before it needs to be refreshed/ - constexpr auto snode_cache_expiration_duration = 2h; - - // The smallest size the snode cache can get to before we need to fetch more. - constexpr size_t min_snode_cache_count = 12; - - // The number of snodes to use to refresh the cache. - constexpr int num_snodes_to_refresh_cache_from = 3; - - // The number of snodes (including the guard snode) in a path. - constexpr uint8_t path_size = 3; - - // The number of times a path can fail before it's replaced. - constexpr uint16_t path_failure_threshold = 3; - - // The number of times a snode can fail before it's replaced. - constexpr uint16_t snode_failure_threshold = 3; - - // The frequency to check if queued requests have timed out due to a pending path build - constexpr auto queued_request_path_build_timeout_frequency = 250ms; - - const fs::path default_cache_path{u8"."}, file_testnet{u8"testnet"}, - file_snode_pool{u8"snode_pool"}; - const std::vector legacy_files{ - u8"snode_pool_updated", u8"swarm", u8"snode_failure_counts"}; - - constexpr auto node_not_found_prefix = "502 Bad Gateway\n\nNext node not found: "sv; - constexpr auto node_not_found_prefix_no_status = "Next node not found: "sv; - constexpr auto ALPN = "oxenstorage"; - constexpr auto ONION = "onion_req"; - - enum class PathSelectionBehaviour { - random, - new_or_least_busy, - }; - - std::string path_type_name(PathType path_type, bool single_path_mode) { - if (single_path_mode) - return "single_path"; - - switch (path_type) { - case PathType::standard: return "standard"; - case PathType::upload: return "upload"; - case PathType::download: return "download"; - } - return "standard"; // Default - } - - // The mininum number of paths we want to maintain - uint8_t min_path_count(PathType path_type, bool single_path_mode) { - if (single_path_mode) - return 1; - - switch (path_type) { - case PathType::standard: return 2; - case PathType::upload: return 2; - case PathType::download: return 2; - } - return 2; // Default - } - - PathSelectionBehaviour path_selection_behaviour(PathType path_type) { - switch (path_type) { - case PathType::standard: return PathSelectionBehaviour::random; - case PathType::upload: return PathSelectionBehaviour::new_or_least_busy; - case PathType::download: return PathSelectionBehaviour::new_or_least_busy; - } - return PathSelectionBehaviour::random; // Default - } - - /// Converts a string such as "1.2.3" to a vector of ints {1,2,3}. Throws if something - /// in/around the .'s isn't parseable as an integer. - std::vector parse_version(std::string_view vers, bool trim_trailing_zero = true) { - auto v_s = session::split(vers, "."); - std::vector result; - for (const auto& piece : v_s) - if (!quic::parse_int(piece, result.emplace_back())) - throw std::invalid_argument{"Invalid version"}; - - // Remove any trailing `0` values (but ensure we at least end up with a "0" version) - if (trim_trailing_zero) - while (result.size() > 1 && result.back() == 0) - result.pop_back(); - - return result; - } - - service_node node_from_json(nlohmann::json json) { - auto pk_ed = json["pubkey_ed25519"].get(); - if (pk_ed.size() != 64 || !oxenc::is_hex(pk_ed)) - throw std::invalid_argument{ - "Invalid service node json: pubkey_ed25519 is not a valid, hex pubkey"}; - - // When parsing a node from JSON it'll generally be from the 'get_swarm` endpoint or a 421 - // error neither of which contain the `storage_server_version` - luckily we don't need the - // version for these two cases so can just default it to `0` - std::vector storage_server_version = {0}; - if (json.contains("storage_server_version")) { - if (json["storage_server_version"].is_array()) { - if (json["storage_server_version"].size() > 0) { - // Convert the version to a string and parse it back into a version code to - // ensure the version formats remain consistent throughout - storage_server_version = json["storage_server_version"].get>(); - storage_server_version = - parse_version("{}"_format(fmt::join(storage_server_version, "."))); - } - } else - storage_server_version = - parse_version(json["storage_server_version"].get()); - } - - std::string ip; - if (json.contains("public_ip")) - ip = json["public_ip"].get(); - else - ip = json["ip"].get(); - - if (ip == "0.0.0.0") - throw std::runtime_error{"Invalid IP address"}; - - uint16_t port; - if (json.contains("storage_lmq_port")) - port = json["storage_lmq_port"].get(); - else - port = json["port_omq"].get(); - - if (port == 0) - throw std::runtime_error{"Invalid lmq port"}; - - swarm_id_t swarm_id = INVALID_SWARM_ID; - if (json.contains("swarm_id")) - swarm_id = json["swarm_id"].get(); - - return {oxenc::from_hex(pk_ed), storage_server_version, swarm_id, ip, port}; - } - - service_node node_from_disk(std::string_view str, bool can_ignore_version = false) { - // Format is "{ip}|{port}|{version}|{ed_pubkey}|{swarm_id}" - auto parts = split(str, "|"); - if (parts.size() != 5) - throw std::invalid_argument("Invalid service node serialisation: {}"_format(str)); - if (parts[3].size() != 64 || !oxenc::is_hex(parts[3])) - throw std::invalid_argument{ - "Invalid service node serialisation: pubkey is not hex or has wrong size"}; - - uint16_t port; - if (!quic::parse_int(parts[1], port)) - throw std::invalid_argument{"Invalid service node serialization: invalid port"}; - - std::vector storage_server_version = parse_version(parts[2]); - if (!can_ignore_version && storage_server_version == std::vector{0}) - throw std::invalid_argument{"Invalid service node serialization: invalid version"}; - - swarm_id_t swarm_id = INVALID_SWARM_ID; - quic::parse_int(parts[4], swarm_id); - - return { - oxenc::from_hex(parts[3]), // ed25519_pubkey - storage_server_version, // storage_server_version - swarm_id, // swarm_id - std::string(parts[0]), // ip - port, // port - }; - } - - const std::vector seed_nodes_testnet{ - node_from_disk("95.216.33.113|35400|2.8.0|" - "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9|"sv)}; - const std::vector seed_nodes_mainnet{ - node_from_disk("95.216.33.113|20200|2.8.0|" - "1f000f09a7b07828dcb72af7cd16857050c10c02bd58afb0e38111fb6cda1fef|"sv), - node_from_disk("37.27.236.229|20201|2.8.0|" - "1f101f0acee4db6f31aaa8b4df134e85ca8a4878efaef7f971e88ab144c1a7ce|"sv), - node_from_disk("172.96.140.124|20202|2.8.0|" - "1f202f00f4d2d4acc01e20773999a291cf3e3136c325474d159814e06199919f|"sv), - node_from_disk("208.73.207.54|20203|2.8.0|" - "1f303f1d7523c46fa5398826740d13282d26b5de90fbae5749442f66afb6d78b|"sv), - node_from_disk("104.194.8.115|20204|2.8.0|" - "1f604f1c858a121a681d8f9b470ef72e6946ee1b9c5ad15a35e16b50c28db7b0|"sv)}; - constexpr auto file_server = "filev2.getsession.org"sv; - constexpr auto file_server_pubkey = - "da21e1d886c6fbaea313f75298bd64aab03a97ce985b46bb2dad9f2089c8ee59"sv; - - std::string node_to_disk(service_node node) { - // Format is "{ip}|{port}|{version}|{ed_pubkey}|{swarm_id}" - auto ed25519_pubkey_hex = oxenc::to_hex(node.view_remote_key()); - - return fmt::format( - "{}|{}|{}|{}|{}", - node.host(), - node.port(), - "{}"_format(fmt::join(node.storage_server_version, ".")), - ed25519_pubkey_hex, - node.swarm_id); - } - - session::onionreq::x25519_pubkey compute_xpk(std::span ed25519_pk) { - std::array xpk; - if (0 != crypto_sign_ed25519_pk_to_curve25519(xpk.data(), ed25519_pk.data())) - throw std::runtime_error{ - "An error occured while attempting to convert Ed25519 pubkey to X25519; " - "is the pubkey valid?"}; - return session::onionreq::x25519_pubkey::from_bytes({xpk.data(), 32}); - } - - std::string consume_string(oxenc::bt_dict_consumer dict, std::string_view key) { - if (!dict.skip_until(key)) - throw std::invalid_argument{ - "Unable to find entry in dict for key '" + std::string(key) + "'"}; - return dict.consume_string(); - } - - template - auto consume_integer(oxenc::bt_dict_consumer dict, std::string_view key) { - if (!dict.skip_until(key)) - throw std::invalid_argument{ - "Unable to find entry in dict for key '" + std::string(key) + "'"}; - return dict.next_integer().second; - } -} // namespace - -namespace detail { - swarm_id_t pubkey_to_swarm_space(const session::onionreq::x25519_pubkey& pk) { - swarm_id_t res = 0; - for (size_t i = 0; i < 4; i++) { - swarm_id_t buf; - std::memcpy(&buf, pk.data() + i * 8, 8); - res ^= buf; - } - oxenc::big_to_host_inplace(res); - - return res; - } - - std::vector>> generate_swarms( - std::vector nodes) { - std::vector>> result; - std::unordered_map> _grouped_nodes; - - for (const auto& node : nodes) - _grouped_nodes[node.swarm_id].push_back(node); - - for (auto& [swarm_id, nodes] : _grouped_nodes) - result.emplace_back(swarm_id, std::move(nodes)); - - std::sort(result.begin(), result.end(), [](const auto& a, const auto& b) { - return a.first < b.first; - }); - return result; - } - - std::optional node_for_destination(network_destination destination) { - if (auto* dest = std::get_if(&destination)) - return *dest; - - return std::nullopt; - } - - nlohmann::json get_service_nodes_params(std::optional limit) { - nlohmann::json params{ - {"active_only", true}, - {"fields", - {{"public_ip", true}, - {"pubkey_ed25519", true}, - {"storage_lmq_port", true}, - {"storage_server_version", true}, - {"swarm_id", true}}}}; - - if (limit) - params["limit"] = *limit; - - return params; - } - - std::vector process_get_service_nodes_response( - oxenc::bt_list_consumer result_bencode) { - std::vector result; - result_bencode.skip_value(); // Skip the status code (already validated) - auto response_dict = result_bencode.consume_dict_consumer(); - response_dict.skip_until("result"); - - auto result_dict = response_dict.consume_dict_consumer(); - result_dict.skip_until("service_node_states"); - - // Process the node list - auto node = result_dict.consume_list_consumer(); - - while (!node.is_finished()) { - try { - auto node_consumer = node.consume_dict_consumer(); - auto pubkey_ed25519 = - oxenc::from_hex(consume_string(node_consumer, "pubkey_ed25519")); - auto public_ip = consume_string(node_consumer, "public_ip"); - auto storage_lmq_port = - consume_integer(node_consumer, "storage_lmq_port"); - - if (public_ip == "0.0.0.0") - throw std::runtime_error{"Invalid IP address"}; - - if (storage_lmq_port == 0) - throw std::runtime_error{"Invalid lmq port"}; - - std::vector storage_server_version; - node_consumer.skip_until("storage_server_version"); - auto version_consumer = node_consumer.consume_list_consumer(); - auto swarm_id = consume_integer(node_consumer, "swarm_id"); - - while (!version_consumer.is_finished()) { - storage_server_version.emplace_back(version_consumer.consume_integer()); - } - - result.emplace_back( - pubkey_ed25519, - storage_server_version, - swarm_id, - public_ip, - storage_lmq_port); - } catch (const std::exception& e) { - log::warning(cat, "Ignoring invalid snode: {}.", e.what()); - } - } - - return result; - } - - std::vector process_get_service_nodes_response(nlohmann::json response_json) { - if (!response_json.contains("result") || !response_json["result"].is_object()) - throw std::runtime_error{"JSON missing result field."}; - - nlohmann::json result_json = response_json["result"]; - if (!result_json.contains("service_node_states") || - !result_json["service_node_states"].is_array()) - throw std::runtime_error{"JSON missing service_node_states field."}; - - std::vector result; - for (auto& snode : result_json["service_node_states"]) - try { - result.emplace_back(node_from_json(snode)); - } catch (const std::exception& e) { - log::warning(cat, "Ignoring invalid snode: {}.", e.what()); - } - - return result; - } - - void log_retry_result_if_needed(request_info info, bool single_path_mode) { - if (!info.retry_reason) - return; - - // For debugging purposes if the error was a redirect retry then - // we want to log that the retry was successful as this will - // help identify how often we are receiving incorrect errors - auto reason = "unknown retry"; - - switch (*info.retry_reason) { - case request_info::RetryReason::none: reason = "unknown retry"; break; - case request_info::RetryReason::redirect: reason = "421 retry"; break; - case request_info::RetryReason::decryption_failure: reason = "decryption error"; break; - case request_info::RetryReason::redirect_swarm_refresh: - reason = "421 swarm refresh retry"; - break; - } - - log::info( - cat, - "Received valid response after {} in request {} for {}.", - reason, - info.request_id, - path_type_name(info.path_type, single_path_mode)); - } - - std::vector convert_service_nodes( - std::vector nodes) { - std::vector converted_nodes; - for (auto& node : nodes) { - auto ed25519_pubkey_hex = oxenc::to_hex(node.view_remote_key()); - auto ipv4 = node.to_ipv4(); - network_service_node converted_node; - converted_node.ip[0] = (ipv4.addr >> 24) & 0xFF; - converted_node.ip[1] = (ipv4.addr >> 16) & 0xFF; - converted_node.ip[2] = (ipv4.addr >> 8) & 0xFF; - converted_node.ip[3] = ipv4.addr & 0xFF; - strncpy(converted_node.ed25519_pubkey_hex, ed25519_pubkey_hex.c_str(), 64); - converted_node.ed25519_pubkey_hex[64] = '\0'; // Ensure null termination - converted_node.quic_port = node.port(); - converted_nodes.push_back(converted_node); - } - - return converted_nodes; - } - - ServerDestination convert_server_destination(const network_server_destination server) { - std::optional>> headers; - if (server.headers_size > 0) { - headers = std::vector>{}; - - for (size_t i = 0; i < server.headers_size; i++) - headers->emplace_back(server.headers[i], server.header_values[i]); - } - - return ServerDestination{ - server.protocol, - server.host, - server.endpoint, - x25519_pubkey::from_hex({server.x25519_pubkey, 64}), - server.port, - headers, - server.method}; - } -} // namespace detail - -request_info request_info::make( - onionreq::network_destination _dest, - std::optional> _original_body, - std::optional _swarm_pk, - std::chrono::milliseconds _request_timeout, - std::optional _request_and_path_build_timeout, - PathType _type, - std::optional _req_id, - std::optional _ep, - std::optional> _body) { - return request_info{ - _req_id.value_or("R-{}"_format(random::random_base32(4))), - std::move(_dest), - _ep.value_or(ONION), - std::move(_body), - std::move(_original_body), - std::move(_swarm_pk), - _type, - _request_timeout, - _request_and_path_build_timeout}; -} - -std::string onion_path::to_string() const { - std::vector node_descriptions; - std::transform( - nodes.begin(), - nodes.end(), - std::back_inserter(node_descriptions), - [](const service_node& node) { return node.to_string(); }); - - return "{}"_format(fmt::join(node_descriptions, ", ")); -} - -bool onion_path::contains_node(const service_node& sn) const { - for (auto& n : nodes) { - if (n == sn) - return true; - } - - return false; -} - -// MARK: Initialization - -Network::Network( - std::optional cache_path, - bool use_testnet, - bool single_path_mode, - bool pre_build_paths) : - use_testnet{use_testnet}, - should_cache_to_disk{cache_path}, - single_path_mode{single_path_mode}, - cache_path{cache_path.value_or(default_cache_path)} { - loop = std::make_shared(); - - // Load the cache from disk and start the disk write thread - if (should_cache_to_disk) { - load_cache_from_disk(); - disk_write_thread = std::thread{&Network::disk_write_thread_loop, this}; - } - - // Kick off a separate thread to build paths (may as well kick this off early) - if (pre_build_paths) - for (int i = 0; i < min_path_count(PathType::standard, single_path_mode); ++i) { - auto path_id = "P-{}"_format(random::random_base32(4)); - in_progress_path_builds[path_id] = PathType::standard; - loop->call_soon([this, path_id] { build_path(path_id, PathType::standard); }); - } -} - -Network::~Network() { - // Flag the network as suspended when we start destroying to ensure no new requests get started - // (which could result in additional calls being added to the `loop` incorrectly and cause bad - // memory crashes) - suspended = true; - - // Trigger a 'call_get' to block until the endpoint has been destroyed - loop->call_get([this]() mutable { _close_connections(); }); - - { - std::lock_guard lock{snode_cache_mutex}; - shut_down_disk_thread = true; - } - update_disk_cache_throttled(true); - if (disk_write_thread.joinable()) - disk_write_thread.join(); -} - -// MARK: Cache Management - -void Network::load_cache_from_disk() { - try { - // If the cache is for the wrong network then delete everything - auto testnet_stub = cache_path / file_testnet; - if (use_testnet != fs::exists(testnet_stub) && fs::exists(testnet_stub)) - fs::remove_all(cache_path); - - // Remove any legacy files (don't want to leave old data around) - for (const auto& path : legacy_files) { - auto path_to_remove = cache_path / path; - fs::remove_all(path_to_remove); - } - - // If we are using testnet then create a file to indicate that - if (use_testnet) - write_whole_file(testnet_stub); - - // Load the snode pool - if (auto pool_path = cache_path / file_snode_pool; fs::exists(pool_path)) { - auto ftime = fs::last_write_time(pool_path); - last_snode_cache_update = - std::chrono::time_point_cast( - ftime - fs::file_time_type::clock::now() + - std::chrono::system_clock::now()); - - auto file = open_for_reading(pool_path); - std::vector loaded_cache; - std::string line; - auto invalid_entries = 0; - - while (std::getline(file, line)) { - try { - loaded_cache.push_back(node_from_disk(line)); - } catch (...) { - ++invalid_entries; - } - } - - if (invalid_entries > 0) - log::warning(cat, "Skipped {} invalid entries in snode cache.", invalid_entries); - - snode_cache = loaded_cache; - all_swarms = detail::generate_swarms(loaded_cache); - } - - log::info( - cat, - "Loaded cache of {} snodes, {} swarms.", - snode_cache.size(), - all_swarms.size()); - } catch (const std::exception& e) { - log::error(cat, "Failed to load snode cache, will rebuild ({}).", e.what()); - - if (fs::exists(cache_path)) - fs::remove_all(cache_path); - } -} - -void Network::update_disk_cache_throttled(bool force_immediate_write) { - // If we are forcing an immediate write then just notify the disk write thread and reset the - // pending write flag - if (force_immediate_write) { - snode_cache_cv.notify_one(); - has_pending_disk_write = false; - return; - } - - if (has_pending_disk_write) - return; - - has_pending_disk_write = true; - loop->call_later(1s, [this]() { - snode_cache_cv.notify_one(); - has_pending_disk_write = false; - }); -} - -void Network::disk_write_thread_loop() { - std::unique_lock lock{snode_cache_mutex}; - while (true) { - snode_cache_cv.wait( - lock, [this] { return need_write || need_clear_cache || shut_down_disk_thread; }); - - if (need_write) { - // Make a local copy so that we can release the lock and not - // worry about other threads wanting to change things - auto snode_cache_write = snode_cache; - - lock.unlock(); - { - try { - // Create the cache directories if needed - fs::create_directories(cache_path); - - // If we are using testnet then create a file to indicate that - if (use_testnet) { - auto testnet_stub = cache_path / file_testnet; - write_whole_file(testnet_stub); - } - - // Save the snode pool to disk - auto pool_path = cache_path / file_snode_pool, pool_tmp = pool_path; - pool_tmp += u8"_new"; - - { - std::stringstream ss; - for (auto& snode : snode_cache_write) - ss << node_to_disk(snode) << '\n'; - - std::ofstream file(pool_tmp, std::ios::binary); - file << ss.rdbuf(); - } - - fs::rename(pool_tmp, pool_path); - need_write = false; - - log::debug(cat, "Finished writing snode cache to disk."); - } catch (const std::exception& e) { - log::error(cat, "Failed to write snode cache: {}", e.what()); - } - } - lock.lock(); - } - if (need_clear_cache) { - snode_cache = {}; - - lock.unlock(); - if (fs::exists(cache_path)) - fs::remove_all(cache_path); - lock.lock(); - need_clear_cache = false; - } - if (shut_down_disk_thread) - return; - } -} - -void Network::clear_cache() { - loop->call([this] { - { - std::lock_guard lock{snode_cache_mutex}; - need_clear_cache = true; - } - update_disk_cache_throttled(true); - }); -} - -size_t Network::snode_cache_size() { - return loop->call_get([this]() -> size_t { return snode_cache.size(); }); -} - -// MARK: Connection - -void Network::suspend() { - loop->call([this] { - suspended = true; - close_connections(); - log::info(cat, "Suspended."); - }); -} - -void Network::resume() { - loop->call([this] { - suspended = false; - log::info(cat, "Resumed."); - }); -} - -void Network::close_connections() { - loop->call([this] { _close_connections(); }); -} - -void Network::_close_connections() { - // Explicitly close all connections then reset the endpoint - if (endpoint) - endpoint->close_conns(); - endpoint.reset(); - - // Cancel any pending requests (they can't succeed once the connection is closed) - for (const auto& [path_type, path_type_requests] : request_queue) - for (const auto& [info, callback] : path_type_requests) - callback( - false, - false, - error_network_suspended, - {content_type_plain_text}, - "Network is suspended."); - - // Clear all storage of requests, paths and connections so that we are in a fresh state on - // relaunch - request_queue.clear(); - paths.clear(); - path_build_queue.clear(); - paths_pending_drop.clear(); - unused_connections.clear(); - in_progress_connections.clear(); - snode_refresh_results.reset(); - current_snode_cache_refresh_request_id = std::nullopt; - - update_status(ConnectionStatus::disconnected); - log::info(cat, "Closed all connections."); -} - -void Network::update_status(ConnectionStatus updated_status) { - // Ignore updates which don't change the status - if (status == updated_status) - return; - - // If we are already 'connected' then ignore 'connecting' status changes (if we drop one path - // and build another in the background this can happen) - if (status == ConnectionStatus::connected && updated_status == ConnectionStatus::connecting) - return; - - // Store the updated status - status = updated_status; - - if (!status_changed) - return; - - status_changed(updated_status); -} - -std::chrono::milliseconds Network::retry_delay( - int num_failures, std::chrono::milliseconds max_delay) { - return std::chrono::milliseconds(std::min( - max_delay.count(), - static_cast(100 * std::pow(2, num_failures)))); -} - -std::shared_ptr Network::get_endpoint() { - return loop->call_get([this]() mutable { - if (!endpoint) - endpoint = quic::Endpoint::endpoint( - *loop, - quic::Address{"0.0.0.0", 0}, - quic::opt::alpns{ALPN}, - quic::opt::disable_mtu_discovery{}); - - return endpoint; - }); -} - -// MARK: Request Queues and Path Building - -size_t Network::min_snode_cache_size() const { - if (!seed_node_cache_size) - return min_snode_cache_count; - - // If the seed node cache size is somehow smaller than `min_snode_cache_count` (ie. Testnet - // having issues) then the minimum size should be the full cache size (minus enough to build a - // path) or at least the size of a path - auto min_path_size = static_cast(path_size); - return std::min( - std::max(min_path_size, *seed_node_cache_size - min_path_size), min_snode_cache_count); -} - -std::vector Network::get_unused_nodes() { - if (snode_cache.size() < min_snode_cache_size()) - return {}; - - // Exclude any IPs that are already in use from existing paths - std::vector node_ips_to_exlude = all_path_ips(); - - // Exclude unused connections - for (const auto& conn_info : unused_connections) - node_ips_to_exlude.emplace_back(conn_info.node.to_ipv4()); - - // Exclude in progress connections - for (const auto& [request_id, node] : in_progress_connections) - node_ips_to_exlude.emplace_back(node.to_ipv4()); - - // Exclude pending requests - for (const auto& [path_type, path_type_requests] : request_queue) - for (const auto& [info, callback] : path_type_requests) - if (auto* dest = std::get_if(&info.destination)) - node_ips_to_exlude.emplace_back(dest->to_ipv4()); - - // Exclude any nodes which have surpassed the failure threshold - for (const auto& [node_string, failure_count] : snode_failure_counts) - if (failure_count >= snode_failure_threshold) { - size_t colon_pos = node_string.find(':'); - - if (colon_pos != std::string::npos) - node_ips_to_exlude.emplace_back(quic::ipv4{node_string.substr(0, colon_pos)}); - else - node_ips_to_exlude.emplace_back(quic::ipv4{node_string}); - } - - // Populate the unused nodes with any nodes in the cache which shouldn't be excluded - std::vector result; - - if (node_ips_to_exlude.empty()) - result = snode_cache; - else - std::copy_if( - snode_cache.begin(), - snode_cache.end(), - std::back_inserter(result), - [&node_ips_to_exlude](const auto& node) { - return std::find( - node_ips_to_exlude.begin(), - node_ips_to_exlude.end(), - node.to_ipv4()) == node_ips_to_exlude.end(); - }); - - // Shuffle the `result` so anything that uses it would get random nodes - std::shuffle(result.begin(), result.end(), csrng); - - return result; -} - -void Network::establish_connection( - std::string id, - service_node target, - std::optional timeout, - std::function error)> callback) { - log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, id); - auto currently_suspended = loop->call_get([this]() -> bool { return suspended; }); - - // If the network is currently suspended then don't try to open a connection - if (currently_suspended) - return callback( - {target, std::make_shared(0), nullptr, nullptr}, "Network is suspended."); - - auto conn_key_pair = ed25519::ed25519_key_pair(); - auto creds = quic::GNUTLSCreds::make_from_ed_seckey(to_string_view(conn_key_pair.second)); - auto cb_called = std::make_shared(); - auto cb = std::make_shared)>>( - std::move(callback)); - auto conn_promise = std::promise>(); - auto conn_future = conn_promise.get_future().share(); - auto handshake_timeout = - timeout ? std::optional{quic::opt::handshake_timeout{ - std::chrono::duration_cast(*timeout)}} - : std::nullopt; - - auto c = get_endpoint()->connect( - target, - creds, - quic::opt::keep_alive{10s}, - handshake_timeout, - [this, id, target, cb, cb_called, conn_future](quic::Connection&) mutable { - log::trace(cat, "Connection established for {}.", id); - - // Just in case, call it within a `loop->call` - loop->call([&] { - std::call_once(*cb_called, [&]() { - if (cb) { - auto conn = conn_future.get(); - (*cb)({target, - std::make_shared(0), - conn, - conn->open_stream()}, - std::nullopt); - cb.reset(); - } - }); - }); - }, - [this, target, id, cb, cb_called, conn_future]( - quic::Connection& conn, uint64_t error_code) mutable { - if (error_code == static_cast(NGTCP2_ERR_HANDSHAKE_TIMEOUT)) - log::info( - cat, - "Unable to establish connection to {} for {}.", - target.to_string(), - id); - else - log::info(cat, "Connection to {} closed for {}.", target.to_string(), id); - - // Just in case, call it within a `loop->call` - loop->call([&] { - // Trigger the callback first before updating the paths in case this was - // triggered when try to establish a connection - std::call_once(*cb_called, [&]() { - if (cb) { - (*cb)({target, std::make_shared(0), nullptr, nullptr}, - std::nullopt); - cb.reset(); - } - }); - - // Remove the connection from `unused_connection` if present - std::erase_if(unused_connections, [&conn, &target](auto& unused_conn) { - return (unused_conn.node == target && unused_conn.conn && - unused_conn.conn->reference_id() == conn.reference_id()); - }); - - // If this connection is being used in an existing path then we should drop it - // (as the path is no longer valid) - for (const auto& [path_type, paths_for_type] : paths) { - for (const auto& path : paths_for_type) { - if (!path.nodes.empty() && path.nodes.front() == target && - path.conn_info.conn && - conn.reference_id() == path.conn_info.conn->reference_id()) { - drop_path_when_empty(id, path_type, path); - break; - } - } - } - - // Since a connection was closed we should also clear any pending path drops - // in case this connection was one of those - clear_empty_pending_path_drops(); - - // If the connection failed with a handshake timeout then the node is - // unreachable, either due to a device network issue or because the node - // is down so set the failure count to the failure threshold so it won't - // be used for subsequent requests - if (error_code == static_cast(NGTCP2_ERR_HANDSHAKE_TIMEOUT)) - snode_failure_counts[target.to_string()] = snode_failure_threshold; - }); - }); - - conn_promise.set_value(c); -} - -void Network::establish_and_store_connection(std::string path_id) { - // If we are suspended then don't try to establish a new connection - if (suspended) - return; - - // If we haven't set a connection status yet then do so now - if (status == ConnectionStatus::unknown) - update_status(ConnectionStatus::connecting); - - // Re-populate the unused nodes if it ends up being empty - if (unused_nodes.empty()) - unused_nodes = get_unused_nodes(); - - // If there aren't enough unused nodes then trigger a cache refresh - if (unused_nodes.size() < min_snode_cache_size()) { - log::trace( - cat, - "Unable to establish new connection due to lack of unused nodes, refreshing snode " - "cache ({}).", - path_id); - return loop->call_soon([this, path_id]() { refresh_snode_cache(path_id); }); - } - - // Otherwise check if it's been too long since the last cache update and, if so, trigger a - // refresh - auto cache_lifetime = std::chrono::duration_cast( - std::chrono::system_clock::now() - last_snode_cache_update); - - if (cache_lifetime < 0s || cache_lifetime > snode_cache_expiration_duration) - loop->call_soon([this]() { refresh_snode_cache(); }); - - // If there are no in progress connections then reset the failure count - if (in_progress_connections.empty()) - connection_failures = 0; - - // Grab a node from the `unused_nodes` list to establish a connection to - auto target_node = unused_nodes.back(); - unused_nodes.pop_back(); - - // Try to establish a new connection to the target (this has a 3s handshake timeout as we - // wouldn't want to use any nodes which take longer than that anyway) - log::info(cat, "Establishing connection to {} for {}.", target_node.to_string(), path_id); - in_progress_connections.emplace(path_id, target_node); - - establish_connection( - path_id, - target_node, - 3s, - [this, target_node, path_id](connection_info info, std::optional) { - // If we failed to get a connection then try again after a delay (may as well try - // indefinitely because there is no way to recover from this issue) - if (!info.is_valid()) { - connection_failures++; - auto connection_retry_delay = retry_delay(connection_failures); - log::error( - cat, - "Failed to connect to {}, will try another after {}ms.", - target_node.to_string(), - connection_retry_delay.count()); - return loop->call_later(connection_retry_delay, [this, path_id]() { - establish_and_store_connection(path_id); - }); - } - - // We were able to connect to the node so add it to the unused_connections queue - log::info(cat, "Connection to {} valid for {}.", target_node.to_string(), path_id); - unused_connections.emplace_back(info); - - // Kick off the next pending path build since we now have a valid connection - if (!path_build_queue.empty()) { - in_progress_path_builds[path_id] = path_build_queue.front(); - loop->call_soon([this, path_type = path_build_queue.front(), path_id]() { - build_path(path_id, path_type); - }); - path_build_queue.pop_front(); - } - - // If there are still pending path builds but no in progress connections then kick - // off enough additional connections for remaining builds (this shouldn't happen but - // better to be safe and avoid a situation where a path build gets orphaned) - if (!path_build_queue.empty() && in_progress_connections.empty()) - for ([[maybe_unused]] const auto& _ : path_build_queue) - loop->call_soon([this]() { - auto conn_id = "EC-{}"_format(random::random_base32(4)); - establish_and_store_connection(conn_id); - }); - }); -} - -void Network::refresh_snode_cache_complete(std::vector nodes) { - // Shuffle the nodes so we don't have a specific order - std::shuffle(nodes.begin(), nodes.end(), csrng); - - // Update the disk cache if the snode pool was updated - { - std::lock_guard lock{snode_cache_mutex}; - snode_cache = nodes; - last_snode_cache_update = std::chrono::system_clock::now(); - need_write = true; - } - update_disk_cache_throttled(); - - // Reset the cache refresh state - current_snode_cache_refresh_request_id = std::nullopt; - snode_cache_refresh_failure_count = 0; - in_progress_snode_cache_refresh_count = 0; - unused_snode_refresh_nodes = std::nullopt; - snode_refresh_results.reset(); - - // Reset the snode failure counts (assume if the snode refresh includes - // nodes then they are valid) - snode_failure_counts.clear(); - - // Since we've updated the snode cache the swarm cache could be invalid - // so we need to regenerate it (the resulting `all_swarms` needs to be - // stored in ascending order as it is required for the logic to find the - // appropriate swarm for a given pubkey) - all_swarms.clear(); - swarm_cache.clear(); - all_swarms = detail::generate_swarms(nodes); - - // Run any post-refresh processes - for (const auto& callback : after_snode_cache_refresh) - loop->call_soon([cb = std::move(callback)]() { cb(); }); - after_snode_cache_refresh.clear(); - - // Resume any queued path builds - for (const auto& path_type : path_build_queue) { - auto path_id = "P-{}"_format(random::random_base32(4)); - in_progress_path_builds[path_id] = path_type; - loop->call_soon([this, path_type, path_id]() { build_path(path_id, path_type); }); - } - path_build_queue.clear(); -} - -void Network::refresh_snode_cache_from_seed_nodes(std::string request_id, bool reset_unused_nodes) { - if (suspended) { - log::info(cat, "Ignoring snode cache refresh as network is suspended ({}).", request_id); - return; - } - - // Only allow a single cache refresh at a time (this gets cleared in `_close_connections` so if - // it happens to loop after going to, and returning from, the background a subsequent refresh - // won't be blocked) - if (current_snode_cache_refresh_request_id && - current_snode_cache_refresh_request_id != request_id) { - log::info( - cat, - "Snode cache refresh from seed node {} ignored as it doesn't match the current " - "refresh id ({}).", - request_id, - current_snode_cache_refresh_request_id.value_or("NULL")); - return; - } - - // If the unused nodes is empty then reset it (if we are refreshing from seed nodes it means the - // local cache is not usable so we are just going to have to call this endlessly until it works) - if (reset_unused_nodes || !unused_snode_refresh_nodes || unused_snode_refresh_nodes->empty()) { - log::info( - cat, - "Existing cache is insufficient, refreshing from seed nodes ({}).", - request_id); - - // Shuffle to ensure we pick random nodes to fetch from - unused_snode_refresh_nodes = (use_testnet ? seed_nodes_testnet : seed_nodes_mainnet); - std::shuffle(unused_snode_refresh_nodes->begin(), unused_snode_refresh_nodes->end(), csrng); - } - - auto target_node = unused_snode_refresh_nodes->back(); - unused_snode_refresh_nodes->pop_back(); - - establish_connection( - request_id, - target_node, - 3s, - [this, request_id](connection_info info, std::optional) { - // If we failed to get a connection then try again after a delay (may as well try - // indefinitely because there is no way to recover from this issue) - if (!info.is_valid()) { - snode_cache_refresh_failure_count++; - auto cache_refresh_retry_delay = retry_delay(snode_cache_refresh_failure_count); - log::error( - cat, - "Failed to connect to seed node to refresh snode cache, will retry " - "after {}ms ({}).", - cache_refresh_retry_delay.count(), - request_id); - return loop->call_later(cache_refresh_retry_delay, [this, request_id]() { - refresh_snode_cache_from_seed_nodes(request_id, false); - }); - } - - get_service_nodes( - request_id, - info, - std::nullopt, - [this, request_id]( - std::vector nodes, std::optional error) { - // If we got no nodes then we will need to try again - if (nodes.empty()) { - snode_cache_refresh_failure_count++; - auto cache_refresh_retry_delay = - retry_delay(snode_cache_refresh_failure_count); - log::error( - cat, - "Failed to retrieve nodes from seed node to refresh cache " - "due to error: {}, will retry after {}ms ({}).", - error.value_or("Unknown Error"), - cache_refresh_retry_delay.count(), - request_id); - return loop->call_later( - cache_refresh_retry_delay, [this, request_id]() { - refresh_snode_cache_from_seed_nodes(request_id, false); - }); - } - - log::info( - cat, - "Refreshing snode cache from seed nodes completed with {} " - "nodes ({}).", - nodes.size(), - request_id); - seed_node_cache_size = nodes.size(); - refresh_snode_cache_complete(nodes); - }); - }); -} - -void Network::refresh_snode_cache(std::optional existing_request_id) { - auto request_id = existing_request_id.value_or("RSC-{}"_format(random::random_base32(4))); - - if (suspended) { - log::info(cat, "Ignoring snode cache refresh as network is suspended ({}).", request_id); - return; - } - - // Only allow a single cache refresh at a time (this gets cleared in `_close_connections` so if - // it happens to loop after going to, and returning from, the background a subsequent refresh - // won't be blocked) - if (current_snode_cache_refresh_request_id && - current_snode_cache_refresh_request_id != request_id) { - log::info( - cat, - "Snode cache refresh {} ignored due to in progress refresh ({}).", - request_id, - current_snode_cache_refresh_request_id.value_or("NULL")); - return; - } - - // We are starting a new cache refresh so store an identifier for it (we also initialise - // `snode_refresh_results` so we can use it to track the results from the different requests) - if (!current_snode_cache_refresh_request_id) { - log::info(cat, "Refreshing snode cache ({}).", request_id); - snode_cache_refresh_failure_count = 0; - in_progress_snode_cache_refresh_count = 0; - current_snode_cache_refresh_request_id = request_id; - snode_refresh_results = std::make_shared>>(); - } - - // If we don't have enough nodes in the unused nodes then refresh it - if (unused_nodes.size() < min_snode_cache_size()) - unused_nodes = get_unused_nodes(); - - // If we still don't have enough nodes in the unused nodes it likely means we didn't - // have enough nodes in the cache so instead just fetch from the seed nodes (which is - // a trusted source so we can update the cache from a single response) - if (unused_nodes.size() < min_snode_cache_size()) - return refresh_snode_cache_from_seed_nodes(request_id, true); - - // Target an unused node and increment the in progress refresh counter - auto target_node = unused_nodes.back(); - unused_nodes.pop_back(); - in_progress_snode_cache_refresh_count++; - - // If there are still more concurrent refresh_snode_cache requests we want to trigger then - // trigger the next one to run in the next run loop - if (in_progress_snode_cache_refresh_count < num_snodes_to_refresh_cache_from) - loop->call_soon([this, request_id]() { refresh_snode_cache(request_id); }); - - // Prepare and send the request to retrieve service nodes - nlohmann::json payload{ - {"method", "oxend_request"}, - {"params", - {{"endpoint", "get_service_nodes"}, - {"params", detail::get_service_nodes_params(std::nullopt)}}}, - }; - auto info = request_info::make( - target_node, - to_vector(payload.dump()), - std::nullopt, - quic::DEFAULT_TIMEOUT, - std::nullopt, - PathType::standard, - request_id); - _send_onion_request( - info, - [this, request_id]( - bool success, - bool timeout, - int16_t, - std::vector>, - std::optional response) { - // If the 'snode_refresh_results' value doesn't exist it means we have already - // completed/cancelled this snode cache refresh and have somehow gotten into an - // invalid state, so just ignore this request - if (!snode_refresh_results) { - log::warning( - cat, - "Ignoring snode cache response after cache update already completed " - "({}).", - request_id); - return; - } - - try { - if (!success || timeout || !response) - throw std::runtime_error{response.value_or("Unknown error.")}; - - nlohmann::json response_json = nlohmann::json::parse(*response); - std::vector result = - detail::process_get_service_nodes_response(response_json); - snode_refresh_results->emplace_back(result); - - // Update the in progress request count - in_progress_snode_cache_refresh_count--; - } catch (const std::exception& e) { - // The request failed so increment the failure counter and retry after a short - // delay - snode_cache_refresh_failure_count++; - - auto cache_refresh_retry_delay = retry_delay(snode_cache_refresh_failure_count); - log::error( - cat, - "Failed to retrieve nodes from one target when refreshing cache due to " - "error: {}, Will try another target after {}ms ({}).", - e.what(), - cache_refresh_retry_delay.count(), - request_id); - return loop->call_later(cache_refresh_retry_delay, [this, request_id]() { - refresh_snode_cache(request_id); - }); - } - - // If we haven't received all results then do nothing - if (snode_refresh_results->size() != num_snodes_to_refresh_cache_from) { - log::info( - cat, - "Received snode cache refresh result {}/{} ({}).", - snode_refresh_results->size(), - num_snodes_to_refresh_cache_from, - request_id); - return; - } - - auto any_nodes_request_failed = std::any_of( - snode_refresh_results->begin(), - snode_refresh_results->end(), - [](const auto& n) { return n.empty(); }); - - // If the current cache is still usable just send a warning and don't bother - // retrying - if (any_nodes_request_failed) { - log::warning(cat, "Failed to refresh snode cache ({}).", request_id); - current_snode_cache_refresh_request_id = std::nullopt; - snode_cache_refresh_failure_count = 0; - in_progress_snode_cache_refresh_count = 0; - snode_refresh_results.reset(); - return; - } - - // Sort the vectors (so make it easier to find the intersection) - auto compare_service_nodes = [](const service_node& a, const service_node& b) { - if (auto cmp = quic::Address(a) <=> quic::Address(b); cmp != 0) - return cmp < 0; - - return std::tie(a.get_remote_key(), a.swarm_id, a.storage_server_version) < - std::tie(b.get_remote_key(), b.swarm_id, b.storage_server_version); - }; - - for (auto& nodes : *snode_refresh_results) - std::stable_sort(nodes.begin(), nodes.end(), compare_service_nodes); - - auto nodes = (*snode_refresh_results)[0]; - - // If we triggered multiple requests then get the intersection of all vectors - if (snode_refresh_results->size() > 1) { - for (size_t i = 1; i < snode_refresh_results->size(); ++i) { - std::vector temp; - std::set_intersection( - nodes.begin(), - nodes.end(), - (*snode_refresh_results)[i].begin(), - (*snode_refresh_results)[i].end(), - std::back_inserter(temp), - compare_service_nodes); - nodes = std::move(temp); - } - } - - log::info( - cat, - "Refreshing snode cache completed with {} nodes ({}).", - nodes.size(), - request_id); - refresh_snode_cache_complete(nodes); - }); -} - -void Network::build_path(std::string path_id, PathType path_type) { - if (suspended) { - log::info(cat, "Ignoring build_path call as network is suspended."); - return; - } - - auto path_name = path_type_name(path_type, single_path_mode); - - // If we don't have an unused connection for the first hop then enqueue the path build and - // establish a new connection - if (unused_connections.empty()) { - log::info( - cat, - "No unused connections available to build {} path, creating new connection for {}.", - path_name, - path_id); - path_build_queue.emplace_back(path_type); - in_progress_path_builds.erase(path_id); - return loop->call_soon([this, path_id]() { establish_and_store_connection(path_id); }); - } - - // Reset the unused nodes list if it's too small - if (unused_nodes.size() < path_size) - unused_nodes = get_unused_nodes(); - - // If we still don't have enough unused nodes then we need to refresh the cache - if (unused_nodes.size() < path_size) { - log::info( - cat, "Re-queing {} path build due to insufficient nodes ({}).", path_name, path_id); - path_build_failures = 0; - path_build_queue.emplace_back(path_type); - in_progress_path_builds.erase(path_id); - return loop->call_soon([this]() { refresh_snode_cache(); }); - } - - // Build the path - log::info(cat, "Building {} path ({}).", path_name, path_id); - in_progress_path_builds[path_id] = path_type; - - auto conn_info = std::move(unused_connections.front()); - unused_connections.pop_front(); - std::vector path_nodes = {conn_info.node}; - - while (path_nodes.size() < path_size) { - if (unused_nodes.empty()) { - // Log the error and try build again after a slight delay - log::info( - cat, - "Unable to build {} path due to lack of suitable unused nodes ({}).", - path_name, - path_id); - - // Delay the next path build attempt based on the error we received - path_build_failures++; - unused_connections.push_front(std::move(conn_info)); - auto delay = retry_delay(path_build_failures); - loop->call_later( - delay, [this, path_id, path_type]() { build_path(path_id, path_type); }); - return; - } - - // Grab the next unused node to continue building the path - auto node = unused_nodes.back(); - unused_nodes.pop_back(); - - // Ensure we don't put two nodes with the same IP into the same path - auto snode_with_ip_it = std::find_if( - path_nodes.begin(), path_nodes.end(), [&node](const auto& existing_node) { - return existing_node.to_ipv4() == node.to_ipv4(); - }); - - if (snode_with_ip_it == path_nodes.end()) - path_nodes.push_back(node); - } - - // Store the new path - auto path = onion_path{path_id, std::move(conn_info), path_nodes, 0}; - paths[path_type].emplace_back(path); - in_progress_path_builds.erase(path_id); - - // Log that a path was built - log::info( - cat, - "Built new onion request path [{}], now have {} {} path(s) ({}).", - path.to_string(), - paths[path_type].size(), - path_name, - path_id); - - // If the connection info is valid and it's a standard path then update the - // connection status to connected - if (path_type == PathType::standard) { - update_status(ConnectionStatus::connected); - - // If a paths_changed callback was provided then call it - if (paths_changed) { - std::vector> raw_paths; - for (const auto& path : paths[path_type]) - raw_paths.emplace_back(path.nodes); - - paths_changed(raw_paths); - } - } - - // Remove the nodes from unused_nodes which have the same IPs as nodes in - // the final path - std::vector path_ips; - for (const auto& node : path_nodes) - path_ips.emplace_back(node.to_ipv4()); - - std::erase_if(unused_nodes, [&path_ips](const auto& node) { - return std::find(path_ips.begin(), path_ips.end(), node.to_ipv4()) != path_ips.end(); - }); - - // If there are pending requests which this path is valid for then resume them - std::erase_if(request_queue[path_type], [this, &path](const auto& request) { - if (!find_valid_path(request.first, {path})) - return false; - - loop->call_soon([this, info = request.first, cb = std::move(request.second)]() { - _send_onion_request(std::move(info), std::move(cb)); - }); - return true; - }); - - // If there are still pending requests and there are no pending path builds for them then kick - // off a subsequent path build in an effort to resume the remaining requests - if (!request_queue[path_type].empty()) { - auto additional_path_id = "P-{}"_format(random::random_base32(4)); - in_progress_path_builds[additional_path_id] = path_type; - loop->call_soon([this, path_type, additional_path_id] { - build_path(additional_path_id, path_type); - }); - } else - request_queue.erase(path_type); -} - -std::optional Network::find_valid_path( - const request_info info, const std::vector paths) { - if (paths.empty()) - return std::nullopt; - - // Only include paths with valid connections as options - std::vector possible_paths; - std::copy_if( - paths.begin(), paths.end(), std::back_inserter(possible_paths), [&](const auto& path) { - return path.is_valid(); - }); - - // If the request destination is a node then only select a path that doesn't include the IP of - // the destination - if (auto target = detail::node_for_destination(info.destination)) { - std::vector ip_excluded_paths; - std::copy_if( - possible_paths.begin(), - possible_paths.end(), - std::back_inserter(ip_excluded_paths), - [&](const onion_path& p) { return not p.contains_node(*target); }); - - if (single_path_mode && ip_excluded_paths.empty()) - log::warning( - cat, - "Path should have been excluded due to matching IP for {} but network is in " - "single path mode.", - info.request_id); - else - possible_paths = ip_excluded_paths; - } - - if (possible_paths.empty()) - return std::nullopt; - - // Randomise the possible paths (if all paths are equal for the PathSelectionBehaviour then we - // want a random one to be selected) - std::shuffle(possible_paths.begin(), possible_paths.end(), csrng); - - // Select from the possible paths based on the desired behaviour - auto behaviour = path_selection_behaviour(info.path_type); - switch (behaviour) { - case PathSelectionBehaviour::new_or_least_busy: { - auto min_num_paths = min_path_count(info.path_type, single_path_mode); - std::sort( - possible_paths.begin(), possible_paths.end(), [](const auto& a, const auto& b) { - return a.num_pending_requests() < b.num_pending_requests(); - }); - - // If we have already have the min number of paths for this path type, or there is - // a path with no pending requests then return the first path - if (paths.size() >= min_num_paths || possible_paths.front().num_pending_requests() == 0) - return possible_paths.front(); - - // Otherwise we want to build a new path (for this PathSelectionBehaviour the assuption - // is that it'd be faster to build a new path and send the request along that rather - // than use an existing path) - return std::nullopt; - } - - // Random is the default behaviour - case PathSelectionBehaviour::random: return possible_paths.front(); - default: return possible_paths.front(); - } -}; - -void Network::build_path_if_needed(PathType path_type, bool found_path) { - const auto current_paths = paths[path_type]; - - // In `single_path_mode` we never build additional paths - if (current_paths.size() > 0 && single_path_mode) - return; - - // We only want to enqueue a new path build if: - // - We don't have the minimum number of paths for the specified type - // - We don't have any pending builds - // - The current paths are unsuitable for the request - auto min_paths = min_path_count(path_type, single_path_mode); - - // If we have enough existing paths and found a valid path then no need to build more paths - if (found_path && current_paths.size() >= min_paths) - return; - - // Get the number pending paths - auto queued = std::count(path_build_queue.begin(), path_build_queue.end(), path_type); - auto in_progress = std::count_if( - in_progress_path_builds.begin(), - in_progress_path_builds.end(), - [&path_type](const auto& build) { return build.second == path_type; }); - auto pending_paths = (queued + in_progress); - - // If we don't have enough current + pending paths, or the request couldn't be sent then - // kick off a new path build - if ((current_paths.size() + pending_paths) < min_paths || (!found_path && pending_paths == 0)) { - auto path_id = "P-{}"_format(random::random_base32(4)); - build_path(path_id, path_type); - } -} - -// MARK: Direct Requests - -void Network::get_service_nodes( - std::string request_id, - connection_info conn_info, - std::optional limit, - std::function nodes, std::optional error)> - callback) { - log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, request_id); - - if (!conn_info.is_valid()) - return callback({}, "Connection is not valid."); - - oxenc::bt_dict_producer payload; - payload.append("endpoint", "get_service_nodes"); - payload.append("params", detail::get_service_nodes_params(limit).dump()); - - conn_info.add_pending_request(); - conn_info.stream->command( - "oxend_request", - payload.view(), - [this, request_id, conn_info, cb = std::move(callback)](quic::message resp) { - log::trace(cat, "{} got response for {}.", __PRETTY_FUNCTION__, request_id); - std::vector result; - conn_info.remove_pending_request(); - - try { - auto [status_code, body] = validate_response(resp, true); - oxenc::bt_list_consumer result_bencode{body}; - result = detail::process_get_service_nodes_response(result_bencode); - } catch (const std::exception& e) { - return cb({}, e.what()); - } - - // Output the result - cb(result, std::nullopt); - - // After completing a request we should try to clear any pending path drops (just in - // case this request was the final one on a pending path drop) - if (!conn_info.has_pending_requests()) - clear_empty_pending_path_drops(); - }); -} - -// MARK: Swarm Management - -void Network::get_swarm( - session::onionreq::x25519_pubkey swarm_pubkey, - std::function swarm)> callback) { - log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, swarm_pubkey.hex()); - - loop->call([this, swarm_pubkey, cb = std::move(callback)]() { - // If we have a cached swarm then return it - auto cached_swarm = swarm_cache[swarm_pubkey.hex()]; - if (!cached_swarm.second.empty()) - return cb(cached_swarm.first, cached_swarm.second); - - // If we have no snode cache or no swarms then we need to rebuild the cache (which will also - // rebuild the swarms) and run this request again - if (snode_cache.empty() || all_swarms.empty()) { - after_snode_cache_refresh.emplace_back([this, swarm_pubkey, cb = std::move(cb)]() { - get_swarm(swarm_pubkey, std::move(cb)); - }); - return loop->call_soon([this]() { refresh_snode_cache(); }); - } - - // If there is only a single swarm then return it - if (all_swarms.size() == 1) - return cb(all_swarms.front().first, all_swarms.front().second); - - // Generate a swarm_id for the pubkey - const swarm_id_t swarm_id = detail::pubkey_to_swarm_space(swarm_pubkey); - - // Find the right boundary, i.e. first swarm with swarm_id >= res - auto right_it = std::lower_bound( - all_swarms.begin(), all_swarms.end(), swarm_id, [](const auto& s, uint64_t v) { - return s.first < v; - }); - - if (right_it == all_swarms.end()) - // res is > the top swarm_id, meaning it is big and in the wrapping space between last - // and first elements. - right_it = all_swarms.begin(); - - // Our "left" is the one just before that (with wraparound, if right is the first swarm) - auto left_it = std::prev(right_it == all_swarms.begin() ? all_swarms.end() : right_it); - - uint64_t dright = right_it->first - swarm_id; - uint64_t dleft = swarm_id - left_it->first; - auto swarm = &*(dright < dleft ? right_it : left_it); - - // Update the cache with the result - log::info( - cat, - "Found swarm with {} nodes for {}, adding to cache.", - swarm->second.size(), - swarm_pubkey.hex()); - swarm_cache[swarm_pubkey.hex()] = *swarm; - cb(swarm->first, swarm->second); - }); -} - -// MARK: Node Retrieval - -void Network::get_random_nodes( - uint16_t count, std::function nodes)> callback) { - auto request_id = "R-{}"_format(random::random_base32(4)); - log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, request_id); - - loop->call([this, request_id, count, cb = std::move(callback)]() mutable { - // If we don't have sufficient unused nodes then regenerate it - if (unused_nodes.size() < count) - unused_nodes = get_unused_nodes(); - - // If we still don't have sufficient nodes then we need to refresh the snode cache - if (unused_nodes.size() < count) { - after_snode_cache_refresh.emplace_back( - [this, count, cb = std::move(cb)]() { get_random_nodes(count, cb); }); - return loop->call_soon([this]() { refresh_snode_cache(); }); - } - - // Otherwise callback with the requested random number of nodes - auto random_nodes = - std::vector(unused_nodes.begin(), unused_nodes.begin() + count); - unused_nodes.erase(unused_nodes.begin(), unused_nodes.begin() + count); - cb(random_nodes); - }); -} - -// MARK: Request Handling - -void Network::check_request_queue_timeouts(std::optional request_timeout_id_) { - // If the network is suspended then don't bother checking for timeouts - if (suspended) - return; - - // If there is an existing timeout checking loop then we don't want to start a second - if (request_timeout_id != request_timeout_id_) - return; - - // If there wasn't an existing loop id then set it here - if (!request_timeout_id) - request_timeout_id = "RT-{}"_format(random::random_base32(4)); - - // Timeout and remove any pending requests which should timeout based on path build time - auto has_remaining_timeout_requests = false; - auto time_now = std::chrono::system_clock::now(); - - for (auto& [path_type, requests_for_path] : request_queue) - std::erase_if( - requests_for_path, - [&has_remaining_timeout_requests, &time_now](const auto& request) { - // If the request doesn't have a path build timeout then ignore it - if (!request.first.request_and_path_build_timeout) - return false; - - auto duration = std::chrono::duration_cast( - time_now - request.first.creation_time); - - if (duration > *request.first.request_and_path_build_timeout) { - request.second( - false, - true, - error_path_build_timeout, - {content_type_plain_text}, - "Timed out waiting for path build."); - return true; - } - - has_remaining_timeout_requests = true; - return false; - }); - - // If there are no more timeout requests then stop looping here - if (!has_remaining_timeout_requests) { - request_timeout_id = std::nullopt; - return; - } - - // Otherwise schedule the next check - loop->call_later(queued_request_path_build_timeout_frequency, [this]() { - check_request_queue_timeouts(request_timeout_id); - }); -} - -void Network::send_request( - request_info info, connection_info conn_info, network_response_callback_t handle_response) { - log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, info.request_id); - - if (!conn_info.is_valid()) - return handle_response( - false, false, -1, {content_type_plain_text}, "Network is unreachable."); - - std::span payload{}; - - if (info.body) - payload = to_span(*info.body); - - // Calculate the remaining timeout - std::chrono::milliseconds timeout = info.request_timeout; - - if (info.request_and_path_build_timeout) { - auto elapsed_time = std::chrono::duration_cast( - std::chrono::system_clock::now() - info.creation_time); - - timeout = *info.request_and_path_build_timeout - elapsed_time; - - // If the timeout was somehow negative then just fail the request (no point continuing if - // we have already timed out) - if (timeout < std::chrono::milliseconds(0)) - return handle_response( - false, - true, - error_path_build_timeout, - {content_type_plain_text}, - "Path Build Timed Out."); - } - - conn_info.add_pending_request(); - conn_info.stream->command( - info.endpoint, - payload, - timeout, - [this, info, conn_info, cb = std::move(handle_response)](quic::message resp) { - log::trace(cat, "{} got response for {}.", __PRETTY_FUNCTION__, info.request_id); - - std::pair result; - auto& [status_code, body] = result; - conn_info.remove_pending_request(); - - try { - result = validate_response(resp, false); - } catch (const status_code_exception& e) { - return handle_errors( - info, - conn_info, - resp.timed_out, - e.status_code, - e.headers, - e.what(), - cb); - } catch (const std::exception& e) { - return handle_errors( - info, - conn_info, - resp.timed_out, - -1, - {content_type_plain_text}, - e.what(), - cb); - } - - cb(true, false, status_code, {}, body); - - // After completing a request we should try to clear any pending path drops (just in - // case this request was the final one on a pending path drop) - if (!conn_info.has_pending_requests()) - clear_empty_pending_path_drops(); - }); -} - -void Network::send_onion_request( - onionreq::network_destination destination, - std::optional> body, - std::optional swarm_pubkey, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout, - PathType type) { - _send_onion_request( - request_info::make( - std::move(destination), - std::move(body), - std::move(swarm_pubkey), - request_timeout, - request_and_path_build_timeout, - type), - std::move(handle_response)); -} - -void Network::_send_onion_request(request_info info, network_response_callback_t handle_response) { - auto path_name = path_type_name(info.path_type, single_path_mode); - log::trace(cat, "{} called for {} path ({}).", __PRETTY_FUNCTION__, path_name, info.request_id); - - // Try to retrieve a valid path for this request, if we can't get one then add the request to - // the queue to be run once a path for it has successfully been built - auto path = loop->call_get([this, info]() { - auto result = find_valid_path(info, paths[info.path_type]); - loop->call_soon([this, path_type = info.path_type, found_path = result.has_value()]() { - build_path_if_needed(path_type, found_path); - }); - return result; - }); - - if (!path) { - return loop->call([this, info = std::move(info), cb = std::move(handle_response)]() { - // If the network is suspended then fail immediately - if (suspended) - return cb( - false, - false, - error_network_suspended, - {content_type_plain_text}, - "Network is suspended."); - - request_queue[info.path_type].emplace_back(std::move(info), std::move(cb)); - - // If the request has a path_build_timeout then start the timeout check loop - if (info.request_and_path_build_timeout) - loop->call_later(queued_request_path_build_timeout_frequency, [this]() { - check_request_queue_timeouts(); - }); - }); - } - - log::trace(cat, "{} got {} path for {}.", __PRETTY_FUNCTION__, path_name, info.request_id); - - // Construct the onion request - auto builder = Builder::make(info.destination, path->nodes); - try { - builder.generate(info); - } catch (const std::exception& e) { - log::warning(cat, "Builder exception: {}", e.what()); - return handle_response( - false, false, error_building_onion_request, {content_type_plain_text}, e.what()); - } - - // Actually send the request - send_request( - info, - path->conn_info, - [this, - builder = std::move(builder), - info, - path = *path, - cb = std::move(handle_response)]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - log::trace(cat, "{} got response for {}.", __PRETTY_FUNCTION__, info.request_id); - - // If the request was reported as a failure or a timeout then we - // will have already handled the errors so just trigger the callback - if (!success || timeout) - return cb(success, timeout, status_code, headers, response); - - try { - // Ensure the response is long enough to be processed, if not - // then handle it as an error - if (!ResponseParser::response_long_enough(builder.enc_type, response->size())) - throw status_code_exception{ - status_code, - {content_type_plain_text}, - "Response is too short to be an onion request response: " + - *response}; - - // Otherwise, process the onion request response - std::tuple< - int16_t, - std::vector>, - std::optional> - processed_response; - - // The SnodeDestination runs via V3 onion requests and the - // ServerDestination runs via V4 - if (std::holds_alternative(info.destination)) - processed_response = process_v3_onion_response(builder, *response); - else if (std::holds_alternative(info.destination)) - processed_response = process_v4_onion_response(builder, *response); - - // If we got a non 2xx status code, return the error - auto& [processed_status_code, processed_headers, processed_body] = - processed_response; - if (processed_status_code < 200 || processed_status_code > 299) - throw status_code_exception{ - processed_status_code, - {content_type_plain_text}, - processed_body.value_or("Request returned " - "non-success status " - "code.")}; - - // For debugging purposes we want to add a log if this was a successful request - // after we did an automatic retry - detail::log_retry_result_if_needed(info, single_path_mode); - - // Try process the body in case it was a batch request which - // failed - std::optional results; - if (processed_body) { - try { - auto processed_body_json = nlohmann::json::parse(*processed_body); - - // If it wasn't a batch/sequence request then assume it - // was successful and return no error - if (processed_body_json.contains("results")) - results = processed_body_json["results"]; - } catch (...) { - } - } - - // If there was no 'results' array then it wasn't a batch - // request so we can stop here and return - if (!results) - return cb( - true, - false, - processed_status_code, - processed_headers, - processed_body); - - // Otherwise we want to check if all of the results have the - // same status code and, if so, handle that failure case - // (default the 'error_body' to the 'processed_body' in case we - // don't get an explicit error) - int16_t single_status_code = -1; - std::vector> single_headers = { - content_type_plain_text}; - std::optional error_body = processed_body; - for (const auto& result : results->items()) { - if (result.value().contains("code") && result.value()["code"].is_number() && - (single_status_code == -1 || - result.value()["code"].get() != single_status_code)) - single_status_code = result.value()["code"].get(); - else { - // Either there was no code, or the code was different - // from a former code in which case there wasn't an - // individual detectable error (ie. it needs specific - // handling) so return no error - single_status_code = 200; - break; - } - - if (result.value().contains("headers")) { - single_headers = {}; - auto header_vals = result.value()["headers"]; - - for (auto it = header_vals.begin(); it != header_vals.end(); ++it) - single_headers.emplace_back(it.key(), it.value()); - } - - if (result.value().contains("body") && result.value()["body"].is_string()) - error_body = result.value()["body"].get(); - } - - // If all results contained the same error then handle it as a - // single error - if (single_status_code < 200 || single_status_code > 299) - throw status_code_exception{ - single_status_code, - single_headers, - error_body.value_or("Sub-request returned " - "non-success status code.")}; - - // Otherwise some requests succeeded and others failed so - // succeed with the processed data - return cb( - true, false, processed_status_code, processed_headers, processed_body); - } catch (const status_code_exception& e) { - handle_errors( - info, path.conn_info, false, e.status_code, e.headers, e.what(), cb); - } catch (const std::exception& e) { - handle_errors( - info, - path.conn_info, - false, - -1, - {content_type_plain_text}, - e.what(), - cb); - } - }); -} - -void Network::upload_file_to_server( - std::vector data, - onionreq::ServerDestination server, - std::optional file_name, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout) { - std::vector> headers; - std::unordered_set existing_keys; - - if (server.headers) - for (auto& [key, value] : *server.headers) { - headers.emplace_back(key, value); - existing_keys.insert(key); - } - - // Add the required headers if they weren't provided - if (existing_keys.find("Content-Disposition") == existing_keys.end()) - headers.emplace_back( - "Content-Disposition", - (file_name ? "attachment; filename=\"{}\""_format(*file_name) : "attachment")); - - if (existing_keys.find("Content-Type") == existing_keys.end()) - headers.emplace_back("Content-Type", "application/octet-stream"); - - send_onion_request( - ServerDestination{ - server.protocol, - server.host, - server.endpoint, - server.x25519_pubkey, - server.port, - headers, - server.method}, - data, - std::nullopt, - handle_response, - request_timeout, - request_and_path_build_timeout, - PathType::upload); -} - -void Network::download_file( - std::string_view download_url, - session::onionreq::x25519_pubkey x25519_pubkey, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout) { - const auto& [proto, host, port, path] = parse_url(download_url); - - if (!path) - throw std::invalid_argument{"Invalid URL provided: Missing path"}; - - download_file( - ServerDestination{proto, host, *path, x25519_pubkey, port, std::nullopt, "GET"}, - handle_response, - request_timeout, - request_and_path_build_timeout); -} - -void Network::download_file( - onionreq::ServerDestination server, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout) { - send_onion_request( - server, - std::nullopt, - std::nullopt, - handle_response, - request_timeout, - request_and_path_build_timeout, - PathType::download); -} - -void Network::get_client_version( - Platform platform, - onionreq::ed25519_seckey seckey, - network_response_callback_t handle_response, - std::chrono::milliseconds request_timeout, - std::optional request_and_path_build_timeout) { - std::string endpoint; - - switch (platform) { - case Platform::android: endpoint = "/session_version?platform=android"; break; - case Platform::desktop: endpoint = "/session_version?platform=desktop"; break; - case Platform::ios: endpoint = "/session_version?platform=ios"; break; - } - - // Generate the auth signature - auto blinded_keys = blind_version_key_pair(to_span(seckey.view())); - auto timestamp = std::chrono::duration_cast( - (std::chrono::system_clock::now()).time_since_epoch()) - .count(); - auto signature = blind_version_sign(to_span(seckey.view()), platform, timestamp); - auto pubkey = x25519_pubkey::from_hex(file_server_pubkey); - std::string blinded_pk_hex; - blinded_pk_hex.reserve(66); - blinded_pk_hex += "07"; - oxenc::to_hex( - blinded_keys.first.begin(), - blinded_keys.first.end(), - std::back_inserter(blinded_pk_hex)); - - auto headers = std::vector>{}; - headers.emplace_back("X-FS-Pubkey", blinded_pk_hex); - headers.emplace_back("X-FS-Timestamp", "{}"_format(timestamp)); - headers.emplace_back("X-FS-Signature", oxenc::to_base64(signature.begin(), signature.end())); - - send_onion_request( - ServerDestination{ - "http", std::string(file_server), endpoint, pubkey, 80, headers, "GET"}, - std::nullopt, - pubkey, - handle_response, - request_timeout, - request_and_path_build_timeout, - PathType::standard); -} - -// MARK: Response Handling - -std::tuple>, std::optional> -Network::process_v3_onion_response(Builder builder, std::string response) { - std::string base64_iv_and_ciphertext; - try { - nlohmann::json response_json = nlohmann::json::parse(response); - - if (!response_json.contains("result") || !response_json["result"].is_string()) - throw std::runtime_error{"JSON missing result field."}; - - base64_iv_and_ciphertext = response_json["result"].get(); - } catch (...) { - base64_iv_and_ciphertext = response; - } - - if (!oxenc::is_base64(base64_iv_and_ciphertext)) - throw std::runtime_error{"Invalid base64 encoded IV and ciphertext."}; - - std::vector iv_and_ciphertext; - oxenc::from_base64( - base64_iv_and_ciphertext.begin(), - base64_iv_and_ciphertext.end(), - std::back_inserter(iv_and_ciphertext)); - auto parser = ResponseParser(builder); - auto result = parser.decrypt(iv_and_ciphertext); - auto result_json = nlohmann::json::parse(result); - int16_t status_code; - std::vector> headers; - std::string body; - - if (result_json.contains("status_code") && result_json["status_code"].is_number()) - status_code = result_json["status_code"].get(); - else if (result_json.contains("status") && result_json["status"].is_number()) - status_code = result_json["status"].get(); - else - throw std::runtime_error{"Invalid JSON response, missing required status_code field."}; - - if (result_json.contains("headers")) { - auto header_vals = result_json["headers"]; - - for (auto it = header_vals.begin(); it != header_vals.end(); ++it) - headers.emplace_back(it.key(), it.value()); - } - - if (result_json.contains("body") && result_json["body"].is_string()) - body = result_json["body"].get(); - else - body = result_json.dump(); - - return {status_code, headers, body}; -} - -std::tuple>, std::optional> -Network::process_v4_onion_response(Builder builder, std::string response) { - auto response_data = to_vector(response); - auto parser = ResponseParser(builder); - auto result = parser.decrypt(response_data); - - // Process the bencoded response - oxenc::bt_list_consumer result_bencode{to_span(result)}; - - if (result_bencode.is_finished() || !result_bencode.is_string()) - throw std::runtime_error{"Invalid bencoded response"}; - - auto response_info_string = result_bencode.consume_string(); - int16_t status_code; - std::vector> headers; - nlohmann::json response_info_json = nlohmann::json::parse(response_info_string); - - if (response_info_json.contains("code") && response_info_json["code"].is_number()) - status_code = response_info_json["code"].get(); - else - throw std::runtime_error{"Invalid JSON response, missing required code field."}; - - if (response_info_json.contains("headers")) { - auto header_vals = response_info_json["headers"]; - - for (auto it = header_vals.begin(); it != header_vals.end(); ++it) - headers.emplace_back(it.key(), it.value()); - } - - if (result_bencode.is_finished()) - return {status_code, headers, std::nullopt}; - - return {status_code, headers, result_bencode.consume_string()}; -} - -// MARK: Error Handling - -std::pair Network::validate_response(quic::message resp, bool is_bencoded) { - std::string body = std::string(resp.body()); - - if (resp.timed_out) - throw std::runtime_error{"Timed out"}; - if (resp.is_error()) - throw std::runtime_error{body.empty() ? "Unknown error" : body}; - - if (is_bencoded) { - // Process the bencoded response - oxenc::bt_list_consumer result_bencode{body}; - - if (result_bencode.is_finished() || !result_bencode.is_integer()) - throw std::runtime_error{"Invalid bencoded response"}; - - // If we have a status code that is not in the 2xx range, return the error - auto status_code = result_bencode.consume_integer(); - - if (status_code < 200 || status_code > 299) { - if (result_bencode.is_finished() || !result_bencode.is_string()) - throw status_code_exception{ - status_code, - {content_type_plain_text}, - "Request failed with status code: " + std::to_string(status_code)}; - - throw status_code_exception{ - status_code, {content_type_plain_text}, result_bencode.consume_string()}; - } - - // Can't convert the data to a string so just return the response body itself - return {status_code, body}; - } - - // Default to a 200 success if the response is empty but didn't timeout or error - int16_t status_code = 200; - std::pair content_type; - std::string response_string; - - try { - nlohmann::json response_json = nlohmann::json::parse(body); - content_type = content_type_json; - - if (response_json.is_array() && response_json.size() == 2) { - status_code = response_json[0].get(); - response_string = response_json[1].dump(); - } else - response_string = body; - } catch (...) { - response_string = body; - content_type = content_type_plain_text; - } - - if (status_code < 200 || status_code > 299) - throw status_code_exception{status_code, {content_type}, response_string}; - - return {status_code, response_string}; -} - -void Network::drop_path_when_empty(std::string id, PathType path_type, onion_path path) { - paths_pending_drop.emplace_back(path, path_type); - paths[path_type].erase( - std::remove(paths[path_type].begin(), paths[path_type].end(), path), - paths[path_type].end()); - - std::string reason; - if (id == path.id) - reason = "connection being closed"; - else - reason = "failure threshold passed with {} failure"_format(id); - - log::info( - cat, - "Flagging path {} [{}] to be dropped due to {}, now have {} {} paths(s).", - path.id, - path.to_string(), - reason, - paths[path_type].size(), - path_type_name(path_type, single_path_mode)); - - // Clear any paths which are waiting to be dropped - clear_empty_pending_path_drops(); -} - -void Network::clear_empty_pending_path_drops() { - auto remaining_standard_paths = 0; - std::erase_if(paths_pending_drop, [this, &remaining_standard_paths](const auto& path_info) { - // If the path is no longer valid then we can drop it - if (!path_info.first.has_pending_requests()) { - log::info( - cat, - "Removing flagged {} path {} that {}: [{}].", - path_type_name(path_info.second, single_path_mode), - path_info.first.id, - (path_info.first.is_valid() ? "has no remaining requests" - : "is no longer valid"), - path_info.first.to_string()); - return true; - } - remaining_standard_paths++; - return false; - }); - - // Update the network status if we've removed all standard paths - if (remaining_standard_paths == 0 && paths[PathType::standard].empty()) - update_status(ConnectionStatus::disconnected); -} - -void Network::handle_errors( - request_info info, - connection_info conn_info, - bool timeout_, - int16_t status_code_, - std::vector> headers_, - std::optional response, - std::optional handle_response) { - bool timeout = timeout_; - auto status_code = status_code_; - auto headers = headers_; - auto path_name = path_type_name(info.path_type, single_path_mode); - - // There is an issue which can occur where we get invalid data back and are unable to decrypt - // it, if we do see this behaviour then we want to retry the request on the off chance it - // resolves itself - // - // When testing this case the retry always resulted in a 421 error, if that occurs we want to go - // through the standard 421 behaviour (which, in this case, would involve a 3rd retry against - // another node in the swarm to confirm the redirect) - if (!info.retry_reason && response && *response == session::onionreq::decryption_failed_error) { - log::info( - cat, - "Received decryption failure in request {} on {} path, retrying.", - info.request_id, - path_name); - auto updated_info = info; - updated_info.retry_reason = request_info::RetryReason::decryption_failure; - return loop->call_soon([this, updated_info, cb = std::move(*handle_response)]() { - _send_onion_request(updated_info, std::move(cb)); - }); - } - - // A number of server errors can return HTML data but no status code, we want to extract those - // cases so they can be handled properly below - if (status_code == -1 && response) { - const std::unordered_map> response_map = { - {"400 Bad Request", {400, false}}, - {"403 Forbidden", {403, false}}, - {"500 Internal Server Error", {500, false}}, - {"502 Bad Gateway", {502, false}}, - {"503 Service Unavailable", {503, false}}, - {"504 Gateway Timeout", {504, true}}, - }; - - for (const auto& [prefix, result] : response_map) { - if (response->starts_with(prefix)) { - status_code = result.first; - timeout = (timeout || result.second); - } - } - } - - // In trace mode log all error info - log::trace( - cat, - "Received network error in request {} on {} path, status_code: {}, timeout: {}, " - "response: {}", - info.request_id, - path_name, - status_code, - timeout, - response.value_or("(No Response)")); - - // A timeout could be caused because the destination is unreachable rather than the the path - // (eg. if a user has an old SOGS which is no longer running on their device they will get a - // timeout) so if we timed out while sending a proxied request we assume something is wrong on - // the server side and don't update the path/snode state - if (!info.node_destination && timeout) { - if (handle_response) - return (*handle_response)(false, true, status_code, headers, response); - return; - } - - switch (status_code) { - // A 404 or a 400 is likely due to a bad/missing SOGS or file so - // shouldn't mark a path or snode as invalid - case 400: - case 404: - if (handle_response) - return (*handle_response)(false, false, status_code, headers, response); - return; - - // The user's clock is out of sync with the service node network (a - // snode will return 406, but V4 onion requests returns a 425) - case 406: - case 425: - if (handle_response) - return (*handle_response)(false, false, status_code, headers, response); - return; - - // The snode is reporting that it isn't associated with the given public key anymore. If - // this is the first 421 then we want to try another node in the swarm (just in case it - // was reported incorrectly). If this is the second occurrence of the 421 then the - // client needs to update the swarm (if the response contains updated swarm data), or - // increment the path failure count. - case 421: - try { - // If there is no response handler or no swarm information was provided then we - // should just replace the swarm - auto target = detail::node_for_destination(info.destination); - - if (!handle_response || !info.swarm_pubkey || !target) - throw std::invalid_argument{"Unable to handle redirect."}; - - switch (info.retry_reason.value_or(request_info::RetryReason::none)) { - // If this was the first 421 then we want to retry using another node in the - // swarm to get confirmation that we should switch to a different swarm - case request_info::RetryReason::none: - case request_info::RetryReason::decryption_failure: { - auto cached_swarm = swarm_cache[info.swarm_pubkey->hex()]; - - if (cached_swarm.second.empty()) - throw std::invalid_argument{ - "Unable to handle redirect due to lack of swarm."}; - - std::vector swarm_copy; - std::copy_if( - cached_swarm.second.begin(), - cached_swarm.second.end(), - std::back_inserter(swarm_copy), - [&target = *target](const auto& node) { return node != target; }); - std::shuffle(swarm_copy.begin(), swarm_copy.end(), csrng); - - if (swarm_copy.empty()) - throw std::invalid_argument{"No other nodes in the swarm."}; - - log::info( - cat, - "Received 421 error in request {} on {} path, retrying once before " - "updating swarm.", - info.request_id, - path_name); - auto updated_info = info; - updated_info.destination = swarm_copy.front(); - updated_info.retry_reason = request_info::RetryReason::redirect; - return loop->call_soon( - [this, updated_info, cb = std::move(*handle_response)]() { - _send_onion_request(updated_info, std::move(cb)); - }); - } - - // If we got a second 421 then it's likely that our cached swarm is out of date - // so we need to refresh our snode cache, regenerate our swarm and try one more - // time - case request_info::RetryReason::redirect: - log::info( - cat, - "Received second 421 error in request {} on {} path, refreshing " - "snode cache before trying one final time.", - info.request_id, - path_name); - after_snode_cache_refresh.emplace_back([this, - swarm_pubkey = info.swarm_pubkey, - info, - status_code, - headers, - response, - cb = std::move( - *handle_response)]() { - get_swarm( - *swarm_pubkey, - [this, - info, - status_code, - headers, - response, - cb = std::move(cb)]( - swarm_id_t, std::vector swarm) { - auto target = - detail::node_for_destination(info.destination); - - std::vector swarm_copy; - std::copy_if( - swarm.begin(), - swarm.end(), - std::back_inserter(swarm_copy), - [&target = *target](const auto& node) { - return node != target; - }); - std::shuffle(swarm_copy.begin(), swarm_copy.end(), csrng); - - // If there are no nodes in the swarm then don't bother - // trying again - if (swarm_copy.empty()) { - log::info( - cat, - "Second 421 retry for request {} resulted in " - "another 421 and had no other nodes in the " - "swarm.", - info.request_id); - return cb(false, false, status_code, headers, response); - } - - auto updated_info = info; - updated_info.retry_reason = - request_info::RetryReason::redirect_swarm_refresh; - updated_info.destination = swarm_copy.front(); - loop->call_soon([this, updated_info, cb = std::move(cb)]() { - _send_onion_request(updated_info, std::move(cb)); - }); - }); - }); - return loop->call_soon([this, request_id = info.request_id]() { - refresh_snode_cache(request_id); - }); - - // If we got a 421 after refreshing the swarm then there is some bigger issue - // (ie. our local swarm generation logic differs from the server or we are - // getting invalid swarm ids back when updating our cache) so the best we can - // do is handle this like any other error - case request_info::RetryReason::redirect_swarm_refresh: - log::info( - cat, - "Received another 421 for request {} after refreshing the snode " - "cache, failing request.", - info.request_id); - break; - - default: break; // Unhandled case should just behave like any other error - } - } catch (...) { - } - - // If we weren't able to retry or redirect the swarm then handle this like any other - // error - break; - - case 500: - case 504: - // If we are making a proxied request to a server then assume 500 errors are occurring - // on the server rather than in the service node network and don't update the path/snode - // state - if (!info.node_destination) { - if (handle_response) - return (*handle_response)(false, timeout, status_code, headers, response); - return; - } - break; - - default: break; - } - - // Retrieve the path for the connection_info (no paths share the same guard node so we can use - // that to find it) - std::optional path; - auto is_active_path = true; - - auto path_it = std::find_if( - paths[info.path_type].begin(), - paths[info.path_type].end(), - [guard_node = conn_info.node](const auto& path) { - return !path.nodes.empty() && path.nodes.front() == guard_node; - }); - - // Try to retrieve the path this request was on, if it's not in an active or pending drop path - // then log a warning (as this shouldn't be possible) and call the callback - if (path_it != paths[info.path_type].end()) - path = *path_it; - else { - auto path_pending_drop_it = std::find_if( - paths_pending_drop.begin(), - paths_pending_drop.end(), - [guard_node = conn_info.node](const auto& path_info) { - return !path_info.first.nodes.empty() && - path_info.first.nodes.front() == guard_node; - }); - - if (path_pending_drop_it == paths_pending_drop.end()) { - log::warning( - cat, - "Request {} failed but {} path with guard {} already dropped.", - info.request_id, - path_name, - conn_info.node.to_string()); - - if (handle_response) - (*handle_response)(false, timeout, status_code, headers, response); - return; - } - path = path_pending_drop_it->first; - is_active_path = false; - } - - // Update the failure counts and paths - auto updated_path = *path; - bool found_invalid_node = false; - - if (response) { - std::optional ed25519PublicKey; - - // Check if the response has one of the 'node_not_found' prefixes - if (response->starts_with(node_not_found_prefix)) - ed25519PublicKey = {response->data() + node_not_found_prefix.size()}; - else if (response->starts_with(node_not_found_prefix_no_status)) - ed25519PublicKey = {response->data() + node_not_found_prefix_no_status.size()}; - - // If we found a result then try to extract the pubkey and process it - if (ed25519PublicKey && ed25519PublicKey->size() == 64 && - oxenc::is_hex(*ed25519PublicKey)) { - session::onionreq::ed25519_pubkey edpk = - session::onionreq::ed25519_pubkey::from_hex(*ed25519PublicKey); - auto edpk_view = to_span(edpk.view()); - - auto snode_it = std::find_if( - updated_path.nodes.begin(), - updated_path.nodes.end(), - [&edpk_view](const auto& node) { - return to_string_view(node.view_remote_key()) == to_string_view(edpk_view); - }); - - if (snode_it != updated_path.nodes.end()) { - found_invalid_node = true; - - // If we get an explicit node failure then we should just immediately drop it and - // try to repair the existing path by replacing the bad node with another one - snode_failure_counts[snode_it->to_string()] = snode_failure_threshold; - - try { - // If the node that's gone bad is the guard node then we just have to - // drop the path - if (snode_it == updated_path.nodes.begin()) - throw std::runtime_error{"Cannot recover if guard node is bad"}; - - if (unused_nodes.empty()) - throw std::runtime_error{"No remaining nodes"}; - - auto target_node = unused_nodes.back(); - unused_nodes.pop_back(); - - std::replace( - updated_path.nodes.begin(), - updated_path.nodes.end(), - *snode_it, - target_node); - log::info( - cat, - "Found bad node ({}) in {} path, replacing node ({}).", - *ed25519PublicKey, - path_name, - updated_path.id); - } catch (...) { - // There aren't enough unused nodes remaining so we need to drop the - // path - updated_path.failure_count = path_failure_threshold; - log::info( - cat, - "Unable to replace bad node ({}) in {} path ({}).", - *ed25519PublicKey, - path_name, - updated_path.id); - } - } - } - } - - // If we didn't find the specific node or the paths connection was closed then increment the - // path failure count - if (!found_invalid_node || !updated_path.conn_info.is_valid()) { - updated_path.failure_count += 1; - - // If the path has failed too many times we want to drop the guard snode (marking it as - // invalid) and increment the failure count of each node in the path) - if (updated_path.failure_count >= path_failure_threshold) { - for (auto& it : updated_path.nodes) - ++snode_failure_counts[it.to_string()]; - - // Set the failure count of the guard node to match the threshold so we don't use it - // again until we refresh the cache - snode_failure_counts[updated_path.nodes[0].to_string()] = snode_failure_threshold; - } else if (updated_path.nodes.size() < path_size) - // triggered when trying to establish a new path and, as such, we should increase - // the failure count of the guard node since it is probably invalid - ++snode_failure_counts[updated_path.nodes[0].to_string()]; - } - - // Drop the path if invalid (and currently an active path) - if (is_active_path) { - if (updated_path.failure_count >= path_failure_threshold) - drop_path_when_empty(info.request_id, info.path_type, *path_it); - else - std::replace( - paths[info.path_type].begin(), - paths[info.path_type].end(), - *path_it, - updated_path); - } - - if (handle_response) - (*handle_response)(false, timeout, status_code, headers, response); -} - -} // namespace session::network - -// MARK: C API - -namespace { - -inline session::network::Network& unbox(network_object* network_) { - assert(network_ && network_->internals); - return *static_cast(network_->internals); -} - -inline bool set_error(char* error, const std::exception& e) { - if (!error) - return false; - - std::string msg = e.what(); - if (msg.size() > 255) - msg.resize(255); - std::memcpy(error, msg.c_str(), msg.size() + 1); - return false; -} - -} // namespace - -extern "C" { - -using namespace session; -using namespace session::network; - -LIBSESSION_C_API bool network_init( - network_object** network, - const char* cache_path_, - bool use_testnet, - bool single_path_mode, - bool pre_build_paths, - char* error) { - try { - std::optional cache_path; - if (cache_path_) - cache_path = cache_path_; - - auto n = std::make_unique( - cache_path, use_testnet, single_path_mode, pre_build_paths); - auto n_object = std::make_unique(); - - n_object->internals = n.release(); - *network = n_object.release(); - return true; - } catch (const std::exception& e) { - return set_error(error, e); - } -} - -LIBSESSION_C_API void network_free(network_object* network) { - delete static_cast(network->internals); - delete network; -} - -LIBSESSION_C_API void network_suspend(network_object* network) { - unbox(network).suspend(); -} - -LIBSESSION_C_API void network_resume(network_object* network) { - unbox(network).resume(); -} - -LIBSESSION_C_API void network_close_connections(network_object* network) { - unbox(network).close_connections(); -} - -LIBSESSION_C_API void network_clear_cache(network_object* network) { - unbox(network).clear_cache(); -} - -LIBSESSION_C_API size_t network_get_snode_cache_size(network_object* network) { - return unbox(network).snode_cache_size(); -} - -LIBSESSION_C_API void network_set_status_changed_callback( - network_object* network, void (*callback)(CONNECTION_STATUS status, void* ctx), void* ctx) { - if (!callback) - unbox(network).status_changed = nullptr; - else - unbox(network).status_changed = [cb = std::move(callback), ctx](ConnectionStatus status) { - cb(static_cast(status), ctx); - }; -} - -LIBSESSION_C_API void network_set_paths_changed_callback( - network_object* network, - void (*callback)(onion_request_path* paths, size_t paths_len, void* ctx), - void* ctx) { - if (!callback) - unbox(network).paths_changed = nullptr; - else - unbox(network).paths_changed = [cb = std::move(callback), - ctx](std::vector> paths) { - size_t paths_mem_size = 0; - for (auto& nodes : paths) - paths_mem_size += - sizeof(onion_request_path) + (sizeof(network_service_node) * nodes.size()); - - // Allocate the memory for the onion_request_paths* array - auto* c_paths_array = static_cast(std::malloc(paths_mem_size)); - for (size_t i = 0; i < paths.size(); ++i) { - auto c_nodes = network::detail::convert_service_nodes(paths[i]); - - // Allocate memory that persists outside the loop - size_t node_array_size = sizeof(network_service_node) * c_nodes.size(); - auto* c_nodes_array = - static_cast(std::malloc(node_array_size)); - std::copy(c_nodes.begin(), c_nodes.end(), c_nodes_array); - new (c_paths_array + i) onion_request_path{c_nodes_array, c_nodes.size()}; - } - - cb(c_paths_array, paths.size(), ctx); - }; -} - -LIBSESSION_C_API void network_get_swarm( - network_object* network, - const char* swarm_pubkey_hex, - void (*callback)(network_service_node* nodes, size_t nodes_len, void*), - void* ctx) { - assert(swarm_pubkey_hex && callback); - unbox(network).get_swarm( - x25519_pubkey::from_hex({swarm_pubkey_hex, 64}), - [cb = std::move(callback), ctx](swarm_id_t, std::vector nodes) { - auto c_nodes = network::detail::convert_service_nodes(nodes); - cb(c_nodes.data(), c_nodes.size(), ctx); - }); -} - -LIBSESSION_C_API void network_get_random_nodes( - network_object* network, - uint16_t count, - void (*callback)(network_service_node*, size_t, void*), - void* ctx) { - assert(callback); - unbox(network).get_random_nodes( - count, [cb = std::move(callback), ctx](std::vector nodes) { - auto c_nodes = network::detail::convert_service_nodes(nodes); - cb(c_nodes.data(), c_nodes.size(), ctx); - }); -} - -LIBSESSION_C_API void network_send_onion_request_to_snode_destination( - network_object* network, - const network_service_node node, - const unsigned char* body_, - size_t body_size, - const char* swarm_pubkey_hex, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx) { - assert(callback); - - try { - std::optional> body; - if (body_size > 0) - body.emplace(body_, body_ + body_size); - - std::optional swarm_pubkey; - if (swarm_pubkey_hex) - swarm_pubkey = x25519_pubkey::from_hex({swarm_pubkey_hex, 64}); - - std::optional request_and_path_build_timeout; - if (request_and_path_build_timeout_ms > 0) - request_and_path_build_timeout = - std::chrono::milliseconds{request_and_path_build_timeout_ms}; - - std::array ip; - std::memcpy(ip.data(), node.ip, ip.size()); - - unbox(network).send_onion_request( - service_node{ - oxenc::from_hex({node.ed25519_pubkey_hex, 64}), - {0}, - INVALID_SWARM_ID, - "{}"_format(fmt::join(ip, ".")), - node.quic_port}, - body, - swarm_pubkey, - [cb = std::move(callback), ctx]( - bool success, - bool timeout, - int status_code, - std::vector> headers, - std::optional response) { - std::vector cHeaders; - std::vector cHeaderValues; - cHeaders.reserve(headers.size()); - cHeaderValues.reserve(headers.size()); - - for (const auto& [header, value] : headers) { - cHeaders.push_back(header.c_str()); - cHeaderValues.push_back(value.c_str()); - } - - if (response) - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - (*response).c_str(), - (*response).size(), - ctx); - else - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - nullptr, - 0, - ctx); - }, - std::chrono::milliseconds{request_timeout_ms}, - request_and_path_build_timeout); - } catch (const std::exception& e) { - callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); - } -} - -LIBSESSION_C_API void network_send_onion_request_to_server_destination( - network_object* network, - const network_server_destination server, - const unsigned char* body_, - size_t body_size, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx) { - assert(server.method && server.protocol && server.host && server.endpoint && - server.x25519_pubkey && callback); - - try { - std::optional> body; - if (body_size > 0) - body.emplace(body_, body_ + body_size); - - std::optional request_and_path_build_timeout; - if (request_and_path_build_timeout_ms > 0) - request_and_path_build_timeout = - std::chrono::milliseconds{request_and_path_build_timeout_ms}; - - unbox(network).send_onion_request( - network::detail::convert_server_destination(server), - body, - std::nullopt, - [cb = std::move(callback), ctx]( - bool success, - bool timeout, - int status_code, - std::vector> headers, - std::optional response) { - std::vector cHeaders; - std::vector cHeaderValues; - cHeaders.reserve(headers.size()); - cHeaderValues.reserve(headers.size()); - - for (const auto& [header, value] : headers) { - cHeaders.push_back(header.c_str()); - cHeaderValues.push_back(value.c_str()); - } - - if (response) - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - (*response).c_str(), - (*response).size(), - ctx); - else - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - nullptr, - 0, - ctx); - }, - std::chrono::milliseconds{request_timeout_ms}, - request_and_path_build_timeout); - } catch (const std::exception& e) { - callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); - } -} - -LIBSESSION_C_API void network_upload_to_server( - network_object* network, - const network_server_destination server, - const unsigned char* data, - size_t data_len, - const char* file_name_, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx) { - assert(data && server.method && server.protocol && server.host && server.endpoint && - server.x25519_pubkey && callback); - - try { - std::optional file_name; - if (file_name_) - file_name = file_name_; - - std::optional request_and_path_build_timeout; - if (request_and_path_build_timeout_ms > 0) - request_and_path_build_timeout = - std::chrono::milliseconds{request_and_path_build_timeout_ms}; - - unbox(network).upload_file_to_server( - {data, data + data_len}, - network::detail::convert_server_destination(server), - file_name, - [cb = std::move(callback), ctx]( - bool success, - bool timeout, - int status_code, - std::vector> headers, - std::optional response) { - std::vector cHeaders; - std::vector cHeaderValues; - cHeaders.reserve(headers.size()); - cHeaderValues.reserve(headers.size()); - - for (const auto& [header, value] : headers) { - cHeaders.push_back(header.c_str()); - cHeaderValues.push_back(value.c_str()); - } - - if (response) - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - (*response).c_str(), - (*response).size(), - ctx); - else - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - nullptr, - 0, - ctx); - }, - std::chrono::milliseconds{request_timeout_ms}, - request_and_path_build_timeout); - } catch (const std::exception& e) { - callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); - } -} - -LIBSESSION_C_API void network_download_from_server( - network_object* network, - const network_server_destination server, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx) { - assert(server.method && server.protocol && server.host && server.endpoint && - server.x25519_pubkey && callback); - - try { - std::optional request_and_path_build_timeout; - if (request_and_path_build_timeout_ms > 0) - request_and_path_build_timeout = - std::chrono::milliseconds{request_and_path_build_timeout_ms}; - - unbox(network).download_file( - network::detail::convert_server_destination(server), - [cb = std::move(callback), ctx]( - bool success, - bool timeout, - int status_code, - std::vector> headers, - std::optional response) { - std::vector cHeaders; - std::vector cHeaderValues; - cHeaders.reserve(headers.size()); - cHeaderValues.reserve(headers.size()); - - for (const auto& [header, value] : headers) { - cHeaders.push_back(header.c_str()); - cHeaderValues.push_back(value.c_str()); - } - - if (response) - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - (*response).c_str(), - (*response).size(), - ctx); - else - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - nullptr, - 0, - ctx); - }, - std::chrono::milliseconds{request_timeout_ms}, - request_and_path_build_timeout); - } catch (const std::exception& e) { - callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); - } -} - -LIBSESSION_C_API void network_get_client_version( - network_object* network, - CLIENT_PLATFORM platform, - const unsigned char* ed25519_secret, - int64_t request_timeout_ms, - int64_t request_and_path_build_timeout_ms, - network_onion_response_callback_t callback, - void* ctx) { - assert(platform && callback); - - try { - std::optional request_and_path_build_timeout; - if (request_and_path_build_timeout_ms > 0) - request_and_path_build_timeout = - std::chrono::milliseconds{request_and_path_build_timeout_ms}; - - unbox(network).get_client_version( - static_cast(platform), - onionreq::ed25519_seckey::from_bytes({ed25519_secret, 64}), - [cb = std::move(callback), ctx]( - bool success, - bool timeout, - int status_code, - std::vector> headers, - std::optional response) { - std::vector cHeaders; - std::vector cHeaderValues; - cHeaders.reserve(headers.size()); - cHeaderValues.reserve(headers.size()); - - for (const auto& [header, value] : headers) { - cHeaders.push_back(header.c_str()); - cHeaderValues.push_back(value.c_str()); - } - - if (response) - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - (*response).c_str(), - (*response).size(), - ctx); - else - cb(success, - timeout, - status_code, - cHeaders.data(), - cHeaderValues.data(), - headers.size(), - nullptr, - 0, - ctx); - }, - std::chrono::milliseconds{request_timeout_ms}, - request_and_path_build_timeout); - } catch (const std::exception& e) { - callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); - } -} - -} // extern "C" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b436795f..f783eae1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -27,14 +27,17 @@ set(LIB_SESSION_UTESTS_SOURCES test_proto.cpp test_random.cpp test_session_encrypt.cpp + test_utils.cpp test_xed25519.cpp case_logger.cpp ) -if (ENABLE_ONIONREQ) - list(APPEND LIB_SESSION_UTESTS_SOURCES test_session_network.cpp) +if(ENABLE_NETWORKING) + list(APPEND LIB_SESSION_UTESTS_SOURCES test_network_swarm.cpp) list(APPEND LIB_SESSION_UTESTS_SOURCES test_onionreq.cpp) + list(APPEND LIB_SESSION_UTESTS_SOURCES test_onion_request_router.cpp) + list(APPEND LIB_SESSION_UTESTS_SOURCES test_snode_pool.cpp) endif() add_library(test_libs INTERFACE) @@ -45,10 +48,10 @@ target_link_libraries(test_libs INTERFACE nlohmann_json::nlohmann_json oxen::logging) -if (ENABLE_ONIONREQ) - target_link_libraries(test_libs INTERFACE libsession::onionreq) +if (ENABLE_NETWORKING) + target_link_libraries(test_libs INTERFACE libsession::network) else() - target_compile_definitions(test_libs INTERFACE DISABLE_ONIONREQ) + target_compile_definitions(test_libs INTERFACE DISABLE_NETWORKING) endif() add_executable(testAll main.cpp ${LIB_SESSION_UTESTS_SOURCES}) diff --git a/tests/test_logging.cpp b/tests/test_logging.cpp index d1a38518..b8a90c30 100644 --- a/tests/test_logging.cpp +++ b/tests/test_logging.cpp @@ -8,7 +8,7 @@ #include "utils.hpp" -#ifndef DISABLE_ONIONREQ +#ifndef DISABLE_NETWORKING #include #endif @@ -89,7 +89,7 @@ TEST_CASE("Logging callbacks", "[logging]") { line1)); } -#ifndef DISABLE_ONIONREQ +#ifndef DISABLE_NETWORKING TEST_CASE("Logging callbacks with quic::Network", "[logging][network]") { oxen::log::clear_sinks(); simple_logs.clear(); diff --git a/tests/test_network_swarm.cpp b/tests/test_network_swarm.cpp new file mode 100644 index 00000000..887ea581 --- /dev/null +++ b/tests/test_network_swarm.cpp @@ -0,0 +1,200 @@ +#include +#include +#include +#include +#include + +#include "utils.hpp" + +using namespace session; +using namespace session::network; +using namespace session::network::swarm; + +swarm_id_t get_swarm_id( + std::string swarm_pubkey_hex, + std::vector>> swarms) { + if (swarm_pubkey_hex.size() == 66) + swarm_pubkey_hex = swarm_pubkey_hex.substr(2); + + auto pk = x25519_pubkey::from_hex(swarm_pubkey_hex); + return get_swarm(pk, swarms).first; +} + +TEST_CASE("Swarm", "[network][swarm][pubkey_to_swarm_space]") { + x25519_pubkey pk; + + pk = x25519_pubkey::from_hex( + "3506f4a71324b7dd114eddbf4e311f39dde243e1f2cb97c40db1961f70ebaae8"); + CHECK(pubkey_to_swarm_space(pk) == 17589930838143112648ULL); + pk = x25519_pubkey::from_hex( + "cf27da303a50ac8c4b2d43d27259505c9bcd73fc21cf2a57902c3d050730b604"); + CHECK(pubkey_to_swarm_space(pk) == 10370619079776428163ULL); + pk = x25519_pubkey::from_hex( + "d3511706b8b34f6e8411bf07bd22ba6b2435ca56846fbccf6eb1e166a6cd15cc"); + CHECK(pubkey_to_swarm_space(pk) == 2144983569669512198ULL); + pk = x25519_pubkey::from_hex( + "0f06693428fca9102a451e3f28d9cc743d8ea60a89ab6aa69eb119470c11cbd3"); + CHECK(pubkey_to_swarm_space(pk) == 9690840703409570833ULL); + pk = x25519_pubkey::from_hex( + "ffba630924aa1224bb930dde21c0d11bf004608f2812217f8ac812d6c7e3ad48"); + CHECK(pubkey_to_swarm_space(pk) == 4532060000165252872ULL); + pk = x25519_pubkey::from_hex( + "eeeeeeeeeeeeeeee777777777777777711111111111111118888888888888888"); + CHECK(pubkey_to_swarm_space(pk) == 0); + pk = x25519_pubkey::from_hex( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"); + CHECK(pubkey_to_swarm_space(pk) == 0); + pk = x25519_pubkey::from_hex( + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + CHECK(pubkey_to_swarm_space(pk) == 1); + pk = x25519_pubkey::from_hex( + "ffffffffffffffffffffffffffffffffffffffffffffffff7fffffffffffffff"); + CHECK(pubkey_to_swarm_space(pk) == 1ULL << 63); + pk = x25519_pubkey::from_hex( + "000000000000000000000000000000000000000000000000ffffffffffffffff"); + CHECK(pubkey_to_swarm_space(pk) == (uint64_t)-1); + pk = x25519_pubkey::from_hex( + "0000000000000000000000000000000000000000000000000123456789abcdef"); + CHECK(pubkey_to_swarm_space(pk) == 0x0123456789abcdefULL); +} + +TEST_CASE("Swarm", "[network][swarm][get_swarm]") { + std::vector>> swarms = { + {100, {}}, {200, {}}, {300, {}}, {399, {}}, {498, {}}, {596, {}}, {694, {}}}; + + // Exact matches: + // 0x64 = 100, 0xc8 = 200, 0x1f2 = 498 + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000064", swarms) == + 100); + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000000c8", swarms) == + 200); + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000001f2", swarms) == + 498); + + // Nearest + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000000", swarms) == + 100); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000001", swarms) == + 100); + + // Nearest, with wraparound + // 0x8000... is closest to the top value + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000008000000000000000", swarms) == + 694); + + // 0xa000... is closest (via wraparound) to the smallest + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000a000000000000000", swarms) == + 100); + + // This is the invalid swarm id for swarms, but should still work for a client + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000ffffffffffffffff", swarms) == + 100); + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000fffffffffffffffe", swarms) == + 100); + + // Midpoint tests; we prefer the lower value when exactly in the middle between two swarms. + // 0x96 = 150 + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000095", swarms) == + 100); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000096", swarms) == + 100); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000097", swarms) == + 200); + + // 0xfa = 250 + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000000f9", swarms) == + 200); + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000000fa", swarms) == + 200); + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000000fb", swarms) == + 300); + + // 0x15d = 349 + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000000000000000015d", swarms) == + 300); + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000000000000000015e", swarms) == + 399); + + // 0x1c0 = 448 + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000001c0", swarms) == + 399); + CHECK(get_swarm_id( + "0500000000000000000000000000000000000000000000000000000000000001c1", swarms) == + 498); + + // 0x223 = 547 + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000222", swarms) == + 498); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000223", swarms) == + 498); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000224", swarms) == + 596); + + // 0x285 = 645 + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000285", swarms) == + 596); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000286", swarms) == + 694); + + // 0x800....d is the midpoint between 694 and 100 (the long way). We always round "down" (which + // in this case, means wrapping to the largest swarm). + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000800000000000018c", swarms) == + 694); + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000800000000000018d", swarms) == + 694); + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000800000000000018e", swarms) == + 100); + + // With a swarm at -20 the midpoint is now 40 (=0x28). When our value is the *low* value we + // prefer the *last* swarm in the case of a tie (while consistent with the general case of + // preferring the left edge, it means we're inconsistent with the other wraparound case, above. + // *sigh*). + swarms.push_back({(uint64_t)-20, {}}); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000027", swarms) == + swarms.back().first); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000028", swarms) == + swarms.back().first); + CHECK(get_swarm_id( + "050000000000000000000000000000000000000000000000000000000000000029", swarms) == + swarms.front().first); + + // The code used to have a broken edge case if we have a swarm at zero and a client at max-u64 + // because of an overflow in how the distance is calculated (the first swarm will be calculated + // as max-u64 away, rather than 1 away), and so the id always maps to the highest swarm (even + // though 0xfff...fe maps to the lowest swarm; the first check here, then, would fail. + swarms.insert(swarms.begin(), {0, {}}); + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000ffffffffffffffff", swarms) == + 0); + CHECK(get_swarm_id( + "05000000000000000000000000000000000000000000000000fffffffffffffffe", swarms) == + 0); +} diff --git a/tests/test_onion_request_router.cpp b/tests/test_onion_request_router.cpp new file mode 100644 index 00000000..8685822a --- /dev/null +++ b/tests/test_onion_request_router.cpp @@ -0,0 +1,742 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +using namespace session; +using namespace session::network; + +namespace session::network { +class TestOnionRequestRouter { + public: + static void set_paths( + std::shared_ptr router, + RequestCategory category, + std::vector paths) { + router->_paths.emplace(category, paths); + } + + static std::vector get_paths( + std::shared_ptr router, RequestCategory category) { + return router->_paths[category]; + } + + static void set_request_queues( + std::shared_ptr router, + std::unordered_map> queues) { + router->_request_queues = queues; + } + + static uint16_t failure_count( + std::shared_ptr router, + RequestCategory category, + std::string path_id) { + for (auto& path : router->_paths[category]) + if (path.id == path_id) + return path.failure_count; + + return 0; + } + + static void build_path( + std::shared_ptr router, + RequestCategory category, + std::optional initiating_req_id = std::nullopt, + const std::vector& nodes_to_exclude_ = {}, + std::optional original_path_id = std::nullopt) { + router->_build_path(category, initiating_req_id, nodes_to_exclude_, original_path_id); + } + + static OnionPath* find_valid_path( + std::shared_ptr router, const Request& request) { + return router->_find_valid_path(request); + } + + static void handle_transport_response( + std::shared_ptr router, + std::string path_id, + Request original_request, + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional decrypted_body, + network_response_callback_t callback) { + router->_handle_transport_response( + path_id, + original_request, + success, + timeout, + status_code, + std::move(headers), + std::move(decrypted_body), + std::move(callback)); + } +}; + +namespace detail { + class TestRequestQueue : public detail::RequestQueue, public CallTracker { + public: + TestRequestQueue( + std::shared_ptr loop, std::chrono::milliseconds check_frequency) : + detail::RequestQueue(loop, check_frequency) {}; + + void add(Request request, network_response_callback_t callback) override { + if (check_should_ignore_and_log_call("add")) + return; + detail::RequestQueue::add(std::move(request), std::move(callback)); + } + + void add_front(std::pair req_pair) override { + if (check_should_ignore_and_log_call("add_front")) + return; + detail::RequestQueue::add_front(std::move(req_pair)); + } + + std::deque> pop_all() override { + if (check_should_ignore_and_log_call("pop_all")) + return {}; + return detail::RequestQueue::pop_all(); + } + + private: + void check_timeouts() override { + if (check_should_ignore_and_log_call("check_timeouts")) + return; + detail::RequestQueue::check_timeouts(); + } + }; +} // namespace detail + +namespace { + class TestSnodePool : public SnodePool, public CallTracker { + public: + std::optional> mock_unused_nodes; + + TestSnodePool( + config::SnodePoolConfig config, + std::shared_ptr loop, + network_fetcher_t direct_fetcher = [](Request, network_response_callback_t) {}) : + SnodePool(std::move(config), std::move(loop), std::move(direct_fetcher)) {} + + void record_node_failure(const service_node& node, bool permanent = false) override { + if (check_should_ignore_and_log_call("record_node_failure(node)")) + return; + SnodePool::record_node_failure(node, permanent); + } + + void record_node_failure(const ed25519_pubkey& key, bool permanent = false) override { + if (check_should_ignore_and_log_call("record_node_failure(key)")) + return; + SnodePool::record_node_failure(key, permanent); + } + + void refresh_if_needed( + const std::vector& in_use_nodes, + std::function on_refresh_complete = nullptr) override { + func_called("refresh_if_needed"); + // Do nothing (don't want to trigger a cache refresh) + } + + void get_swarm( + session::network::x25519_pubkey swarm_pubkey, + std::function)> callback) + override { + func_called("get_swarm"); + // Do nothing (don't want to trigger a cache refresh) + } + + std::vector get_unused_nodes( + size_t count, const std::vector& exclude = {}) override { + if (check_should_ignore_and_log_call("get_unused_nodes")) + return {}; + + if (mock_unused_nodes) + return *mock_unused_nodes; + + return SnodePool::get_unused_nodes(count, exclude); + } + }; + + class TestTransport : public ITransport, public CallTracker { + public: + void suspend() override { func_called("suspend"); }; + void resume(bool automatically_reconnect = true) override { func_called("resume"); }; + void close_connections() override { func_called("close_connections"); }; + + ConnectionStatus get_status() const override { return ConnectionStatus::unknown; }; + void verify_connectivity( + service_node node, + std::chrono::milliseconds timeout, + const std::string& request_id, + std::function callback) override { + func_called("verify_connectivity"); + } + void add_failure_listener( + const ed25519_pubkey& pubkey, std::function listener) override { + func_called("add_failure_listener"); + } + void remove_failure_listeners(const ed25519_pubkey& pubkey) override { + func_called("remove_failure_listeners"); + } + + void send_request(Request request, network_response_callback_t callback) override { + func_called("send_request"); + } + }; + + struct Result { + bool success; + bool timeout; + int16_t status_code; + std::vector> headers; + std::optional response; + }; +} // namespace + +TEST_CASE("Network", "[network][onion_request_router][handle_errors]") { + config::SnodePoolConfig pool_config = { + std::nullopt, + std::chrono::minutes{5}, + std::chrono::minutes{5}, + false, + network::opt::retry_delay{50ms, 200ms}, + opt::netid::Target::testnet, + {}, + 0, + 0, + 3, // Node failure threshold + false}; + config::OnionRequestRouterConfig config = { + network::opt::retry_delay{50ms, 200ms}, + 50ms, + 3, + 3, + 10, + true, + true, + {{RequestCategory::standard, 1}}}; + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto ed_pk2 = "5ea34e72bb044654a6a23675690ef5ffaaf1656b02f93fb76655f9cbdbe89876"_hexbytes; + auto ed_pk3 = "e17a692033200ae41350df9709754edde7343e2cf2f23e88f993319e0720e5e5"_hexbytes; + auto ed_pk4 = "7b633fa6fb462b90db6f0f50384190ce7715e31b7aa93d87dbd7e94e33d4251f"_hexbytes; + auto target = service_node{ed_pk, oxen::quic::ipv4{"127.0.0.1"}, 20001, 30001, {2, 11, 0}, 0}; + auto target2 = service_node{ed_pk2, oxen::quic::ipv4{"127.0.0.1"}, 20002, 30002, {2, 11, 0}, 0}; + auto target3 = service_node{ed_pk3, oxen::quic::ipv4{"127.0.0.1"}, 20003, 30003, {2, 11, 0}, 0}; + auto target4 = service_node{ed_pk4, oxen::quic::ipv4{"127.0.0.1"}, 20004, 30004, {2, 11, 0}, 0}; + auto request = + Request{"AAAA", target, "info", to_vector("test"), RequestCategory::standard, 0ms}; + std::optional path; + Result result; + + auto loop = std::make_shared(); + auto snode_pool = std::make_shared(pool_config, loop); + auto transport = std::make_shared(); + std::shared_ptr router; + + // Check the handling of the codes which make no changes + auto codes_with_no_changes = {400, 404, 406, 425}; + + for (auto code : codes_with_no_changes) { + snode_pool->clear_node_failure_counts(); + snode_pool->reset_calls(); + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {*path}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + request, + false, + false, + code, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == code); + CHECK(result.response.value_or("") == ""); + CHECK(snode_pool->did_not_call("record_node_failure(node)")); + CHECK(snode_pool->did_not_call("record_node_failure(key)")); + CHECK(snode_pool->node_failure_count(target2) == 0); + CHECK(snode_pool->node_failure_count(target3) == 0); + CHECK(snode_pool->node_failure_count(target4) == 0); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == + 0); + } + + // Check general error handling (first failure) + snode_pool->clear_node_failure_counts(); + snode_pool->reset_calls(); + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {*path}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + request, + false, + false, + 500, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK(result.response.value_or("") == ""); + CHECK(snode_pool->did_not_call("record_node_failure(node)")); + CHECK(snode_pool->did_not_call("record_node_failure(key)")); + CHECK(snode_pool->node_failure_count(target2) == 0); + CHECK(snode_pool->node_failure_count(target3) == 0); + CHECK(snode_pool->node_failure_count(target4) == 0); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == 1); + + // Check general error handling with no response (too many path failures) + snode_pool->clear_node_failure_counts(); + REQUIRE(snode_pool->node_failure_count(target2) == 0); + snode_pool->reset_calls(); + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths( + router, + RequestCategory::standard, + {OnionPath{"Test", {target2, target3, target4}, 0, 9}}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + request, + false, + false, + 500, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK(result.response.value_or("") == ""); + CHECK(snode_pool->called("record_node_failure(node)", 3)); + CHECK(snode_pool->node_failure_count(target2) == 1); + CHECK(snode_pool->node_failure_count(target3) == 1); + CHECK(snode_pool->node_failure_count(target4) == 1); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == + 0); // Path dropped and reset + + // Check general error handling with a path and specific node failure + snode_pool->clear_node_failure_counts(); + snode_pool->reset_calls(); + snode_pool->mock_unused_nodes = {target}; + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {*path}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + request, + false, + false, + 500, + {}, + "Next node not found: {}"_format(ed25519_pubkey::from_bytes(ed_pk3).hex()), + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK(result.response.value_or("") == + "Next node not found: {}"_format(ed25519_pubkey::from_bytes(ed_pk3).hex())); + CHECK(snode_pool->called("record_node_failure(node)", 1)); + CHECK(snode_pool->node_failure_count(target2) == 0); + CHECK(snode_pool->node_failure_count(target3) == 3); // Node will have been dropped + CHECK(snode_pool->node_failure_count(target4) == 0); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == 1); + CHECK(TestOnionRequestRouter::get_paths(router, RequestCategory::standard).front().nodes[1] != + target3); + + // Check a 421 doesn't impact the node failure counts + snode_pool->clear_node_failure_counts(); + snode_pool->reset_calls(); + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {*path}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + request, + false, + false, + 421, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 421); + CHECK(result.response.value_or("") == ""); + CHECK(snode_pool->did_not_call("record_node_failure(node)")); + CHECK(snode_pool->did_not_call("record_node_failure(key)")); + CHECK(snode_pool->node_failure_count(target2) == 0); + CHECK(snode_pool->node_failure_count(target3) == 0); + CHECK(snode_pool->node_failure_count(target4) == 0); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == 1); + + // Check a timeout with a server destination doesn't impact the failure counts + auto server_request = + Request{"AAAA", + ServerDestination{ + "https", + "open.getsession.org", + x25519_pubkey::from_hex("a03c383cf63c3c4efe67acc52112a6dd734b3a946b9545" + "f488aaa93da7991238"), + 443, + std::nullopt, + "GET"}, + "info", + to_vector("test"), + RequestCategory::standard, + 0ms}; + snode_pool->clear_node_failure_counts(); + snode_pool->reset_calls(); + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {*path}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + server_request, + false, + true, + -1, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK(result.timeout); + CHECK(result.status_code == -1); + CHECK(result.response.value_or("") == ""); + CHECK(snode_pool->did_not_call("record_node_failure(node)")); + CHECK(snode_pool->did_not_call("record_node_failure(key)")); + CHECK(snode_pool->node_failure_count(target2) == 0); + CHECK(snode_pool->node_failure_count(target3) == 0); + CHECK(snode_pool->node_failure_count(target4) == 0); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == 0); + + // Check the handling of the codes which should be ignored when the request was sent to a server + // make no changes + auto server_codes_with_no_changes = {500, 504}; + + for (auto code : server_codes_with_no_changes) { + snode_pool->clear_node_failure_counts(); + snode_pool->reset_calls(); + path.emplace(OnionPath{"Test", {target2, target3, target4}}); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {*path}); + TestOnionRequestRouter::handle_transport_response( + router, + "Test", + server_request, + false, + false, + code, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK(result.timeout == (code == 504)); + CHECK(result.status_code == code); + CHECK(result.response.value_or("") == ""); + CHECK(snode_pool->did_not_call("record_node_failure(node)")); + CHECK(snode_pool->did_not_call("record_node_failure(key)")); + CHECK(snode_pool->node_failure_count(target2) == 0); + CHECK(snode_pool->node_failure_count(target3) == 0); + CHECK(snode_pool->node_failure_count(target4) == 0); + CHECK(TestOnionRequestRouter::failure_count(router, RequestCategory::standard, "Test") == + 0); + } +} + +TEST_CASE("Network", "[network][onion_request_router][build_path]") { + config::SnodePoolConfig pool_config = { + std::nullopt, + std::chrono::minutes{5}, + std::chrono::minutes{5}, + false, + network::opt::retry_delay{50ms, 200ms}, + opt::netid::Target::testnet, + {}, + 0, + 0, + 3, // Node failure threshold + false}; + config::OnionRequestRouterConfig config = { + network::opt::retry_delay{50ms, 200ms}, + 50ms, + 3, + 3, + 10, + true, + true, + {{RequestCategory::standard, 1}}}; + auto loop = std::make_shared(); + auto snode_pool = std::make_shared(pool_config, loop); + auto transport = std::make_shared(); + std::shared_ptr router; + + // Nothing should happen if the network is suspended + snode_pool->reset_calls(); + router = std::make_shared(config, loop, snode_pool, transport); + router->suspend(); + TestOnionRequestRouter::build_path(router, RequestCategory::standard); + CHECK(snode_pool->did_not_call("get_unused_nodes")); + + // If the unused nodes are empty it refreshes them + snode_pool->reset_calls(); + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::build_path(router, RequestCategory::standard); + CHECK(snode_pool->called("get_unused_nodes")); + CHECK(snode_pool->called("refresh_if_needed")); +} + +TEST_CASE("Network", "[network][onion_request_router][find_valid_path]") { + config::SnodePoolConfig pool_config = { + std::nullopt, + std::chrono::minutes{5}, + std::chrono::minutes{5}, + false, + network::opt::retry_delay{50ms, 200ms}, + opt::netid::Target::testnet, + {}, + 0, + 0, + 3, // cache_node_failure_threshold + false}; + config::OnionRequestRouterConfig config = { + network::opt::retry_delay{50ms, 200ms}, + 50ms, + 3, + 3, + 10, + true, + false, + {{RequestCategory::standard, 1}}}; + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto ed_pk2 = "5ea34e72bb044654a6a23675690ef5ffaaf1656b02f93fb76655f9cbdbe89876"_hexbytes; + auto ed_pk3 = "e17a692033200ae41350df9709754edde7343e2cf2f23e88f993319e0720e5e5"_hexbytes; + auto ed_pk4 = "7b633fa6fb462b90db6f0f50384190ce7715e31b7aa93d87dbd7e94e33d4251f"_hexbytes; + auto target = service_node{ed_pk, oxen::quic::ipv4{"127.0.0.1"}, 20001, 30001, {2, 11, 0}, 0}; + auto target2 = service_node{ed_pk2, oxen::quic::ipv4{"127.0.0.1"}, 20002, 30002, {2, 11, 0}, 0}; + auto target3 = service_node{ed_pk3, oxen::quic::ipv4{"127.0.0.1"}, 20003, 30003, {2, 11, 0}, 0}; + auto target4 = service_node{ed_pk4, oxen::quic::ipv4{"127.0.0.1"}, 20004, 30004, {2, 11, 0}, 0}; + auto path1 = OnionPath{"Test1", {target, target2, target3}}; + auto path2 = OnionPath{"Test2", {target2, target3, target4}}; + auto request = + Request{"AAAA", target, "info", to_vector("test"), RequestCategory::standard, 0ms}; + + auto loop = std::make_shared(); + auto snode_pool = std::make_shared(pool_config, loop); + auto transport = std::make_shared(); + std::shared_ptr router; + + // It returns nothing when given no path options + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {}); + CHECK(TestOnionRequestRouter::find_valid_path(router, request) == nullptr); + + // It excludes paths which include the IP of the target + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {path1}); + CHECK(TestOnionRequestRouter::find_valid_path(router, request) == nullptr); + + // It returns a path when there is a valid one + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {path2}); + CHECK(TestOnionRequestRouter::find_valid_path(router, request) != nullptr); + + // In 'single_path_mode' it does allow the path to include the IP of the target (so that + // requests can still be made) + config = { + network::opt::retry_delay{50ms, 200ms}, + 50ms, + 3, + 3, + 10, + true, + true, // single path mode + {{RequestCategory::standard, 1}}}; + router = std::make_shared(config, loop, snode_pool, transport); + TestOnionRequestRouter::set_paths(router, RequestCategory::standard, {path1}); + CHECK(TestOnionRequestRouter::find_valid_path(router, request) != nullptr); +} + +TEST_CASE("Network", "[network][onion_request_router][check_request_queue_timeouts]") { + config::SnodePoolConfig pool_config = { + std::nullopt, + std::chrono::minutes{5}, + std::chrono::minutes{5}, + false, + network::opt::retry_delay{50ms, 200ms}, + opt::netid::Target::testnet, + {}, + 0, + 0, + 3, // cache_node_failure_threshold + false}; + config::OnionRequestRouterConfig config = { + network::opt::retry_delay{50ms, 200ms}, + 50ms, + 3, + 3, + 10, + true, + false, + {{RequestCategory::standard, 1}}}; + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto target = service_node{ed_pk, oxen::quic::ipv4{"127.0.0.1"}, 20001, 30001, {2, 11, 0}, 0}; + auto target2 = service_node{ed_pk, oxen::quic::ipv4{"127.0.0.1"}, 20002, 30002, {2, 11, 0}, 0}; + auto target3 = service_node{ed_pk, oxen::quic::ipv4{"127.0.0.1"}, 20003, 30003, {2, 11, 0}, 0}; + auto target4 = service_node{ed_pk, oxen::quic::ipv4{"127.0.0.1"}, 20004, 30004, {2, 11, 0}, 0}; + auto path = OnionPath{"Test1", {target2, target3, target4}}; + auto request = + Request{"AAAA", target, "info", to_vector("test"), RequestCategory::standard, 0ms}; + Result result; + + auto loop = std::make_shared(); + auto snode_pool = std::make_shared(pool_config, loop); + auto transport = std::make_shared(); + auto queue = std::make_shared(loop, 50ms); + std::shared_ptr router; + + // Test that it doesn't start checking for timeouts when the request doesn't have an overall + // timeout + request = + Request{"AAAA", + target, + "info", + to_vector("test"), + RequestCategory::standard, + 1000ms, + std::nullopt}; + router = std::make_shared(config, loop, snode_pool, transport); + queue = std::make_shared(loop, 50ms); + TestOnionRequestRouter::set_request_queues(router, {{RequestCategory::standard, queue}}); + router->send_request( + request, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK(queue->did_not_call("check_timeouts", 250ms)); + + // Test that it does start checking for timeouts when the request has an overall timeout + request = Request{ + "AAAA", target, "info", to_vector("test"), RequestCategory::standard, 1000ms, 1000ms}; + router = std::make_shared(config, loop, snode_pool, transport); + queue = std::make_shared(loop, 50ms); + TestOnionRequestRouter::set_request_queues(router, {{RequestCategory::standard, queue}}); + router->send_request( + request, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK(queue->called("add", 250ms)); + CHECK(queue->called("check_timeouts", 250ms)); + + // Test that it fails the request with a timeout if it has an overall timeout and the path build + // takes too long + std::promise prom; + request = Request{ + "AAAA", target, "info", to_vector("test"), RequestCategory::standard, 1000ms, 200ms}; + router = std::make_shared(config, loop, snode_pool, transport); + queue = std::make_shared(loop, 50ms); + TestOnionRequestRouter::set_request_queues(router, {{RequestCategory::standard, queue}}); + router->send_request( + request, + [&prom](bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + prom.set_value({success, timeout, status_code, headers, response}); + }); + + // Wait for the result to be set + result = prom.get_future().get(); + + CHECK_FALSE(result.success); + CHECK(result.timeout); +} + +} // namespace session::network diff --git a/tests/test_onionreq.cpp b/tests/test_onionreq.cpp index c73935c5..79f3bfd7 100644 --- a/tests/test_onionreq.cpp +++ b/tests/test_onionreq.cpp @@ -1,12 +1,12 @@ #include #include #include -#include #include "utils.hpp" using namespace session; using namespace session::onionreq; +using namespace session::network; TEST_CASE("Onion request encryption", "[encryption][onionreq]") { diff --git a/tests/test_session_network.cpp b/tests/test_session_network.cpp deleted file mode 100644 index 4a6b1b9d..00000000 --- a/tests/test_session_network.cpp +++ /dev/null @@ -1,1642 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "utils.hpp" - -using namespace session; -using namespace session::onionreq; -using namespace session::network; - -namespace { -struct TestServer { - std::shared_ptr loop; - std::shared_ptr endpoint; - service_node node; - - ~TestServer() { - loop->call_get([&]() { endpoint->close_conns(); }); - endpoint.reset(); - loop.reset(); - } -}; - -struct Result { - bool success; - bool timeout; - int16_t status_code; - std::vector> headers; - std::optional response; -}; - -service_node test_node( - const std::vector ed_pk, const uint16_t index, const bool unique_ip = true) { - return service_node{ - ed_pk, - {2, 8, 0}, - INVALID_SWARM_ID, - (unique_ip ? fmt::format("0.0.0.{}", index) : "1.1.1.1"), - static_cast(10000 + index)}; -} - -std::optional node_for_destination(network_destination destination) { - if (auto* dest = std::get_if(&destination)) - return *dest; - - return std::nullopt; -} - -std::shared_ptr create_test_server(uint16_t port) { - oxen::quic::opt::inbound_alpns server_alpns{"oxenstorage"}; - auto server_key_pair = session::ed25519::ed25519_key_pair(to_span(fmt::format("{:032}", port))); - auto server_x25519_pubkey = session::curve25519::to_curve25519_pubkey( - {server_key_pair.first.data(), server_key_pair.first.size()}); - auto server_x25519_seckey = session::curve25519::to_curve25519_seckey( - {server_key_pair.second.data(), server_key_pair.second.size()}); - auto creds = - oxen::quic::GNUTLSCreds::make_from_ed_seckey(to_string_view(server_key_pair.second)); - oxen::quic::Address server_local{port}; - session::onionreq::HopEncryption decryptor{ - x25519_seckey::from_bytes(to_span(server_x25519_seckey)), - x25519_pubkey::from_bytes(to_span(server_x25519_pubkey)), - true}; - - auto server_cb = [&](oxen::quic::message m) { - nlohmann::json response{{"hf", {1, 0, 0}}, {"t", 1234567890}, {"version", {2, 8, 0}}}; - m.respond(response.dump(), false); - }; - - auto onion_cb = [&](oxen::quic::message m) { - nlohmann::json response{{"hf", {2, 0, 0}}, {"t", 1234567890}, {"version", {2, 8, 0}}}; - m.respond(response.dump(), false); - }; - - oxen::quic::stream_constructor_callback server_constructor = - [&](oxen::quic::Connection& c, oxen::quic::Endpoint& e, std::optional) { - auto s = e.loop.make_shared(c, e); - s->register_handler("info", server_cb); - s->register_handler("onion_req", onion_cb); - return s; - }; - - auto loop = std::make_shared(); - auto endpoint = oxen::quic::Endpoint::endpoint(*loop, server_local, server_alpns); - endpoint->listen(creds, server_constructor); - - auto node = service_node{ - to_string_view(server_key_pair.first), - {2, 8, 0}, - INVALID_SWARM_ID, - "127.0.0.1"s, - endpoint->local().port()}; - - return std::make_shared(loop, endpoint, node); -} - -} // namespace - -namespace session::network { -class TestNetwork : public Network { - public: - std::unordered_map call_counts; - std::mutex call_counts_mutex; - std::condition_variable call_cv; - - std::vector calls_to_ignore; - std::chrono::milliseconds retry_delay_value = 0ms; - std::optional> find_valid_path_response; - std::optional last_request_info; - bool handle_onion_requests_as_plaintext = false; - - TestNetwork( - std::optional cache_path, - bool use_testnet, - bool single_path_mode, - bool pre_build_paths) : - Network{cache_path, use_testnet, single_path_mode, pre_build_paths} { - paths_changed = [this](std::vector>) { - func_called("paths_changed"); - }; - } - - void set_suspended(bool suspended_) { suspended = suspended_; } - - bool get_suspended() { return suspended; } - - ConnectionStatus get_status() { return status; } - - void set_snode_cache(std::vector cache) { - // Need to set the `last_snode_cache_update` to `10s` ago because otherwise it'll be - // considered invalid when checking the cache validity - snode_cache = cache; - last_snode_cache_update = (std::chrono::system_clock::now() - 10s); - } - - void set_unused_connections(std::deque unused_connections_) { - unused_connections = unused_connections_; - } - - void set_in_progress_connections( - std::unordered_map in_progress_connections_) { - in_progress_connections = in_progress_connections_; - } - - void add_path(PathType path_type, std::vector nodes) { - paths[path_type].emplace_back( - onion_path{"Test", {nodes[0], nullptr, nullptr, nullptr}, nodes, 0}); - } - - void set_paths(PathType path_type, std::vector paths_) { - paths[path_type] = paths_; - } - - std::vector get_paths(PathType path_type) { return paths[path_type]; } - - void set_all_swarms(std::vector>> all_swarms_) { - all_swarms = all_swarms_; - } - - void set_swarm( - session::onionreq::x25519_pubkey swarm_pubkey, - swarm_id_t swarm_id, - std::vector swarm) { - swarm_cache[swarm_pubkey.hex()] = {swarm_id, swarm}; - } - - std::pair> get_cached_swarm( - session::onionreq::x25519_pubkey swarm_pubkey) { - return swarm_cache[swarm_pubkey.hex()]; - } - - swarm_id_t get_swarm_id(std::string swarm_pubkey_hex) { - if (swarm_pubkey_hex.size() == 66) - swarm_pubkey_hex = swarm_pubkey_hex.substr(2); - - auto pk = x25519_pubkey::from_hex(swarm_pubkey_hex); - std::promise prom; - get_swarm(pk, [&prom](swarm_id_t result, std::vector) { - prom.set_value(result); - }); - return prom.get_future().get(); - } - - void set_failure_count(service_node node, uint8_t failure_count) { - snode_failure_counts[node.to_string()] = failure_count; - } - - uint8_t get_failure_count(service_node node) { - return snode_failure_counts.try_emplace(node.to_string(), 0).first->second; - } - - uint8_t get_failure_count(PathType path_type, onion_path path) { - auto current_paths = paths[path_type]; - auto target_path = std::find_if( - current_paths.begin(), current_paths.end(), [&path](const auto& path_it) { - return path_it.nodes[0] == path.nodes[0]; - }); - - if (target_path != current_paths.end()) - return target_path->failure_count; - - return 0; - } - - void set_path_build_queue(std::deque path_build_queue_) { - path_build_queue = path_build_queue_; - } - - std::deque get_path_build_queue() { return path_build_queue; } - - void set_path_build_failures(int path_build_failures_) { - path_build_failures = path_build_failures_; - } - - int get_path_build_failures() { return path_build_failures; } - - void set_unused_nodes(std::vector unused_nodes_) { unused_nodes = unused_nodes_; } - - std::vector get_unused_nodes() { return Network::get_unused_nodes(); } - - std::vector get_unused_nodes_value() { return unused_nodes; } - - void add_pending_request(PathType path_type, request_info info) { - request_queue[path_type].emplace_back( - std::move(info), - [](bool, - bool, - int16_t, - std::vector>, - std::optional) {}); - } - - std::pair>, onion_path> create_test_path() { - std::vector> path_servers; - std::vector path_nodes; - path_nodes.reserve(3); - - for (auto i = 0; i < 3; ++i) { - path_servers.emplace_back(create_test_server(static_cast(4390 + i))); - path_nodes.emplace_back(path_servers[i]->node); - } - - std::promise>> prom; - establish_connection( - "Test", - path_nodes[0], - 3s, - [&prom](connection_info conn_info, std::optional error) { - prom.set_value({std::move(conn_info), error}); - }); - - // Wait for the result to be set - auto result = prom.get_future().get(); - REQUIRE(result.first.is_valid()); - return {path_servers, onion_path{"Test", std::move(result.first), path_nodes, uint8_t{0}}}; - } - - // Overridden Functions - - std::chrono::milliseconds retry_delay(int, std::chrono::milliseconds) override { - return retry_delay_value; - } - - void update_disk_cache_throttled(bool force_immediate_write) override { - if (check_should_ignore_and_log_call("update_disk_cache_throttled")) - return; - - Network::update_disk_cache_throttled(force_immediate_write); - } - - void establish_and_store_connection(std::string request_id) override { - if (check_should_ignore_and_log_call("establish_and_store_connection")) - return; - - Network::establish_and_store_connection(request_id); - } - - void refresh_snode_cache(std::optional existing_request_id) override { - if (check_should_ignore_and_log_call("refresh_snode_cache")) - return; - - Network::refresh_snode_cache(existing_request_id); - } - - void build_path(std::string path_id, PathType path_type) override { - if (check_should_ignore_and_log_call("build_path")) - return; - - Network::build_path(path_id, path_type); - } - - std::optional find_valid_path( - request_info info, std::vector paths) override { - if (check_should_ignore_and_log_call("find_valid_path")) - return std::nullopt; - - if (find_valid_path_response) - return *find_valid_path_response; - - return Network::find_valid_path(info, paths); - } - - void check_request_queue_timeouts(std::optional request_timeout_id) override { - if (check_should_ignore_and_log_call("check_request_queue_timeouts")) - return; - - Network::check_request_queue_timeouts(request_timeout_id); - } - - void _send_onion_request( - request_info info, network_response_callback_t handle_response) override { - last_request_info = info; - - if (check_should_ignore_and_log_call("_send_onion_request")) - return; - - Network::_send_onion_request(std::move(info), std::move(handle_response)); - } - - // Exposing Private Functions - - void establish_connection( - std::string request_id, - service_node target, - std::optional timeout, - std::function error)> callback) { - Network::establish_connection(request_id, target, timeout, std::move(callback)); - } - - void build_path_if_needed(PathType path_type, bool found_valid_path) override { - return Network::build_path_if_needed(path_type, found_valid_path); - } - - void send_request( - request_info info, connection_info conn, network_response_callback_t handle_response) { - Network::send_request(info, conn, std::move(handle_response)); - } - - void handle_errors( - request_info info, - connection_info conn_info, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response, - std::optional handle_response) override { - func_called("handle_errors"); - Network::handle_errors( - info, - conn_info, - timeout, - status_code, - headers, - response, - std::move(handle_response)); - } - - std::tuple< - int16_t, - std::vector>, - std::optional> - process_v3_onion_response(session::onionreq::Builder builder, std::string response) override { - func_called("process_v3_onion_response"); - - if (handle_onion_requests_as_plaintext) - return {200, {}, response}; - - return Network::process_v3_onion_response(builder, response); - } - - std::tuple< - int16_t, - std::vector>, - std::optional> - process_v4_onion_response(session::onionreq::Builder builder, std::string response) override { - func_called("process_v4_onion_response"); - - if (handle_onion_requests_as_plaintext) - return {200, {}, response}; - - return Network::process_v4_onion_response(builder, response); - } - - // Mocking Functions - - template - void ignore_calls_to(Strings&&... __args) { - (calls_to_ignore.emplace_back(std::forward(__args)), ...); - } - - bool check_should_ignore_and_log_call(const std::string& name) { - func_called(name); - - return std::find(calls_to_ignore.begin(), calls_to_ignore.end(), name) != - calls_to_ignore.end(); - } - - void func_called(const std::string& name) { - bool notify = false; - { - std::lock_guard lock(call_counts_mutex); - ++call_counts[name]; - notify = true; - } - - if (notify) - call_cv.notify_all(); - } - - void reset_calls() { - std::lock_guard lock_counts(call_counts_mutex); - call_counts.clear(); - } - - int get_call_count(const std::string& name) { - std::lock_guard lock(call_counts_mutex); - auto it = call_counts.find(name); - return (it != call_counts.end()) ? it->second : 0; - } - - bool called(const std::string& name, int times = 1) { return (get_call_count(name) >= times); } - - [[nodiscard]] bool called( - const std::string& name, std::chrono::milliseconds timeout, int times = 1) { - if (times <= 0) - times = 1; - - std::unique_lock lock(call_counts_mutex); - - auto predicate = [&]() { - auto it = call_counts.find(name); - return (it != call_counts.end() && it->second >= times); - }; - - return call_cv.wait_for(lock, timeout, predicate); - } - - bool did_not_call(const std::string& name) { - std::lock_guard lock(call_counts_mutex); - return !call_counts.contains(name); - } - - [[nodiscard]] bool did_not_call(const std::string& name, std::chrono::milliseconds duration) { - std::unique_lock lock(call_counts_mutex); - auto predicate = [&]() { return call_counts.contains(name); }; - - if (predicate()) - return false; // Already called - - bool was_called_during_wait = call_cv.wait_for(lock, duration, predicate); - return !was_called_during_wait; - } -}; -} // namespace session::network - -TEST_CASE("Network", "[network][parse_url]") { - auto [proto1, host1, port1, path1] = parse_url("HTTPS://example.com/test"); - auto [proto2, host2, port2, path2] = parse_url("http://example2.com:1234/test/123456"); - auto [proto3, host3, port3, path3] = parse_url("https://example3.com"); - auto [proto4, host4, port4, path4] = parse_url("https://example4.com/test?value=test"); - - CHECK(proto1 == "https://"); - CHECK(proto2 == "http://"); - CHECK(proto3 == "https://"); - CHECK(proto4 == "https://"); - CHECK(host1 == "example.com"); - CHECK(host2 == "example2.com"); - CHECK(host3 == "example3.com"); - CHECK(host4 == "example4.com"); - CHECK(port1.value_or(9999) == 9999); - CHECK(port2.value_or(9999) == 1234); - CHECK(port3.value_or(9999) == 9999); - CHECK(port4.value_or(9999) == 9999); - CHECK(path1.value_or("NULL") == "/test"); - CHECK(path2.value_or("NULL") == "/test/123456"); - CHECK(path3.value_or("NULL") == "NULL"); - CHECK(path4.value_or("NULL") == "/test?value=test"); -} - -TEST_CASE("Network", "[network][handle_errors]") { - auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; - auto ed_pk2 = "5ea34e72bb044654a6a23675690ef5ffaaf1656b02f93fb76655f9cbdbe89876"_hexbytes; - auto ed_sk = - "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab78862834829a" - "87e0afadfed763fa8785e893dbde7f2c001ff1071aa55005c347f"_hexbytes; - auto x_pk_hex = "d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"; - auto target = test_node(ed_pk, 0); - auto target2 = test_node(ed_pk2, 1); - auto target3 = test_node(ed_pk2, 2); - auto target4 = test_node(ed_pk2, 3); - auto path = - onion_path{"Test", {target, nullptr, nullptr, nullptr}, {target, target2, target3}, 0}; - auto mock_request = request_info{ - "AAAA", - target, - "test", - std::nullopt, - std::nullopt, - std::nullopt, - PathType::standard, - 0ms, - std::nullopt, - std::chrono::system_clock::now(), - std::nullopt, - true}; - Result result; - std::optional network; - - // Check the handling of the codes which make no changes - auto codes_with_no_changes = {400, 404, 406, 425}; - - for (auto code : codes_with_no_changes) { - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request, - {target, nullptr, nullptr, nullptr}, - false, - code, - {}, - std::nullopt, - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == code); - CHECK_FALSE(result.response.has_value()); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 0); - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_failure_count(PathType::standard, path) == 0); - } - - // Check general error handling (first failure) - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request, - {target, nullptr, nullptr, nullptr}, - false, - 500, - {}, - std::nullopt, - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 500); - CHECK_FALSE(result.response.has_value()); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 0); - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_failure_count(PathType::standard, path) == 1); - - // Check general error handling with no response (too many path failures) - path = onion_path{"Test", {target, nullptr, nullptr, nullptr}, {target, target2, target3}, 9}; - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request, - {target, nullptr, nullptr, nullptr}, - false, - 500, - {}, - std::nullopt, - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 500); - CHECK_FALSE(result.response.has_value()); - CHECK(network->get_failure_count(target) == 3); // Guard node dropped - CHECK(network->get_failure_count(target2) == 1); // Other nodes incremented - CHECK(network->get_failure_count(target3) == 1); // Other nodes incremented - CHECK(network->get_failure_count(PathType::standard, path) == 0); // Path dropped and reset - - // Check general error handling with a path and specific node failure - path = onion_path{"Test", {target, nullptr, nullptr, nullptr}, {target, target2, target3}, 0}; - auto response = std::string{"Next node not found: "} + ed25519_pubkey::from_bytes(ed_pk2).hex(); - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_snode_cache({target, target2, target3, target4}); - network->set_unused_nodes({target4}); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request, - {target, nullptr, nullptr, nullptr}, - false, - 500, - {}, - response, - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 500); - CHECK(result.response == response); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 3); // Node will have been dropped - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_paths(PathType::standard).front().nodes[1] != target2); - CHECK(network->get_failure_count(PathType::standard, path) == - 1); // Incremented because conn_info is invalid - - // Check a 421 with no swarm data throws (no good way to handle this case) - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request, - {target, nullptr, nullptr, nullptr}, - false, - 421, - {}, - std::nullopt, - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 421); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 0); - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_failure_count(PathType::standard, path) == 1); - - // Check a non redirect 421 triggers a retry using a different node - auto mock_request2 = request_info{ - "BBBB", - target, - "test", - std::nullopt, - std::nullopt, - x25519_pubkey::from_hex(x_pk_hex), - PathType::standard, - 0ms, - std::nullopt, - std::chrono::system_clock::now(), - std::nullopt, - true}; - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_swarm(x25519_pubkey::from_hex(x_pk_hex), 1, {target, target2, target3}); - network->set_paths(PathType::standard, {path}); - network->reset_calls(); - network->handle_errors( - mock_request2, - {target, nullptr, nullptr, nullptr}, - false, - 421, - {}, - std::nullopt, - [](bool, - bool, - int16_t, - std::vector>, - std::optional) {}); - CHECK(network->called("_send_onion_request", 100ms)); - REQUIRE(network->last_request_info.has_value()); - CHECK(node_for_destination(network->last_request_info->destination) != - node_for_destination(mock_request2.destination)); - - // Check that when a retry request of a 421 receives it's own 421 that it tries - // to update the snode cache - auto mock_request3 = request_info{ - "BBBB", - target, - "test", - std::nullopt, - std::nullopt, - x25519_pubkey::from_hex(x_pk_hex), - PathType::standard, - 0ms, - std::nullopt, - std::chrono::system_clock::now(), - request_info::RetryReason::redirect, - true}; - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to( - "_send_onion_request", "update_disk_cache_throttled", "refresh_snode_cache"); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request3, - {target, nullptr, nullptr, nullptr}, - false, - 421, - {}, - std::nullopt, - [](bool, - bool, - int16_t, - std::vector>, - std::optional) {}); - CHECK(network->called("refresh_snode_cache", 100ms)); - - // Check when the retry after refreshing the snode cache due to a 421 receives it's own 421 it - // is handled like any other error - auto mock_request4 = request_info{ - "BBBB", - target, - "test", - std::nullopt, - std::nullopt, - x25519_pubkey::from_hex(x_pk_hex), - PathType::standard, - 0ms, - std::nullopt, - std::chrono::system_clock::now(), - request_info::RetryReason::redirect_swarm_refresh, - true}; - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->set_paths(PathType::standard, {path}); - network->handle_errors( - mock_request4, - {target, nullptr, nullptr, nullptr}, - false, - 421, - {}, - std::nullopt, - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 421); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 0); - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_failure_count(PathType::standard, path) == 1); - - // Check a timeout with a sever destination doesn't impact the failure counts - auto server = ServerDestination{ - "https", - "open.getsession.org", - "/rooms", - x25519_pubkey::from_hex("a03c383cf63c3c4efe67acc52112a6dd734b3a946b9545f488aaa93da79912" - "38"), - 443, - std::nullopt, - "GET"}; - auto mock_request5 = request_info{ - "CCCC", - server, - "test", - std::nullopt, - std::nullopt, - x25519_pubkey::from_hex(x_pk_hex), - PathType::standard, - 0ms, - std::nullopt, - std::chrono::system_clock::now(), - std::nullopt, - false}; - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->handle_errors( - mock_request5, - {target, nullptr, nullptr, nullptr}, - true, - -1, - {}, - "Test", - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - CHECK_FALSE(result.success); - CHECK(result.timeout); - CHECK(result.status_code == -1); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 0); - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_failure_count(PathType::standard, path) == 0); - - // Check a server response starting with '500 Internal Server Error' is reported as a `500` - // error and doesn't affect the failure count - network.emplace(std::nullopt, true, true, false); - network->set_suspended(true); // Make no requests in this test - network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); - network->handle_errors( - mock_request4, - {target, nullptr, nullptr, nullptr}, - false, - -1, - {}, - "500 Internal Server Error", - [&result]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result = {success, timeout, status_code, headers, response}; - }); - CHECK_FALSE(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 500); - CHECK(network->get_failure_count(target) == 0); - CHECK(network->get_failure_count(target2) == 0); - CHECK(network->get_failure_count(target3) == 0); - CHECK(network->get_failure_count(PathType::standard, path) == 0); -} - -TEST_CASE("Network", "[network][get_unused_nodes]") { - const auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; - std::optional network; - std::vector snode_cache; - std::vector unused_nodes; - for (uint16_t i = 0; i < 12; ++i) - snode_cache.emplace_back(test_node(ed_pk, i)); - auto invalid_info = connection_info{snode_cache[0], nullptr, nullptr, nullptr}; - auto path = - onion_path{"Test", invalid_info, {snode_cache[0], snode_cache[1], snode_cache[2]}, 0}; - - auto compare_service_nodes = [](const service_node& a, const service_node& b) { - if (auto cmp = oxen::quic::Address(a) <=> oxen::quic::Address(b); cmp != 0) - return cmp < 0; - - return std::tie(a.get_remote_key(), a.swarm_id, a.storage_server_version) < - std::tie(b.get_remote_key(), b.swarm_id, b.storage_server_version); - }; - - // Should shuffle the result - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - CHECK(network->get_unused_nodes() != network->get_unused_nodes()); - - // Should contain the entire snode cache initially - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - unused_nodes = network->get_unused_nodes(); - std::stable_sort(unused_nodes.begin(), unused_nodes.end(), compare_service_nodes); - CHECK(unused_nodes == snode_cache); - - // Should exclude nodes used in paths - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_paths(PathType::standard, {path}); - unused_nodes = network->get_unused_nodes(); - std::stable_sort(unused_nodes.begin(), unused_nodes.end(), compare_service_nodes); - CHECK(unused_nodes == std::vector{snode_cache.begin() + 3, snode_cache.end()}); - - // Should exclude nodes in unused connections - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - unused_nodes = network->get_unused_nodes(); - std::stable_sort(unused_nodes.begin(), unused_nodes.end(), compare_service_nodes); - CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); - - // Should exclude nodes in in-progress connections - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_in_progress_connections({{"Test", snode_cache.front()}}); - unused_nodes = network->get_unused_nodes(); - std::stable_sort(unused_nodes.begin(), unused_nodes.end(), compare_service_nodes); - CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); - - // Should exclude nodes destinations in pending requests - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->add_pending_request( - PathType::standard, - request_info::make( - snode_cache.front(), - std::nullopt, - std::nullopt, - 1s, - std::nullopt, - PathType::standard)); - unused_nodes = network->get_unused_nodes(); - std::stable_sort(unused_nodes.begin(), unused_nodes.end(), compare_service_nodes); - CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); - - // Should exclude nodes which have passed the failure threshold - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_failure_count(snode_cache.front(), 10); - unused_nodes = network->get_unused_nodes(); - std::stable_sort(unused_nodes.begin(), unused_nodes.end(), compare_service_nodes); - CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); - - // Should exclude nodes which have the same IP if one was excluded - std::vector same_ip_snode_cache; - auto unique_node = service_node{ed_pk, {2, 8, 0}, INVALID_SWARM_ID, "0.0.0.20", uint16_t{20}}; - for (uint16_t i = 0; i < 11; ++i) - same_ip_snode_cache.emplace_back(test_node(ed_pk, i, false)); - same_ip_snode_cache.emplace_back(unique_node); - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(same_ip_snode_cache); - network->set_failure_count(same_ip_snode_cache.front(), 10); - unused_nodes = network->get_unused_nodes(); - REQUIRE(unused_nodes.size() == 1); - CHECK(unused_nodes.front() == unique_node); -} - -TEST_CASE("Network", "[network][build_path]") { - const auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; - std::optional network; - std::vector snode_cache; - for (uint16_t i = 0; i < 12; ++i) - snode_cache.emplace_back(test_node(ed_pk, i)); - auto invalid_info = connection_info{snode_cache[0], nullptr, nullptr, nullptr}; - - // Nothing should happen if the network is suspended - network.emplace(std::nullopt, true, false, false); - network->set_suspended(true); - network->build_path("Test1", PathType::standard); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - - // If there are no unused connections it puts the path build in the queue and calls - // establish_and_store_connection - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->build_path("Test1", PathType::standard); - CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); - CHECK(network->called("establish_and_store_connection", 100ms)); - - // If the unused nodes are empty it refreshes them - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - network->set_in_progress_connections({{"TestInProgress", snode_cache.front()}}); - network->build_path("Test1", PathType::standard); - CHECK(network->get_unused_nodes_value().size() == snode_cache.size() - 3); - CHECK(network->get_path_build_queue().empty()); - - // It should exclude nodes that are already in existing paths - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - network->set_in_progress_connections({{"TestInProgress", snode_cache.front()}}); - network->add_path(PathType::standard, {snode_cache.begin() + 1, snode_cache.begin() + 1 + 3}); - network->build_path("Test1", PathType::standard); - CHECK(network->get_unused_nodes_value().size() == (snode_cache.size() - 3 - 3)); - CHECK(network->get_path_build_queue().empty()); - - // If there aren't enough unused nodes it resets the failure count, re-queues the path build and - // triggers a snode cache refresh - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("refresh_snode_cache"); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - network->set_path_build_failures(10); - network->add_path(PathType::standard, snode_cache); - network->build_path("Test1", PathType::standard); - CHECK(network->get_path_build_failures() == 0); - CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); - CHECK(network->called("refresh_snode_cache", 100ms)); - - // If it can't build a path after excluding nodes with the same IP it increments the - // failure count and re-tries the path build after a small delay - network.emplace(std::nullopt, true, false, false); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - network->set_unused_nodes(std::vector{ - snode_cache[0], snode_cache[0], snode_cache[0], snode_cache[0]}); - network->build_path("Test1", PathType::standard); - network->ignore_calls_to("build_path"); // Ignore the 2nd loop - CHECK(network->get_path_build_failures() == 1); - CHECK(network->get_path_build_queue().empty()); - CHECK(network->called("build_path", 100ms, 2)); - - // It stores a successful non-standard path and kicks of queued requests but doesn't update the - // status or call the 'paths_changed' hook - network.emplace(std::nullopt, true, false, false); - network->find_valid_path_response = - onion_path{"Test", invalid_info, {snode_cache.begin(), snode_cache.begin() + 3}, 0}; - network->ignore_calls_to("_send_onion_request"); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - network->add_pending_request( - PathType::download, - request_info::make( - snode_cache.back(), - std::nullopt, - std::nullopt, - 1s, - std::nullopt, - PathType::download)); - network->build_path("Test1", PathType::download); - CHECK(network->called("_send_onion_request", 100ms)); - CHECK(network->get_paths(PathType::download).size() == 1); - - // It stores a successful 'standard' path, updates the status, calls the 'paths_changed' hook - // and kicks of queued requests - network.emplace(std::nullopt, true, false, false); - network->find_valid_path_response = - onion_path{"Test", invalid_info, {snode_cache.begin(), snode_cache.begin() + 3}, 0}; - network->ignore_calls_to("_send_onion_request"); - network->set_snode_cache(snode_cache); - network->set_unused_connections({invalid_info}); - network->add_pending_request( - PathType::standard, - request_info::make( - snode_cache.back(), - std::nullopt, - std::nullopt, - 1s, - std::nullopt, - PathType::standard)); - network->build_path("Test1", PathType::standard); - CHECK(network->called("_send_onion_request", 100ms)); - CHECK(network->get_paths(PathType::standard).size() == 1); - CHECK(network->get_status() == ConnectionStatus::connected); - CHECK(network->called("paths_changed")); -} - -TEST_CASE("Network", "[network][find_valid_path]") { - auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; - auto test_server = create_test_server(5000); - auto target = test_node(ed_pk, 1); - auto info = request_info::make(target, std::nullopt, std::nullopt, 0ms); - - auto network = TestNetwork(std::nullopt, true, false, false); - auto invalid_path = onion_path{ - "Test", - {test_server->node, nullptr, nullptr, nullptr}, - {test_server->node}, - uint8_t{0}}; - - // It returns nothing when given no path options - CHECK_FALSE(network.find_valid_path(info, {}).has_value()); - - // It ignores invalid paths - CHECK_FALSE(network.find_valid_path(info, {invalid_path}).has_value()); - - // Need to get a valid path for subsequent tests - std::promise>> prom; - - network.establish_connection( - "Test", - test_server->node, - 3s, - [&prom](connection_info conn_info, std::optional error) { - prom.set_value({std::move(conn_info), error}); - }); - - // Wait for the result to be set - auto result = prom.get_future().get(); - REQUIRE(result.first.is_valid()); - auto valid_path = onion_path{ - "Test", - std::move(result.first), - std::vector{test_server->node}, - uint8_t{0}}; - - // It excludes paths which include the IP of the target - auto shared_ip_info = request_info::make(test_server->node, std::nullopt, std::nullopt, 0ms); - CHECK_FALSE(network.find_valid_path(shared_ip_info, {valid_path}).has_value()); - - // It returns a path when there is a valid one - CHECK(network.find_valid_path(info, {valid_path}).has_value()); - - // In 'single_path_mode' it does allow the path to include the IP of the target (so that - // requests can still be made) - auto network_single_path = TestNetwork(std::nullopt, true, true, false); - CHECK(network_single_path.find_valid_path(shared_ip_info, {valid_path}).has_value()); -} - -TEST_CASE("Network", "[network][build_path_if_needed]") { - auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; - auto target = test_node(ed_pk, 0); - - std::optional network; - auto invalid_path = onion_path{ - "Test", connection_info{target, nullptr, nullptr, nullptr}, {target}, uint8_t{0}}; - - // It does not add additional path builds if there is already a path and it's in - // 'single_path_mode' - network.emplace(std::nullopt, true, true, false); - network->ignore_calls_to("establish_and_store_connection"); - network->set_paths(PathType::standard, {invalid_path}); - network->build_path_if_needed(PathType::standard, false); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - CHECK(network->get_path_build_queue().empty()); - - // Adds a path build to the queue - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->set_paths(PathType::standard, {}); - network->build_path_if_needed(PathType::standard, false); - CHECK(network->called("establish_and_store_connection", 100ms)); - CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); - - // Can only add the correct number of 'standard' path builds to the queue - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->build_path_if_needed(PathType::standard, false); - network->build_path_if_needed(PathType::standard, false); - CHECK(network->called("establish_and_store_connection", 100ms, 2)); - network->reset_calls(); // This triggers 'call_soon' so we need to wait until they are enqueued - network->build_path_if_needed(PathType::standard, false); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - CHECK(network->get_path_build_queue() == - std::deque{PathType::standard, PathType::standard}); - - // Can add additional 'standard' path builds if below the minimum threshold - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->set_paths(PathType::standard, {invalid_path}); - network->build_path_if_needed(PathType::standard, false); - CHECK(network->called("establish_and_store_connection", 100ms)); - CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); - - // Can add more path builds if there are enough active paths of the same type, no pending paths - // and no `found_path` was provided - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->set_paths(PathType::standard, {invalid_path, invalid_path}); - network->build_path_if_needed(PathType::standard, false); - CHECK(network->called("establish_and_store_connection", 100ms)); - CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); - - // Cannot add more path builds if there are already enough active paths of the same type and a - // `found_path` was provided - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->set_paths(PathType::standard, {invalid_path, invalid_path}); - network->build_path_if_needed(PathType::standard, true); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - CHECK(network->get_path_build_queue().empty()); - - // Cannot add more path builds if there is already a build of the same type in the queue and the - // number of active and pending builds of the same type meet the limit - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->set_paths(PathType::standard, {invalid_path}); - network->set_path_build_queue({PathType::standard}); - network->build_path_if_needed(PathType::standard, false); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); - - // Can only add the correct number of 'download' path builds to the queue - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->build_path_if_needed(PathType::download, false); - network->build_path_if_needed(PathType::download, false); - CHECK(network->called("establish_and_store_connection", 100ms, 2)); - network->reset_calls(); // This triggers 'call_soon' so we need to wait until they are enqueued - network->build_path_if_needed(PathType::download, false); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - CHECK(network->get_path_build_queue() == - std::deque{PathType::download, PathType::download}); - - // Can only add the correct number of 'upload' path builds to the queue - network.emplace(std::nullopt, true, false, false); - network->ignore_calls_to("establish_and_store_connection"); - network->build_path_if_needed(PathType::upload, false); - network->build_path_if_needed(PathType::upload, false); - CHECK(network->called("establish_and_store_connection", 100ms, 2)); - network->reset_calls(); // This triggers 'call_soon' so we need to wait until they are enqueued - network->build_path_if_needed(PathType::upload, false); - CHECK(network->did_not_call("establish_and_store_connection", 25ms)); - CHECK(network->get_path_build_queue() == - std::deque{PathType::upload, PathType::upload}); -} - -TEST_CASE("Network", "[network][establish_connection]") { - auto test_server = create_test_server(5100); - auto network = TestNetwork(std::nullopt, true, true, false); - std::promise>> prom; - - network.establish_connection( - "Test", - test_server->node, - 3s, - [&prom](connection_info info, std::optional error) { - prom.set_value({info, error}); - }); - - // Wait for the result to be set - auto result = prom.get_future().get(); - - CHECK(result.first.is_valid()); - CHECK_FALSE(result.second.has_value()); -} - -TEST_CASE("Network", "[network][check_request_queue_timeouts]") { - std::optional network; - std::optional> test_server; - std::promise prom; - - // Test that it doesn't start checking for timeouts when the request doesn't have - // a build paths timeout - network.emplace(std::nullopt, true, true, false); - test_server.emplace(create_test_server(5201)); - network->send_onion_request( - (*test_server)->node, - to_vector("{\"method\":\"info\",\"params\":{}}"), - std::nullopt, - [](bool, - bool, - int16_t, - std::vector>, - std::optional) {}, - oxen::quic::DEFAULT_TIMEOUT, - std::nullopt); - CHECK(network->did_not_call("check_request_queue_timeouts", 300ms)); - - // Test that it does start checking for timeouts when the request has a - // paths build timeout - network.emplace(std::nullopt, true, true, false); - test_server.emplace(create_test_server(5202)); - network->ignore_calls_to("build_path"); - network->send_onion_request( - (*test_server)->node, - to_vector("{\"method\":\"info\",\"params\":{}}"), - std::nullopt, - [](bool, - bool, - int16_t, - std::vector>, - std::optional) {}, - oxen::quic::DEFAULT_TIMEOUT, - oxen::quic::DEFAULT_TIMEOUT); - CHECK(network->called("check_request_queue_timeouts", 300ms)); - - // Test that it fails the request with a timeout if it has a build path timeout - // and the path build takes too long - network.emplace(std::nullopt, true, true, false); - test_server.emplace(create_test_server(5203)); - network->ignore_calls_to("build_path"); - network->send_onion_request( - (*test_server)->node, - to_vector("{\"method\":\"info\",\"params\":{}}"), - std::nullopt, - [&prom](bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - prom.set_value({success, timeout, status_code, headers, response}); - }, - oxen::quic::DEFAULT_TIMEOUT, - 100ms); - - // Wait for the result to be set - auto result = prom.get_future().get(); - - CHECK_FALSE(result.success); - CHECK(result.timeout); -} - -TEST_CASE("Network", "[network][send_request]") { - auto test_server = create_test_server(5300); - auto network = TestNetwork(std::nullopt, true, true, false); - std::promise prom; - - network.establish_connection( - "Test", - test_server->node, - 3s, - [&prom, &network, &test_server]( - connection_info info, std::optional error) { - if (!info.is_valid()) - return prom.set_value({false, false, -1, {}, error.value_or("Unknown Error")}); - - network.send_request( - request_info::make( - test_server->node, - to_vector("{}"), - std::nullopt, - 3s, - std::nullopt, - PathType::standard, - std::nullopt, - "info"), - std::move(info), - [&prom](bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - prom.set_value({success, timeout, status_code, headers, response}); - }); - }); - - // Wait for the result to be set - auto result = prom.get_future().get(); - - CHECK(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 200); - REQUIRE(result.response.has_value()); - INFO("*result.response is: " << *result.response); - REQUIRE_NOTHROW([&] { [[maybe_unused]] auto _ = nlohmann::json::parse(*result.response); }); - - auto response = nlohmann::json::parse(*result.response); - REQUIRE(response.contains("hf")); - auto hf = response["hf"].get>(); - CHECK(hf.size() == 3); - CHECK(hf[0] == 1); // Called the info callback - CHECK(response.contains("t")); - CHECK(response.contains("version")); -} - -TEST_CASE("Network", "[network][send_onion_request]") { - auto test_server = create_test_server(5400); - auto network = TestNetwork(std::nullopt, true, true, false); - auto [test_path_servers, test_path] = network.create_test_path(); - network.handle_onion_requests_as_plaintext = true; - network.set_paths(PathType::standard, {test_path}); - std::promise result_promise; - - network.send_onion_request( - test_server->node, - to_vector("{\"method\":\"info\",\"params\":{}}"), - std::nullopt, - [&result_promise]( - bool success, - bool timeout, - int16_t status_code, - std::vector> headers, - std::optional response) { - result_promise.set_value({success, timeout, status_code, headers, response}); - }, - oxen::quic::DEFAULT_TIMEOUT, - oxen::quic::DEFAULT_TIMEOUT); - - // Wait for the result to be set - auto result = result_promise.get_future().get(); - - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 200); - REQUIRE(result.success); - REQUIRE(result.response.has_value()); - INFO("*result.response is: " << *result.response); - REQUIRE_NOTHROW([&] { [[maybe_unused]] auto _ = nlohmann::json::parse(*result.response); }); - - auto response = nlohmann::json::parse(*result.response); - REQUIRE(response.contains("hf")); - auto hf = response["hf"].get>(); - CHECK(hf.size() == 3); - CHECK(hf[0] == 2); // Called the onion_req callback - CHECK(response.contains("t")); - CHECK(response.contains("version")); -} - -TEST_CASE("Network", "[network][c][network_send_onion_request]") { - auto test_server_cpp = create_test_server(5500); - auto test_network = std::make_unique(std::nullopt, true, true, false); - std::optional>, onion_path>> test_path_data; - test_path_data.emplace(test_network->create_test_path()); - test_network->handle_onion_requests_as_plaintext = true; - test_network->set_paths(PathType::standard, {test_path_data->second}); - - // Convert TestNetwork to network_object to pass to C API - auto n_object = std::make_unique(); - n_object->internals = test_network.release(); - network_object* network = n_object.release(); - - // Convert test_server_cpp->node to network_service_node to pass to C API - auto ip_v4 = test_server_cpp->node.to_ipv4(); - std::array target_ip = { - static_cast(ip_v4.addr >> 24), - static_cast((ip_v4.addr >> 16) & 0xFF), - static_cast((ip_v4.addr >> 8) & 0xFF), - static_cast(ip_v4.addr & 0xFF)}; - auto test_service_node = network_service_node{}; - test_service_node.quic_port = test_server_cpp->node.port(); - std::copy(target_ip.begin(), target_ip.end(), test_service_node.ip); - auto test_pubkey_hex = oxenc::to_hex(test_server_cpp->node.view_remote_key()); - std::strcpy(test_service_node.ed25519_pubkey_hex, test_pubkey_hex.c_str()); - - // Make the request - auto body = to_vector("{\"method\":\"info\",\"params\":{}}"); - auto result_promise = std::make_shared>(); - - network_send_onion_request_to_snode_destination( - network, - test_service_node, - body.data(), - body.size(), - nullptr, - std::chrono::milliseconds{oxen::quic::DEFAULT_TIMEOUT}.count(), - std::chrono::milliseconds{oxen::quic::DEFAULT_TIMEOUT}.count(), - [](bool success, - bool timeout, - int16_t status_code, - const char* const* headers, - const char* const* header_values, - size_t headers_size, - const char* c_response, - size_t response_size, - void* ctx) { - auto result_promise = static_cast*>(ctx); - auto response_str = std::string(c_response, response_size); - std::vector> header_pairs; - header_pairs.reserve(headers_size); - - for (size_t i = 0; i < headers_size; ++i) { - if (headers[i] == nullptr) - continue; // Skip null entries - if (header_values[i] == nullptr) - continue; // Skip null entries - - header_pairs.emplace_back(headers[i], header_values[i]); - } - - result_promise->set_value( - {success, timeout, status_code, header_pairs, response_str}); - }, - static_cast(result_promise.get())); - - // Wait for the result to be set - auto result = result_promise->get_future().get(); - - CHECK(result.success); - CHECK_FALSE(result.timeout); - CHECK(result.status_code == 200); - REQUIRE(result.response.has_value()); - INFO("*result.response is: " << *result.response); - REQUIRE_NOTHROW([&] { [[maybe_unused]] auto _ = nlohmann::json::parse(*result.response); }); - - auto response = nlohmann::json::parse(*result.response); - REQUIRE(response.contains("hf")); - auto hf = response["hf"].get>(); - CHECK(hf.size() == 3); - CHECK(hf[0] == 2); // Called the onion_req callback - CHECK(response.contains("t")); - CHECK(response.contains("version")); - test_path_data.reset(); - network_free(network); -} - -TEST_CASE("Network", "[network][detail][pubkey_to_swarm_space]") { - x25519_pubkey pk; - - pk = x25519_pubkey::from_hex( - "3506f4a71324b7dd114eddbf4e311f39dde243e1f2cb97c40db1961f70ebaae8"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 17589930838143112648ULL); - pk = x25519_pubkey::from_hex( - "cf27da303a50ac8c4b2d43d27259505c9bcd73fc21cf2a57902c3d050730b604"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 10370619079776428163ULL); - pk = x25519_pubkey::from_hex( - "d3511706b8b34f6e8411bf07bd22ba6b2435ca56846fbccf6eb1e166a6cd15cc"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 2144983569669512198ULL); - pk = x25519_pubkey::from_hex( - "0f06693428fca9102a451e3f28d9cc743d8ea60a89ab6aa69eb119470c11cbd3"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 9690840703409570833ULL); - pk = x25519_pubkey::from_hex( - "ffba630924aa1224bb930dde21c0d11bf004608f2812217f8ac812d6c7e3ad48"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 4532060000165252872ULL); - pk = x25519_pubkey::from_hex( - "eeeeeeeeeeeeeeee777777777777777711111111111111118888888888888888"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 0); - pk = x25519_pubkey::from_hex( - "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 0); - pk = x25519_pubkey::from_hex( - "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 1); - pk = x25519_pubkey::from_hex( - "ffffffffffffffffffffffffffffffffffffffffffffffff7fffffffffffffff"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 1ULL << 63); - pk = x25519_pubkey::from_hex( - "000000000000000000000000000000000000000000000000ffffffffffffffff"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == (uint64_t)-1); - pk = x25519_pubkey::from_hex( - "0000000000000000000000000000000000000000000000000123456789abcdef"); - CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 0x0123456789abcdefULL); -} - -TEST_CASE("Network", "[network][get_swarm]") { - auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; - std::vector>> swarms = { - {100, {}}, {200, {}}, {300, {}}, {399, {}}, {498, {}}, {596, {}}, {694, {}}}; - auto network = TestNetwork(std::nullopt, true, true, false); - network.set_snode_cache({test_node(ed_pk, 0)}); - network.set_all_swarms(swarms); - - // Exact matches: - // 0x64 = 100, 0xc8 = 200, 0x1f2 = 498 - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000006" - "4") == 100); - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000c" - "8") == 200); - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000001f" - "2") == 498); - - // Nearest - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000000" - "0") == 100); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000000" - "1") == 100); - - // Nearest, with wraparound - // 0x8000... is closest to the top value - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000000" - "0") == 694); - - // 0xa000... is closest (via wraparound) to the smallest - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000a00000000000000" - "0") == 100); - - // This is the invalid swarm id for swarms, but should still work for a client - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" - "f") == 100); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" - "e") == 100); - - // Midpoint tests; we prefer the lower value when exactly in the middle between two swarms. - // 0x96 = 150 - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000009" - "5") == 100); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000009" - "6") == 100); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000009" - "7") == 200); - - // 0xfa = 250 - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000f" - "9") == 200); - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000f" - "a") == 200); - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000f" - "b") == 300); - - // 0x15d = 349 - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000015" - "d") == 300); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000015" - "e") == 399); - - // 0x1c0 = 448 - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000001c" - "0") == 399); - CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000001c" - "1") == 498); - - // 0x223 = 547 - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000022" - "2") == 498); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000022" - "3") == 498); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000022" - "4") == 596); - - // 0x285 = 645 - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000028" - "5") == 596); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000028" - "6") == 694); - - // 0x800....d is the midpoint between 694 and 100 (the long way). We always round "down" (which - // in this case, means wrapping to the largest swarm). - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000018" - "c") == 694); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000018" - "d") == 694); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000018" - "e") == 100); - - // With a swarm at -20 the midpoint is now 40 (=0x28). When our value is the *low* value we - // prefer the *last* swarm in the case of a tie (while consistent with the general case of - // preferring the left edge, it means we're inconsistent with the other wraparound case, above. - // *sigh*). - swarms.push_back({(uint64_t)-20, {}}); - network.set_all_swarms(swarms); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000002" - "7") == swarms.back().first); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000002" - "8") == swarms.back().first); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000002" - "9") == swarms.front().first); - - // The code used to have a broken edge case if we have a swarm at zero and a client at max-u64 - // because of an overflow in how the distance is calculated (the first swarm will be calculated - // as max-u64 away, rather than 1 away), and so the id always maps to the highest swarm (even - // though 0xfff...fe maps to the lowest swarm; the first check here, then, would fail. - swarms.insert(swarms.begin(), {0, {}}); - network.set_all_swarms(swarms); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" - "f") == 0); - CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" - "e") == 0); -} diff --git a/tests/test_snode_pool.cpp b/tests/test_snode_pool.cpp new file mode 100644 index 00000000..e74b457b --- /dev/null +++ b/tests/test_snode_pool.cpp @@ -0,0 +1,154 @@ +#include +#include + +#include "utils.hpp" + +using namespace session; +using namespace session::network; + +namespace session::network { + +class TestSnodePool : public SnodePool { + public: + std::optional> mock_unused_nodes; + + TestSnodePool( + config::SnodePoolConfig config, + std::shared_ptr loop, + network_fetcher_t direct_fetcher = [](Request, network_response_callback_t) {}) : + SnodePool(std::move(config), std::move(loop), std::move(direct_fetcher)) {} + + void reset_state_with_cache(std::vector cache) { + std::unique_lock lock{_cache_mutex}; + _snode_cache = cache; + _snode_failure_counts.clear(); + } + + void refresh_if_needed( + const std::vector& in_use_nodes, + std::function on_refresh_complete = nullptr) override { + // Do nothing (don't want to trigger a cache refresh) + } +}; +} // namespace session::network + +TEST_CASE("Network", "[network][get_unused_nodes]") { + session::network::config::SnodePoolConfig pool_config = { + std::nullopt, + std::chrono::minutes{5}, + std::chrono::minutes{5}, + false, // enforce_subnet_diversity + network::opt::retry_delay{50ms, 200ms}, + opt::netid::Target::testnet, + {}, + 0, + 0, + 3, // cache_node_failure_threshold + false}; + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto ed_pk2 = "5ea34e72bb044654a6a23675690ef5ffaaf1656b02f93fb76655f9cbdbe89876"_hexbytes; + auto ed_pk3 = "e17a692033200ae41350df9709754edde7343e2cf2f23e88f993319e0720e5e5"_hexbytes; + auto ed_pk4 = "7b633fa6fb462b90db6f0f50384190ce7715e31b7aa93d87dbd7e94e33d4251f"_hexbytes; + std::vector snode_cache; + std::vector unused_nodes; + + for (uint16_t i = 0; i < 5; ++i) { + snode_cache.emplace_back(service_node{ + ed_pk, + oxen::quic::ipv4{"192.168.0.{}"_format(i)}, + static_cast(20000 + i), + static_cast(30000 + i), + {2, 11, 0}, + 0}); + snode_cache.emplace_back(service_node{ + ed_pk2, + oxen::quic::ipv4{"192.168.1.{}"_format(i)}, + static_cast(20100 + i), + static_cast(30100 + i), + {2, 11, 0}, + 1}); + snode_cache.emplace_back(service_node{ + ed_pk3, + oxen::quic::ipv4{"192.168.2.{}"_format(i)}, + static_cast(20200 + i), + static_cast(30200 + i), + {2, 11, 0}, + 2}); + snode_cache.emplace_back(service_node{ + ed_pk4, + oxen::quic::ipv4{"192.168.3.{}"_format(i)}, + static_cast(20300 + i), + static_cast(30300 + i), + {2, 11, 0}, + 3}); + } + std::sort(snode_cache.begin(), snode_cache.end()); + + auto loop = std::make_shared(); + auto snode_pool = std::make_shared(pool_config, loop); + snode_pool->reset_state_with_cache(snode_cache); + + // Should return a result in a different order (since this is random, it's possible that it + // could return the same order so repeat up to 5 times to make the chance of this negligible) + snode_pool->reset_state_with_cache(snode_cache); + auto results_differed = false; + auto first_result = snode_pool->get_unused_nodes(20); + + for (auto i = 0; i < 5; ++i) { + auto next_result = snode_pool->get_unused_nodes(20); + + if (next_result != first_result) { + results_differed = true; + break; + } + } + INFO("get_unused_nodes() produced the same result 5 times in a row."); + CHECK(results_differed); + + // Should contain the entire snode cache initially + snode_pool->reset_state_with_cache(snode_cache); + unused_nodes = snode_pool->get_unused_nodes(20); + std::sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == snode_cache); + + // Should exclude nodes in the exclusion list + snode_pool->reset_state_with_cache(snode_cache); + std::vector excluded(snode_cache.begin(), snode_cache.begin() + 10); + std::vector remaining(snode_cache.begin() + 10, snode_cache.end()); + unused_nodes = snode_pool->get_unused_nodes(24, excluded); + std::sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == remaining); + + // Should exclude nodes which have passed the failure threshold + snode_pool->reset_state_with_cache(snode_cache); + for (uint16_t i = 0; i < 10; ++i) { + snode_pool->record_node_failure(snode_cache[i], true); + } + unused_nodes = snode_pool->get_unused_nodes(10); + std::sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == remaining); + + // Should exclude nodes which have the same subnet + pool_config = { + std::nullopt, + std::chrono::minutes{5}, + std::chrono::minutes{5}, + true, // enforce_subnet_diversity + network::opt::retry_delay{50ms, 200ms}, + opt::netid::Target::testnet, + {}, + 0, + 0, + 3, // cache_node_failure_threshold + false}; + snode_pool = std::make_shared(pool_config, loop); + snode_pool->reset_state_with_cache(snode_cache); + unused_nodes = snode_pool->get_unused_nodes(20); + std::sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes.size() == 4); + + std::set result_subnets; + for (const auto& node : unused_nodes) + result_subnets.insert(node.ip.to_base(24)); + CHECK(result_subnets.size() == 4); +} \ No newline at end of file diff --git a/tests/test_utils.cpp b/tests/test_utils.cpp new file mode 100644 index 00000000..5db8927e --- /dev/null +++ b/tests/test_utils.cpp @@ -0,0 +1,27 @@ +#include + +#include "utils.hpp" + +TEST_CASE("Network", "[network][parse_url]") { + auto [proto1, host1, port1, path1] = session::parse_url("HTTPS://example.com/test"); + auto [proto2, host2, port2, path2] = session::parse_url("http://example2.com:1234/test/123456"); + auto [proto3, host3, port3, path3] = session::parse_url("https://example3.com"); + auto [proto4, host4, port4, path4] = session::parse_url("https://example4.com/test?value=test"); + + CHECK(proto1 == "https://"); + CHECK(proto2 == "http://"); + CHECK(proto3 == "https://"); + CHECK(proto4 == "https://"); + CHECK(host1 == "example.com"); + CHECK(host2 == "example2.com"); + CHECK(host3 == "example3.com"); + CHECK(host4 == "example4.com"); + CHECK(port1.value_or(9999) == 9999); + CHECK(port2.value_or(9999) == 1234); + CHECK(port3.value_or(9999) == 9999); + CHECK(port4.value_or(9999) == 9999); + CHECK(path1.value_or("NULL") == "/test"); + CHECK(path2.value_or("NULL") == "/test/123456"); + CHECK(path3.value_or("NULL") == "NULL"); + CHECK(path4.value_or("NULL") == "/test?value=test"); +} \ No newline at end of file diff --git a/tests/utils.hpp b/tests/utils.hpp index cf4d70cd..f1bff145 100644 --- a/tests/utils.hpp +++ b/tests/utils.hpp @@ -49,6 +49,86 @@ struct log_level_lowerer : log_level_override { log_level_lowerer(oxen::log::Level l, std::string category) : log_level_override{std::min(l, oxen::log::get_level(category)), category} {} }; + +class CallTracker { + protected: + std::unordered_map call_counts_; + std::mutex call_counts_mutex_; + std::condition_variable call_cv_; + std::vector calls_to_ignore_; + + public: + virtual ~CallTracker() = default; + + void func_called(const std::string& name) { + bool notify = false; + { + std::lock_guard lock(call_counts_mutex_); + ++call_counts_[name]; + notify = true; + } + + if (notify) + call_cv_.notify_all(); + } + + std::vector calls_to_ignore() { return calls_to_ignore_; } + + bool check_should_ignore_and_log_call(const std::string& name) { + func_called(name); + return std::find(calls_to_ignore_.begin(), calls_to_ignore_.end(), name) != + calls_to_ignore_.end(); + } + + template + void ignore_calls_to(Strings&&... args) { + (calls_to_ignore_.emplace_back(std::forward(args)), ...); + } + + void reset_calls() { + std::lock_guard lock(call_counts_mutex_); + call_counts_.clear(); + calls_to_ignore_.clear(); + } + + int get_call_count(const std::string& name) { + std::lock_guard lock(call_counts_mutex_); + auto it = call_counts_.find(name); + return (it != call_counts_.end()) ? it->second : 0; + } + + bool called(const std::string& name, int times = 1) { return (get_call_count(name) >= times); } + + [[nodiscard]] bool called( + const std::string& name, std::chrono::milliseconds timeout, int times = 1) { + if (times <= 0) + times = 1; + + std::unique_lock lock(call_counts_mutex_); + auto predicate = [&]() { + auto it = call_counts_.find(name); + return (it != call_counts_.end() && it->second >= times); + }; + return call_cv_.wait_for(lock, timeout, predicate); + } + + bool did_not_call(const std::string& name) { + std::lock_guard lock(call_counts_mutex_); + return !call_counts_.contains(name); + } + + [[nodiscard]] bool did_not_call(const std::string& name, std::chrono::milliseconds duration) { + std::unique_lock lock(call_counts_mutex_); + auto predicate = [&]() { return call_counts_.contains(name); }; + + if (predicate()) + return false; // Already called + + bool was_called_during_wait = call_cv_.wait_for(lock, duration, predicate); + return !was_called_during_wait; + } +}; + } // namespace session inline std::vector operator""_bytes(const char* x, size_t n) {