From 1286d20f6382dd47a9eb28374c6dc3cbf4bff130 Mon Sep 17 00:00:00 2001 From: Ilya Repin Date: Thu, 23 Oct 2025 07:17:03 +0000 Subject: [PATCH 1/9] Add UseDetectedLocalDC Policy --- include/ydb-cpp-sdk/client/types/ydb.h | 3 + src/client/impl/internal/CMakeLists.txt | 1 + .../internal/common/balancing_policies.cpp | 6 + .../impl/internal/common/balancing_policies.h | 3 + .../internal/db_driver_state/CMakeLists.txt | 1 + .../db_driver_state/endpoint_pool.cpp | 6 + .../internal/db_driver_state/endpoint_pool.h | 3 + .../internal/local_dc_detector/CMakeLists.txt | 13 ++ .../local_dc_detector/local_dc_detector.cpp | 75 ++++++++++ .../local_dc_detector/local_dc_detector.h | 43 ++++++ .../internal/local_dc_detector/pinger.cpp | 17 +++ .../impl/internal/local_dc_detector/pinger.h | 15 ++ src/client/table/impl/table_client.cpp | 2 +- src/client/types/ydb.cpp | 4 + tests/unit/client/CMakeLists.txt | 13 ++ .../local_dc_detector_ut.cpp | 138 ++++++++++++++++++ 16 files changed, 342 insertions(+), 1 deletion(-) create mode 100644 src/client/impl/internal/local_dc_detector/CMakeLists.txt create mode 100644 src/client/impl/internal/local_dc_detector/local_dc_detector.cpp create mode 100644 src/client/impl/internal/local_dc_detector/local_dc_detector.h create mode 100644 src/client/impl/internal/local_dc_detector/pinger.cpp create mode 100644 src/client/impl/internal/local_dc_detector/pinger.h create mode 100644 tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp diff --git a/include/ydb-cpp-sdk/client/types/ydb.h b/include/ydb-cpp-sdk/client/types/ydb.h index c62652c609..ae689699df 100644 --- a/include/ydb-cpp-sdk/client/types/ydb.h +++ b/include/ydb-cpp-sdk/client/types/ydb.h @@ -54,6 +54,9 @@ class TBalancingPolicy { //! location is a name of datacenter (VLA, MAN), if location is nullopt local datacenter is used static TBalancingPolicy UsePreferableLocation(const std::optional& location = {}); + //! Use detected local dc + static TBalancingPolicy UseDetectedLocalDC(); + //! Use all available cluster nodes regardless datacenter locality static TBalancingPolicy UseAllNodes(); diff --git a/src/client/impl/internal/CMakeLists.txt b/src/client/impl/internal/CMakeLists.txt index 5370c34f57..56bfbc0045 100644 --- a/src/client/impl/internal/CMakeLists.txt +++ b/src/client/impl/internal/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(common) add_subdirectory(db_driver_state) add_subdirectory(grpc_connections) +add_subdirectory(local_dc_detector) add_subdirectory(logger) add_subdirectory(make_request) add_subdirectory(plain_status) diff --git a/src/client/impl/internal/common/balancing_policies.cpp b/src/client/impl/internal/common/balancing_policies.cpp index 22ec50a622..7d2301231e 100644 --- a/src/client/impl/internal/common/balancing_policies.cpp +++ b/src/client/impl/internal/common/balancing_policies.cpp @@ -16,6 +16,12 @@ TBalancingPolicy::TImpl TBalancingPolicy::TImpl::UsePreferableLocation(const std return impl; } +TBalancingPolicy::TImpl TBalancingPolicy::TImpl::UseDetectedLocalDC() { + TBalancingPolicy::TImpl impl; + impl.PolicyType = EPolicyType::UseDetectedLocalDC; + return impl; +} + TBalancingPolicy::TImpl TBalancingPolicy::TImpl::UsePreferablePileState(EPileState pileState) { TBalancingPolicy::TImpl impl; impl.PolicyType = EPolicyType::UsePreferablePileState; diff --git a/src/client/impl/internal/common/balancing_policies.h b/src/client/impl/internal/common/balancing_policies.h index f1180f37ed..49a3ae505b 100644 --- a/src/client/impl/internal/common/balancing_policies.h +++ b/src/client/impl/internal/common/balancing_policies.h @@ -14,6 +14,7 @@ class TBalancingPolicy::TImpl { enum class EPolicyType { UseAllNodes, UsePreferableLocation, + UseDetectedLocalDC, UsePreferablePileState }; @@ -21,6 +22,8 @@ class TBalancingPolicy::TImpl { static TImpl UsePreferableLocation(const std::optional& location); + static TImpl UseDetectedLocalDC(); + static TImpl UsePreferablePileState(EPileState pileState); EPolicyType PolicyType; diff --git a/src/client/impl/internal/db_driver_state/CMakeLists.txt b/src/client/impl/internal/db_driver_state/CMakeLists.txt index 09089e4c10..cc6ec36d7b 100644 --- a/src/client/impl/internal/db_driver_state/CMakeLists.txt +++ b/src/client/impl/internal/db_driver_state/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries(impl-internal-db_driver_state PUBLIC client-impl-ydb_endpoints impl-internal-logger impl-internal-plain_status + impl-internal-local_dc_detector client-types-credentials ) diff --git a/src/client/impl/internal/db_driver_state/endpoint_pool.cpp b/src/client/impl/internal/db_driver_state/endpoint_pool.cpp index 8bdbd19262..a6435a580b 100644 --- a/src/client/impl/internal/db_driver_state/endpoint_pool.cpp +++ b/src/client/impl/internal/db_driver_state/endpoint_pool.cpp @@ -41,6 +41,10 @@ std::pair, bool> TEndpointPool::Updat TListEndpointsResult result = future.GetValue(); std::vector removed; if (result.DiscoveryStatus.Status == EStatus::SUCCESS) { + if (BalancingPolicy_.PolicyType == TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC) { + LocalDCDetector_.DetectLocalDC(result.Result); + } + std::vector records; // Is used to convert float to integer load factor // same integer values will be selected randomly. @@ -182,6 +186,8 @@ bool TEndpointPool::IsPreferredEndpoint(const Ydb::Discovery::EndpointInfo& endp return true; case TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation: return endpoint.location() == BalancingPolicy_.Location.value_or(selfLocation); + case TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC: + return LocalDCDetector_.IsLocalDC(endpoint.location()); case TBalancingPolicy::TImpl::EPolicyType::UsePreferablePileState: if (auto it = pileStates.find(endpoint.bridge_pile_name()); it != pileStates.end()) { return GetPileState(it->second.state()) == BalancingPolicy_.PileState; diff --git a/src/client/impl/internal/db_driver_state/endpoint_pool.h b/src/client/impl/internal/db_driver_state/endpoint_pool.h index b534593337..a1ec7263b0 100644 --- a/src/client/impl/internal/db_driver_state/endpoint_pool.h +++ b/src/client/impl/internal/db_driver_state/endpoint_pool.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -57,7 +58,9 @@ class TEndpointPool { TEndpointElectorSafe Elector_; NThreading::TPromise DiscoveryPromise_; std::atomic_uint64_t LastUpdateTime_; + const TBalancingPolicy::TImpl BalancingPolicy_; + TLocalDCDetector LocalDCDetector_; NSdkStats::TStatCollector* StatCollector_ = nullptr; diff --git a/src/client/impl/internal/local_dc_detector/CMakeLists.txt b/src/client/impl/internal/local_dc_detector/CMakeLists.txt new file mode 100644 index 0000000000..94b448751e --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/CMakeLists.txt @@ -0,0 +1,13 @@ +_ydb_sdk_add_library(impl-internal-local_dc_detector) + +target_link_libraries(impl-internal-local_dc_detector PUBLIC + yutil + api-grpc +) + +target_sources(impl-internal-local_dc_detector PRIVATE + local_dc_detector.cpp + pinger.cpp +) + +_ydb_sdk_install_targets(TARGETS impl-internal-local_dc_detector) diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp new file mode 100644 index 0000000000..b9fc78b35e --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp @@ -0,0 +1,75 @@ +#define INCLUDE_YDB_INTERNAL_H +#include "local_dc_detector.h" + +namespace NYdb::inline V3 { + +TLocalDCDetector::TLocalDCDetector(TPinger pingEndpoint) + : PingEndpoint_(std::move(pingEndpoint)) +{} + +void TLocalDCDetector::DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpointsList) { + auto endpointsByLocation = GroupEndpointsByLocation(endpointsList); + SampleEndpoints(endpointsByLocation); + + if (endpointsByLocation.size() > 1) { + Location_ = FindNearestLocation(endpointsByLocation); + } else { + Location_.clear(); + } +} + +bool TLocalDCDetector::IsLocalDC(const std::string& location) const { + return Location_.empty() || Location_ == location; +} + +TLocalDCDetector::TEndpointsByLocation TLocalDCDetector::GroupEndpointsByLocation(const Ydb::Discovery::ListEndpointsResult& endpointsList) const { + TEndpointsByLocation endpointsByLocation; + for (const auto& endpoint : endpointsList.endpoints()) { + endpointsByLocation[endpoint.location()].emplace_back(endpoint); + } + return endpointsByLocation; +} + +void TLocalDCDetector::SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const { + std::mt19937 gen(std::random_device{}()); + for (auto& [location, endpoints] : endpointsByLocation) { + if (endpoints.size() > MAX_ENDPOINTS_PER_LOCATION) { + std::vector sample; + sample.reserve(MAX_ENDPOINTS_PER_LOCATION); + std::sample(endpoints.begin(), endpoints.end(), std::back_inserter(sample), MAX_ENDPOINTS_PER_LOCATION, gen); + endpoints.swap(sample); + } + } +} + +std::uint64_t TLocalDCDetector::MeasureLocationRtt(const std::vector& endpoints) const { + if (endpoints.empty()) { + return std::numeric_limits::max(); + } + + std::vector timings; + timings.reserve(PING_COUNT); + for (size_t i = 0; i < PING_COUNT; ++i) { + const auto& ep = endpoints[i % endpoints.size()].get(); + timings.push_back(PingEndpoint_(ep).MicroSeconds()); + } + std::sort(timings.begin(), timings.end()); + + return std::midpoint(timings[(PING_COUNT - 1) / 2], timings[PING_COUNT / 2]); +} + + +std::string TLocalDCDetector::FindNearestLocation(const TEndpointsByLocation& endpointsByLocation) { + auto minRtt = std::numeric_limits::max(); + std::string nearestLocation; + for (const auto& [location, endpoints] : endpointsByLocation) { + auto rtt = MeasureLocationRtt(endpoints); + if (rtt < minRtt) { + minRtt = rtt; + nearestLocation = location; + } + } + return nearestLocation; +} + +} // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.h b/src/client/impl/internal/local_dc_detector/local_dc_detector.h new file mode 100644 index 0000000000..4afefdc010 --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace NYdb::inline V3 { + +class TLocalDCDetector { +public: + using TPinger = std::function; + explicit TLocalDCDetector(TPinger pingEndpoint = PingEndpoint); + + void DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpoints); + bool IsLocalDC(const std::string& location) const; + +private: + using TEndpoint = Ydb::Discovery::EndpointInfo; + using TEndpointRef = std::reference_wrapper; + using TEndpointsByLocation = std::unordered_map>; + using TMeasureResult = std::pair; + + TEndpointsByLocation GroupEndpointsByLocation(const Ydb::Discovery::ListEndpointsResult& endpointsList) const; + void SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const; + std::uint64_t MeasureLocationRtt(const std::vector& endpoints) const; + std::string FindNearestLocation(const TEndpointsByLocation& endpointsByLocation); + +private: + static constexpr std::size_t MAX_ENDPOINTS_PER_LOCATION = 3; + static constexpr std::size_t PING_COUNT = 2 * MAX_ENDPOINTS_PER_LOCATION; + + TPinger PingEndpoint_; + std::string Location_; +}; + +} // namespace NYdb \ No newline at end of file diff --git a/src/client/impl/internal/local_dc_detector/pinger.cpp b/src/client/impl/internal/local_dc_detector/pinger.cpp new file mode 100644 index 0000000000..6489d0d12e --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/pinger.cpp @@ -0,0 +1,17 @@ +#define INCLUDE_YDB_INTERNAL_H +#include "pinger.h" + +namespace NYdb::inline V3 { + +TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint) { + try { + TNetworkAddress addr(endpoint.address().data(), static_cast(endpoint.port())); + auto start = TInstant::Now(); + TSocket sock(addr, TDuration::Seconds(PING_TIMEOUT_SECONDS)); + return TInstant::Now() - start; + } catch (...) { + return TDuration::Max(); + } +} + +} // namespace NYdb \ No newline at end of file diff --git a/src/client/impl/internal/local_dc_detector/pinger.h b/src/client/impl/internal/local_dc_detector/pinger.h new file mode 100644 index 0000000000..b514b8ace8 --- /dev/null +++ b/src/client/impl/internal/local_dc_detector/pinger.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include +#include + +namespace NYdb::inline V3 { + +static constexpr std::uint32_t PING_TIMEOUT_SECONDS = 5; + +TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint); + +} // namespace NYdb \ No newline at end of file diff --git a/src/client/table/impl/table_client.cpp b/src/client/table/impl/table_client.cpp index 26594dfb0b..8b1b424b14 100644 --- a/src/client/table/impl/table_client.cpp +++ b/src/client/table/impl/table_client.cpp @@ -229,7 +229,7 @@ void TTableClient::TImpl::StartPeriodicHostScanTask() { const auto balancingPolicy = strongClient->DbDriverState_->GetBalancingPolicyType(); // Try to find any host at foreign locations if prefer local dc - const ui64 foreignHost = (balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation) ? + const ui64 foreignHost = (balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation || balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC) ? ScanForeignLocations(strongClient) : 0; std::unordered_map hostMap; diff --git a/src/client/types/ydb.cpp b/src/client/types/ydb.cpp index f6083ec25e..2f7eef2b3a 100644 --- a/src/client/types/ydb.cpp +++ b/src/client/types/ydb.cpp @@ -22,6 +22,10 @@ TBalancingPolicy TBalancingPolicy::UsePreferableLocation(const std::optional(TImpl::UsePreferableLocation(location))); } +TBalancingPolicy TBalancingPolicy::UseDetectedLocalDC() { + return TBalancingPolicy(std::make_unique(TImpl::UseDetectedLocalDC())); +} + TBalancingPolicy TBalancingPolicy::UseAllNodes() { return TBalancingPolicy(std::make_unique(TImpl::UseAllNodes())); } diff --git a/tests/unit/client/CMakeLists.txt b/tests/unit/client/CMakeLists.txt index 8c3b142ee7..f3452a0c45 100644 --- a/tests/unit/client/CMakeLists.txt +++ b/tests/unit/client/CMakeLists.txt @@ -45,6 +45,19 @@ add_ydb_test(NAME client-impl-ydb_endpoints_ut unit ) +add_ydb_test(NAME client-impl-internal-local_dc_detector_ut + INCLUDE_DIRS + ${YDB_SDK_SOURCE_DIR}/src/client/impl/internal/local_dc_detector + SOURCES + local_dc_detector/local_dc_detector_ut.cpp + LINK_LIBRARIES + yutil + api-protos + impl-internal-local_dc_detector + LABELS + unit +) + add_ydb_test(NAME client-oauth2_ut SOURCES oauth2_token_exchange/credentials_ut.cpp diff --git a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp new file mode 100644 index 0000000000..8319b367d4 --- /dev/null +++ b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp @@ -0,0 +1,138 @@ +#include + +#define INCLUDE_YDB_INTERNAL_H +#include +#undef INCLUDE_YDB_INTERNAL_H + +#include +#include + +using namespace NYdb; + +class TMockedEndpoint { +public: + explicit TMockedEndpoint(std::vector measures) + : Measures_(std::move(measures)) + , Idx_(0) + {} + + TDuration Ping() { + std::size_t idx = Idx_++; + + if (idx < Measures_.size()) { + return Measures_.at(idx); + } + return TDuration::Max(); + } + +private: + const std::vector Measures_; + std::size_t Idx_; +}; + +class TMockedPinger { +public: + explicit TMockedPinger(std::unordered_map> measuresByAdress) { + EndpointByAdress_.reserve(measuresByAdress.size()); + + for (auto& [adress, measures] : measuresByAdress) { + EndpointByAdress_.emplace(std::move(adress), std::move(measures)); + } + } + + TDuration operator()(const Ydb::Discovery::EndpointInfo& endpoint) const { + auto it = EndpointByAdress_.find(endpoint.address()); + if (it == EndpointByAdress_.end()) { + return TDuration::Max(); + } + return it->second.Ping(); + } + +private: + mutable std::unordered_map EndpointByAdress_; +}; + +std::vector GenerateMeasures(size_t count, int minMs, int maxMs, std::mt19937& gen) { + std::vector measures; + measures.reserve(count); + std::uniform_int_distribution distrib(minMs, maxMs); + for (size_t i = 0; i < count; ++i) { + measures.push_back(TDuration::MicroSeconds(distrib(gen))); + } + return measures; +} + + +Y_UNIT_TEST_SUITE(LocalDCDetectionTest) { + Y_UNIT_TEST(Basic) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; + const std::vector endpointsB = {"B1", "B2", "B3"}; + const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + + for (const auto& ep : endpointsA) { + mockData[ep] = GenerateMeasures(10, 20, 30, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsB) { + mockData[ep] = GenerateMeasures(10, 30, 45, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("B"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsC) { + mockData[ep] = GenerateMeasures(8, 50, 70, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("C"); + endpoint.set_address(ep); + } + + std::function pinger = TMockedPinger(mockData); + TLocalDCDetector detector(pinger); + + detector.DetectLocalDC(endpoints); + + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + } + + Y_UNIT_TEST(Fallback) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + + const std::vector endpointsA = {"A1", "A2", "A3"}; + const std::vector endpointsB = {"B1", "B2", "B3"}; + const std::vector endpointsC = {"C1", "C2", "C3"}; + + for (const auto& ep : endpointsA) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsB) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("B"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsC) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("C"); + endpoint.set_address(ep); + } + + std::function pinger = TMockedPinger(mockData); + TLocalDCDetector detector(pinger); + + detector.DetectLocalDC(endpoints); + + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("C")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + } +} From 75939d563ada756ab09fbb15626f0212a8e5b9b1 Mon Sep 17 00:00:00 2001 From: Ilya Repin Date: Thu, 23 Oct 2025 11:46:35 +0000 Subject: [PATCH 2/9] Add tests --- .../local_dc_detector_ut.cpp | 134 ++++++++++++++++-- 1 file changed, 121 insertions(+), 13 deletions(-) diff --git a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp index 8319b367d4..d68f0ebec7 100644 --- a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp +++ b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp @@ -42,14 +42,23 @@ class TMockedPinger { TDuration operator()(const Ydb::Discovery::EndpointInfo& endpoint) const { auto it = EndpointByAdress_.find(endpoint.address()); - if (it == EndpointByAdress_.end()) { + if (it == EndpointByAdress_.end() || Blacklist_.contains(endpoint.address())) { return TDuration::Max(); } return it->second.Ping(); } + void BanEndpoint(const std::string& adress) { + Blacklist_.insert(adress); + } + + void UnbanEndpoint(const std::string& adress) { + Blacklist_.erase(adress); + } + private: mutable std::unordered_map EndpointByAdress_; + std::unordered_set Blacklist_; }; std::vector GenerateMeasures(size_t count, int minMs, int maxMs, std::mt19937& gen) { @@ -62,7 +71,6 @@ std::vector GenerateMeasures(size_t count, int minMs, int maxMs, std: return measures; } - Y_UNIT_TEST_SUITE(LocalDCDetectionTest) { Y_UNIT_TEST(Basic) { Ydb::Discovery::ListEndpointsResult endpoints; @@ -73,20 +81,23 @@ Y_UNIT_TEST_SUITE(LocalDCDetectionTest) { const std::vector endpointsB = {"B1", "B2", "B3"}; const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + const std::size_t epoches = 3; + const std::size_t measuresAmount = 10 * epoches; + for (const auto& ep : endpointsA) { - mockData[ep] = GenerateMeasures(10, 20, 30, gen); + mockData[ep] = GenerateMeasures(measuresAmount, 20, 30, gen); auto& endpoint = *endpoints.add_endpoints(); endpoint.set_location("A"); endpoint.set_address(ep); } for (const auto& ep : endpointsB) { - mockData[ep] = GenerateMeasures(10, 30, 45, gen); + mockData[ep] = GenerateMeasures(measuresAmount, 30, 45, gen); auto& endpoint = *endpoints.add_endpoints(); endpoint.set_location("B"); endpoint.set_address(ep); } for (const auto& ep : endpointsC) { - mockData[ep] = GenerateMeasures(8, 50, 70, gen); + mockData[ep] = GenerateMeasures(measuresAmount, 50, 70, gen); auto& endpoint = *endpoints.add_endpoints(); endpoint.set_location("C"); endpoint.set_address(ep); @@ -95,44 +106,141 @@ Y_UNIT_TEST_SUITE(LocalDCDetectionTest) { std::function pinger = TMockedPinger(mockData); TLocalDCDetector detector(pinger); + for (std::size_t i = 0; i < epoches; ++i) { + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + } + } + + Y_UNIT_TEST(SingleLocation) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3"}; + + for (const auto& ep : endpointsA) { + mockData[ep] = GenerateMeasures(10, 20, 30, gen); + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + + std::function pinger = TMockedPinger(mockData); + TLocalDCDetector detector(pinger); + detector.DetectLocalDC(endpoints); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); } - Y_UNIT_TEST(Fallback) { + Y_UNIT_TEST(UnavailableLocalDC) { Ydb::Discovery::ListEndpointsResult endpoints; std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); - const std::vector endpointsA = {"A1", "A2", "A3"}; - const std::vector endpointsB = {"B1", "B2", "B3"}; - const std::vector endpointsC = {"C1", "C2", "C3"}; + const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; + const std::vector endpointsB = {"B1", "B2", "B3", "B4", "B5", "B6", "B7"}; + const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + + const std::size_t epoches = 3; + const std::size_t measuresAmount = 10 * epoches; for (const auto& ep : endpointsA) { + mockData[ep] = GenerateMeasures(measuresAmount, 20, 30, gen); auto& endpoint = *endpoints.add_endpoints(); endpoint.set_location("A"); endpoint.set_address(ep); } for (const auto& ep : endpointsB) { + mockData[ep] = GenerateMeasures(measuresAmount, 30, 45, gen); auto& endpoint = *endpoints.add_endpoints(); endpoint.set_location("B"); endpoint.set_address(ep); } for (const auto& ep : endpointsC) { + mockData[ep] = GenerateMeasures(measuresAmount, 50, 70, gen); auto& endpoint = *endpoints.add_endpoints(); endpoint.set_location("C"); endpoint.set_address(ep); } - std::function pinger = TMockedPinger(mockData); + TMockedPinger mockPinger(mockData); + std::function pinger = + [&mockPinger](const Ydb::Discovery::EndpointInfo& endpoint) { return mockPinger(endpoint); }; + TLocalDCDetector detector(pinger); detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("C")); + for (const auto& ep : endpointsA) { + mockPinger.BanEndpoint(ep); + } + + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("A")); UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + + for (const auto& ep : endpointsA) { + mockPinger.UnbanEndpoint(ep); + } + + detector.DetectLocalDC(endpoints); UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + } + + Y_UNIT_TEST(OfflineDCs) { + Ydb::Discovery::ListEndpointsResult endpoints; + std::unordered_map> mockData; + std::mt19937 gen(std::random_device{}()); + + const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; + const std::vector endpointsB = {"B1", "B2", "B3", "B4", "B5", "B6", "B7"}; + const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + + for (const auto& ep : endpointsA) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsB) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("B"); + endpoint.set_address(ep); + } + for (const auto& ep : endpointsC) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("C"); + endpoint.set_address(ep); + } + + TMockedPinger mockPinger(mockData); + std::function pinger = + [&mockPinger](const Ydb::Discovery::EndpointInfo& endpoint) { return mockPinger(endpoint); }; + + TLocalDCDetector detector(pinger); + + for (const auto& ep : endpointsA) { + mockPinger.BanEndpoint(ep); + } + for (const auto& ep : endpointsB) { + mockPinger.BanEndpoint(ep); + } + for (const auto& ep : endpointsC) { + mockPinger.BanEndpoint(ep); + } + + detector.DetectLocalDC(endpoints); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); + UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("C")); } } From 08d7d37801bfbdf02d644a7d818728812882cf42 Mon Sep 17 00:00:00 2001 From: Ilya Repin Date: Thu, 23 Oct 2025 11:57:25 +0000 Subject: [PATCH 3/9] Change CMakeLists.txt --- src/client/impl/internal/local_dc_detector/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/impl/internal/local_dc_detector/CMakeLists.txt b/src/client/impl/internal/local_dc_detector/CMakeLists.txt index 94b448751e..71a33f6aad 100644 --- a/src/client/impl/internal/local_dc_detector/CMakeLists.txt +++ b/src/client/impl/internal/local_dc_detector/CMakeLists.txt @@ -2,7 +2,7 @@ _ydb_sdk_add_library(impl-internal-local_dc_detector) target_link_libraries(impl-internal-local_dc_detector PUBLIC yutil - api-grpc + api-protos ) target_sources(impl-internal-local_dc_detector PRIVATE From 097d5d0292552b1fa60fed3ef42c0b35824dbb63 Mon Sep 17 00:00:00 2001 From: Ilya Repin Date: Thu, 23 Oct 2025 12:11:53 +0000 Subject: [PATCH 4/9] Fix style --- .../impl/internal/local_dc_detector/local_dc_detector.h | 2 +- src/client/impl/internal/local_dc_detector/pinger.cpp | 2 +- src/client/impl/internal/local_dc_detector/pinger.h | 2 +- src/client/table/impl/table_client.cpp | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.h b/src/client/impl/internal/local_dc_detector/local_dc_detector.h index 4afefdc010..bee5ce42bb 100644 --- a/src/client/impl/internal/local_dc_detector/local_dc_detector.h +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.h @@ -40,4 +40,4 @@ class TLocalDCDetector { std::string Location_; }; -} // namespace NYdb \ No newline at end of file +} // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/pinger.cpp b/src/client/impl/internal/local_dc_detector/pinger.cpp index 6489d0d12e..1acb49c9b3 100644 --- a/src/client/impl/internal/local_dc_detector/pinger.cpp +++ b/src/client/impl/internal/local_dc_detector/pinger.cpp @@ -14,4 +14,4 @@ TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint) { } } -} // namespace NYdb \ No newline at end of file +} // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/pinger.h b/src/client/impl/internal/local_dc_detector/pinger.h index b514b8ace8..627f7a368f 100644 --- a/src/client/impl/internal/local_dc_detector/pinger.h +++ b/src/client/impl/internal/local_dc_detector/pinger.h @@ -12,4 +12,4 @@ static constexpr std::uint32_t PING_TIMEOUT_SECONDS = 5; TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint); -} // namespace NYdb \ No newline at end of file +} // namespace NYdb diff --git a/src/client/table/impl/table_client.cpp b/src/client/table/impl/table_client.cpp index 8b1b424b14..6d311308a1 100644 --- a/src/client/table/impl/table_client.cpp +++ b/src/client/table/impl/table_client.cpp @@ -229,8 +229,10 @@ void TTableClient::TImpl::StartPeriodicHostScanTask() { const auto balancingPolicy = strongClient->DbDriverState_->GetBalancingPolicyType(); // Try to find any host at foreign locations if prefer local dc - const ui64 foreignHost = (balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation || balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC) ? - ScanForeignLocations(strongClient) : 0; + const ui64 foreignHost = + balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UsePreferableLocation || + balancingPolicy == TBalancingPolicy::TImpl::EPolicyType::UseDetectedLocalDC ? + ScanForeignLocations(strongClient) : 0; std::unordered_map hostMap; From 4fcc1db1b39b3d808e209e7c918406f3c3e341df Mon Sep 17 00:00:00 2001 From: Ilya Repin Date: Thu, 23 Oct 2025 13:15:25 +0000 Subject: [PATCH 5/9] Remove TMeasureResult --- src/client/impl/internal/local_dc_detector/local_dc_detector.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.h b/src/client/impl/internal/local_dc_detector/local_dc_detector.h index bee5ce42bb..d255f8c047 100644 --- a/src/client/impl/internal/local_dc_detector/local_dc_detector.h +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.h @@ -25,7 +25,6 @@ class TLocalDCDetector { using TEndpoint = Ydb::Discovery::EndpointInfo; using TEndpointRef = std::reference_wrapper; using TEndpointsByLocation = std::unordered_map>; - using TMeasureResult = std::pair; TEndpointsByLocation GroupEndpointsByLocation(const Ydb::Discovery::ListEndpointsResult& endpointsList) const; void SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const; From 7d1d1dd08d34b076a2afcd33bf8cd8657bfe4be6 Mon Sep 17 00:00:00 2001 From: Ilya-Repin Date: Mon, 10 Nov 2025 04:27:53 +0000 Subject: [PATCH 6/9] Add neh/asio library --- library/cpp/CMakeLists.txt | 2 + library/cpp/dns/CMakeLists.txt | 17 + library/cpp/dns/README.md | 9 + library/cpp/dns/cache.cpp | 198 +++++ library/cpp/dns/cache.h | 45 ++ library/cpp/dns/magic.cpp | 28 + library/cpp/dns/magic.h | 17 + library/cpp/dns/thread.cpp | 133 ++++ library/cpp/dns/thread.h | 12 + library/cpp/dns/ut/CMakeLists.txt | 9 + library/cpp/dns/ut/dns_ut.cpp | 25 + library/cpp/neh/asio/CMakeLists.txt | 22 + library/cpp/neh/asio/asio.cpp | 191 +++++ library/cpp/neh/asio/asio.h | 283 +++++++ library/cpp/neh/asio/deadline_timer_impl.cpp | 1 + library/cpp/neh/asio/deadline_timer_impl.h | 110 +++ library/cpp/neh/asio/executor.cpp | 1 + library/cpp/neh/asio/executor.h | 76 ++ library/cpp/neh/asio/io_service_impl.cpp | 161 ++++ library/cpp/neh/asio/io_service_impl.h | 762 +++++++++++++++++++ library/cpp/neh/asio/poll_interrupter.cpp | 1 + library/cpp/neh/asio/poll_interrupter.h | 107 +++ library/cpp/neh/asio/tcp_acceptor_impl.cpp | 25 + library/cpp/neh/asio/tcp_acceptor_impl.h | 76 ++ library/cpp/neh/asio/tcp_socket_impl.cpp | 117 +++ library/cpp/neh/asio/tcp_socket_impl.h | 332 ++++++++ library/cpp/neh/lfqueue.h | 53 ++ library/cpp/neh/pipequeue.h | 207 +++++ 28 files changed, 3020 insertions(+) create mode 100644 library/cpp/dns/CMakeLists.txt create mode 100644 library/cpp/dns/README.md create mode 100644 library/cpp/dns/cache.cpp create mode 100644 library/cpp/dns/cache.h create mode 100644 library/cpp/dns/magic.cpp create mode 100644 library/cpp/dns/magic.h create mode 100644 library/cpp/dns/thread.cpp create mode 100644 library/cpp/dns/thread.h create mode 100644 library/cpp/dns/ut/CMakeLists.txt create mode 100644 library/cpp/dns/ut/dns_ut.cpp create mode 100644 library/cpp/neh/asio/CMakeLists.txt create mode 100644 library/cpp/neh/asio/asio.cpp create mode 100644 library/cpp/neh/asio/asio.h create mode 100644 library/cpp/neh/asio/deadline_timer_impl.cpp create mode 100644 library/cpp/neh/asio/deadline_timer_impl.h create mode 100644 library/cpp/neh/asio/executor.cpp create mode 100644 library/cpp/neh/asio/executor.h create mode 100644 library/cpp/neh/asio/io_service_impl.cpp create mode 100644 library/cpp/neh/asio/io_service_impl.h create mode 100644 library/cpp/neh/asio/poll_interrupter.cpp create mode 100644 library/cpp/neh/asio/poll_interrupter.h create mode 100644 library/cpp/neh/asio/tcp_acceptor_impl.cpp create mode 100644 library/cpp/neh/asio/tcp_acceptor_impl.h create mode 100644 library/cpp/neh/asio/tcp_socket_impl.cpp create mode 100644 library/cpp/neh/asio/tcp_socket_impl.h create mode 100644 library/cpp/neh/lfqueue.h create mode 100644 library/cpp/neh/pipequeue.h diff --git a/library/cpp/CMakeLists.txt b/library/cpp/CMakeLists.txt index b3cfe65cee..194d9f3036 100644 --- a/library/cpp/CMakeLists.txt +++ b/library/cpp/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(diff) add_subdirectory(digest/lower_case) add_subdirectory(digest/md5) add_subdirectory(digest/murmur) +add_subdirectory(dns) add_subdirectory(getopt) add_subdirectory(http/fetch) add_subdirectory(http/io) @@ -33,6 +34,7 @@ add_subdirectory(monlib/encode) add_subdirectory(monlib/exception) add_subdirectory(monlib/metrics) add_subdirectory(monlib/service) +add_subdirectory(neh/asio) add_subdirectory(openssl/holders) add_subdirectory(openssl/init) add_subdirectory(openssl/io) diff --git a/library/cpp/dns/CMakeLists.txt b/library/cpp/dns/CMakeLists.txt new file mode 100644 index 0000000000..634a5e96a9 --- /dev/null +++ b/library/cpp/dns/CMakeLists.txt @@ -0,0 +1,17 @@ +if (YDB_SDK_TESTS) + add_subdirectory(ut) +endif() + +_ydb_sdk_add_library(dns) + +target_sources(dns + PRIVATE + cache.cpp + thread.cpp + magic.cpp +) + +target_link_libraries(dns + PUBLIC + yutil +) diff --git a/library/cpp/dns/README.md b/library/cpp/dns/README.md new file mode 100644 index 0000000000..88bdba0d6a --- /dev/null +++ b/library/cpp/dns/README.md @@ -0,0 +1,9 @@ +Overview +=== +Библиотека кеширующего resolving-а - изначально писалась для имплементации neh http протокола, использующей корутины. +Для предотвращения пробоя короткого стека корутин есть метод, предусматривающий вынос в отдельный тред собственно вызов функции резолвинга. +Для предотвращения обращения к DNS серверам (использования вместо этого заранее заданных ip-адресов), +предусмотрена ручка добавления alias-ов hosname -> ip-address (требование от метапоискового движка). + +Из-за того, что библиотека разрабатывалась под задачу максимально быстрого резолвинга добавлены слои кеширования результатов +resoving-а, - возможности сбросить кеш для того, чтобы получить более свежие адреса для указанного host-а _нет_. diff --git a/library/cpp/dns/cache.cpp b/library/cpp/dns/cache.cpp new file mode 100644 index 0000000000..9414f072d5 --- /dev/null +++ b/library/cpp/dns/cache.cpp @@ -0,0 +1,198 @@ +#include "cache.h" + +#include "thread.h" + +#include +#include +#include +#include +#include +#include + +using namespace NDns; + +namespace { + struct TResolveTask { + enum EMethod { + Normal, + Threaded + }; + + inline TResolveTask(const TResolveInfo& info, EMethod method) + : Info(info) + , Method(method) + { + } + + const TResolveInfo& Info; + const EMethod Method; + }; + + class IDns { + public: + virtual ~IDns() = default; + virtual const TResolvedHost* Resolve(const TResolveTask&) = 0; + }; + + typedef TAtomicSharedPtr TResolvedHostPtr; + + struct THashResolveInfo { + inline size_t operator()(const TResolveInfo& ri) const { + return ComputeHash(ri.Host) ^ ri.Port; + } + }; + + struct TCompareResolveInfo { + inline bool operator()(const NDns::TResolveInfo& x, const NDns::TResolveInfo& y) const { + return x.Host == y.Host && x.Port == y.Port; + } + }; + + class TGlobalCachedDns: public IDns, public TNonCopyable { + public: + const TResolvedHost* Resolve(const TResolveTask& rt) override { + //2. search host in cache + { + TReadGuard guard(L_); + + TCache::const_iterator it = C_.find(rt.Info); + + if (it != C_.end()) { + return it->second.Get(); + } + } + + TResolvedHostPtr res = ResolveA(rt); + + //update cache + { + TWriteGuard guard(L_); + + std::pair updateResult = C_.insert(std::make_pair(TResolveInfo(res->Host, rt.Info.Port), res)); + TResolvedHost* rh = updateResult.first->second.Get(); + + if (updateResult.second) { + //fresh resolved host, set cache record id for it + rh->Id = C_.size() - 1; + } + + return rh; + } + } + + void AddAlias(const TString& host, const TString& alias) noexcept { + TWriteGuard guard(LA_); + + A_[host] = alias; + } + + static inline TGlobalCachedDns* Instance() { + return SingletonWithPriority(); + } + + private: + inline TResolvedHostPtr ResolveA(const TResolveTask& rt) { + TString originalHost(rt.Info.Host); + TString host(originalHost); + + //3. replace host to alias, if exist + if (A_.size()) { + TReadGuard guard(LA_); + TString names[] = {"*", host}; + + for (const auto& name : names) { + TAliases::const_iterator it = A_.find(name); + + if (it != A_.end()) { + host = it->second; + } + } + } + + if (host.length() > 2 && host[0] == '[') { + TString unbracedIpV6(host.data() + 1, host.size() - 2); + host.swap(unbracedIpV6); + } + + TAutoPtr na; + + //4. getaddrinfo (direct or in separate thread) + if (rt.Method == TResolveTask::Normal) { + na.Reset(new TNetworkAddress(host, rt.Info.Port)); + } else if (rt.Method == TResolveTask::Threaded) { + na = ThreadedResolve(host, rt.Info.Port); + } else { + Y_ASSERT(0); + throw yexception() << TStringBuf("invalid resolve method"); + } + + return new TResolvedHost(originalHost, *na); + } + + typedef THashMap TCache; + TCache C_; + TRWMutex L_; + typedef THashMap TAliases; + TAliases A_; + TRWMutex LA_; + }; + + class TCachedDns: public IDns { + public: + inline TCachedDns(IDns* slave) + : S_(slave) + { + } + + const TResolvedHost* Resolve(const TResolveTask& rt) override { + //1. search in local thread cache + { + TCache::const_iterator it = C_.find(rt.Info); + + if (it != C_.end()) { + return it->second; + } + } + + const TResolvedHost* res = S_->Resolve(rt); + + C_[TResolveInfo(res->Host, rt.Info.Port)] = res; + + return res; + } + + private: + typedef THashMap TCache; + TCache C_; + IDns* S_; + }; + + struct TThreadedDns: public TCachedDns { + inline TThreadedDns() + : TCachedDns(TGlobalCachedDns::Instance()) + { + } + }; + + inline IDns* ThrDns() { + return FastTlsSingleton(); + } +} + +namespace NDns { + const TResolvedHost* CachedResolve(const TResolveInfo& ri) { + TResolveTask rt(ri, TResolveTask::Normal); + + return ThrDns()->Resolve(rt); + } + + const TResolvedHost* CachedThrResolve(const TResolveInfo& ri) { + TResolveTask rt(ri, TResolveTask::Threaded); + + return ThrDns()->Resolve(rt); + } + + void AddHostAlias(const TString& host, const TString& alias) { + TGlobalCachedDns::Instance()->AddAlias(host, alias); + } +} diff --git a/library/cpp/dns/cache.h b/library/cpp/dns/cache.h new file mode 100644 index 0000000000..eda5dc4070 --- /dev/null +++ b/library/cpp/dns/cache.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +namespace NDns { + struct TResolveInfo { + inline TResolveInfo(const TStringBuf& host, ui16 port) + : Host(host) + , Port(port) + { + } + + TStringBuf Host; + ui16 Port; + }; + + struct TResolvedHost { + inline TResolvedHost(const TString& host, const TNetworkAddress& addr) noexcept + : Host(host) + , Addr(addr) + , Id(0) + { + } + + TString Host; //resolved hostname (from TResolveInfo, - before aliasing) + TNetworkAddress Addr; + size_t Id; //cache record id + }; + + // Resolving order: + // 1. check local thread cache, return if found + // 2. check global cache, return if found + // 3. search alias for hostname, if found, continue resolving alias + // 4. normal resolver + const TResolvedHost* CachedResolve(const TResolveInfo& ri); + + //like previous, but at stage 4 use separate thread for resolving (created on first usage) + //useful in green-threads with tiny stack + const TResolvedHost* CachedThrResolve(const TResolveInfo& ri); + + //create alias for host, which can be used for static resolving (when alias is ip address) + void AddHostAlias(const TString& host, const TString& alias); +} diff --git a/library/cpp/dns/magic.cpp b/library/cpp/dns/magic.cpp new file mode 100644 index 0000000000..b93792146f --- /dev/null +++ b/library/cpp/dns/magic.cpp @@ -0,0 +1,28 @@ +#include "magic.h" + +#include + +using namespace NDns; + +namespace { + namespace NX { + struct TError: public IError { + inline TError() + : E_(std::current_exception()) + { + } + + void Raise() override { + std::rethrow_exception(E_); + } + + std::exception_ptr E_; + }; + } +} + +IErrorRef NDns::SaveError() { + using namespace NX; + + return new NX::TError(); +} diff --git a/library/cpp/dns/magic.h b/library/cpp/dns/magic.h new file mode 100644 index 0000000000..d52cde0a6c --- /dev/null +++ b/library/cpp/dns/magic.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace NDns { + class IError { + public: + virtual ~IError() = default; + + virtual void Raise() = 0; + }; + + typedef TAutoPtr IErrorRef; + + IErrorRef SaveError(); +} diff --git a/library/cpp/dns/thread.cpp b/library/cpp/dns/thread.cpp new file mode 100644 index 0000000000..8b27d2d527 --- /dev/null +++ b/library/cpp/dns/thread.cpp @@ -0,0 +1,133 @@ +#include "thread.h" + +#include "magic.h" + +#include +#include +#include +#include +#include +#include + +using namespace NDns; + +namespace { + class TThreadedResolver: public IThreadFactory::IThreadAble, public TNonCopyable { + struct TResolveRequest { + inline TResolveRequest(const TString& host, ui16 port) + : Host(host) + , Port(port) + { + } + + inline TNetworkAddressPtr Wait() { + E.Wait(); + + if (!Error) { + if (!Result) { + ythrow TNetworkResolutionError(EAI_AGAIN) << TStringBuf(": resolver down"); + } + + return Result; + } + + Error->Raise(); + + ythrow TNetworkResolutionError(EAI_FAIL) << TStringBuf(": shit happen"); + } + + inline void Resolve() noexcept { + try { + Result = new TNetworkAddress(Host, Port); + } catch (...) { + Error = SaveError(); + } + + Wake(); + } + + inline void Wake() noexcept { + E.Signal(); + } + + TString Host; + ui16 Port; + TManualEvent E; + TNetworkAddressPtr Result; + IErrorRef Error; + }; + + public: + inline TThreadedResolver() + : E_(TSystemEvent::rAuto) + { + T_.push_back(SystemThreadFactory()->Run(this)); + } + + inline ~TThreadedResolver() override { + Schedule(nullptr); + + for (size_t i = 0; i < T_.size(); ++i) { + T_[i]->Join(); + } + + { + TResolveRequest* rr = nullptr; + + while (Q_.Dequeue(&rr)) { + if (rr) { + rr->Wake(); + } + } + } + } + + static inline TThreadedResolver* Instance() { + return Singleton(); + } + + inline TNetworkAddressPtr Resolve(const TString& host, ui16 port) { + TResolveRequest rr(host, port); + + Schedule(&rr); + + return rr.Wait(); + } + + private: + inline void Schedule(TResolveRequest* rr) { + Q_.Enqueue(rr); + E_.Signal(); + } + + void DoExecute() override { + while (true) { + TResolveRequest* rr = nullptr; + + while (!Q_.Dequeue(&rr)) { + E_.Wait(); + } + + if (rr) { + rr->Resolve(); + } else { + break; + } + } + + Schedule(nullptr); + } + + private: + TLockFreeQueue Q_; + TSystemEvent E_; + typedef TAutoPtr IThreadRef; + TVector T_; + }; +} + +namespace NDns { + TNetworkAddressPtr ThreadedResolve(const TString& host, ui16 port) { + return TThreadedResolver::Instance()->Resolve(host, port); + } +} diff --git a/library/cpp/dns/thread.h b/library/cpp/dns/thread.h new file mode 100644 index 0000000000..06b41d78ce --- /dev/null +++ b/library/cpp/dns/thread.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +#include +#include + +namespace NDns { + typedef TAutoPtr TNetworkAddressPtr; + + TNetworkAddressPtr ThreadedResolve(const TString& host, ui16 port); +} diff --git a/library/cpp/dns/ut/CMakeLists.txt b/library/cpp/dns/ut/CMakeLists.txt new file mode 100644 index 0000000000..9e315b8bf3 --- /dev/null +++ b/library/cpp/dns/ut/CMakeLists.txt @@ -0,0 +1,9 @@ +add_ydb_test(NAME dns-dns_ut + SOURCES + dns_ut.cpp + LINK_LIBRARIES + dns + cpp-testing-unittest_main + LABELS + unit +) diff --git a/library/cpp/dns/ut/dns_ut.cpp b/library/cpp/dns/ut/dns_ut.cpp new file mode 100644 index 0000000000..aae05a742c --- /dev/null +++ b/library/cpp/dns/ut/dns_ut.cpp @@ -0,0 +1,25 @@ +#include +#include +#include + +Y_UNIT_TEST_SUITE(TestDNS) { + using namespace NDns; + + Y_UNIT_TEST(TestMagic) { + UNIT_ASSERT_EXCEPTION(CachedThrResolve(TResolveInfo("?", 80)), yexception); + } + + Y_UNIT_TEST(TestAsteriskAlias) { + AddHostAlias("*", "localhost"); + const TResolvedHost* rh = CachedThrResolve(TResolveInfo("yandex.ru", 80)); + UNIT_ASSERT(rh != nullptr); + + const TNetworkAddress& addr = rh->Addr; + for (TNetworkAddress::TIterator ai = addr.Begin(); ai != addr.End(); ai++) { + if (ai->ai_family == AF_INET || ai->ai_family == AF_INET6) { + NAddr::TAddrInfo info(&*ai); + UNIT_ASSERT(IsLoopback(info)); + } + } + } +} diff --git a/library/cpp/neh/asio/CMakeLists.txt b/library/cpp/neh/asio/CMakeLists.txt new file mode 100644 index 0000000000..b053959419 --- /dev/null +++ b/library/cpp/neh/asio/CMakeLists.txt @@ -0,0 +1,22 @@ +_ydb_sdk_add_library(neh-asio) + +target_sources(neh-asio + PRIVATE + asio.cpp + deadline_timer_impl.cpp + executor.cpp + io_service_impl.cpp + poll_interrupter.cpp + tcp_acceptor_impl.cpp + tcp_socket_impl.cpp +) + +target_link_libraries(neh-asio + PUBLIC + yutil + contrib-libs-libc_compat + coroutine-engine + dns + PRIVATE + enum_serialization_runtime +) diff --git a/library/cpp/neh/asio/asio.cpp b/library/cpp/neh/asio/asio.cpp new file mode 100644 index 0000000000..e10f62f575 --- /dev/null +++ b/library/cpp/neh/asio/asio.cpp @@ -0,0 +1,191 @@ +#include "io_service_impl.h" +#include "deadline_timer_impl.h" +#include "tcp_socket_impl.h" +#include "tcp_acceptor_impl.h" + +using namespace NDns; +using namespace NAsio; + +namespace NAsio { + TIOService::TWork::TWork(TWork& w) + : Srv_(w.Srv_) + { + Srv_.GetImpl().WorkStarted(); + } + + TIOService::TWork::TWork(TIOService& srv) + : Srv_(srv) + { + Srv_.GetImpl().WorkStarted(); + } + + TIOService::TWork::~TWork() { + Srv_.GetImpl().WorkFinished(); + } + + TIOService::TIOService() + : Impl_(new TImpl()) + { + } + + TIOService::~TIOService() { + } + + void TIOService::Run() { + Impl_->Run(); + } + + size_t TIOService::GetOpQueueSize() noexcept { + return Impl_->GetOpQueueSize(); + } + + void TIOService::Post(TCompletionHandler h) { + Impl_->Post(std::move(h)); + } + + void TIOService::Abort() { + Impl_->Abort(); + } + + TDeadlineTimer::TDeadlineTimer(TIOService& srv) noexcept + : Srv_(srv) + , Impl_(nullptr) + { + } + + TDeadlineTimer::~TDeadlineTimer() { + if (Impl_) { + Srv_.GetImpl().ScheduleOp(new TUnregisterTimerOperation(Impl_)); + } + } + + void TDeadlineTimer::AsyncWaitExpireAt(TDeadline deadline, THandler h) { + if (!Impl_) { + Impl_ = new TDeadlineTimer::TImpl(Srv_.GetImpl()); + Srv_.GetImpl().ScheduleOp(new TRegisterTimerOperation(Impl_)); + } + Impl_->AsyncWaitExpireAt(deadline, h); + } + + void TDeadlineTimer::Cancel() { + Impl_->Cancel(); + } + + TTcpSocket::TTcpSocket(TIOService& srv) noexcept + : Srv_(srv) + , Impl_(new TImpl(srv.GetImpl())) + { + } + + TTcpSocket::~TTcpSocket() { + } + + void TTcpSocket::AsyncConnect(const TEndpoint& ep, TTcpSocket::TConnectHandler h, TDeadline deadline) { + Impl_->AsyncConnect(ep, h, deadline); + } + + void TTcpSocket::AsyncWrite(TSendedData& d, TTcpSocket::TWriteHandler h, TDeadline deadline) { + Impl_->AsyncWrite(d, h, deadline); + } + + void TTcpSocket::AsyncWrite(TContIOVector* vec, TWriteHandler h, TDeadline deadline) { + Impl_->AsyncWrite(vec, h, deadline); + } + + void TTcpSocket::AsyncWrite(const void* data, size_t size, TWriteHandler h, TDeadline deadline) { + class TBuffers: public IBuffers { + public: + TBuffers(const void* theData, size_t theSize) + : Part(theData, theSize) + , IOVec(&Part, 1) + { + } + + TContIOVector* GetIOvec() override { + return &IOVec; + } + + IOutputStream::TPart Part; + TContIOVector IOVec; + }; + + TSendedData d(new TBuffers(data, size)); + Impl_->AsyncWrite(d, h, deadline); + } + + void TTcpSocket::AsyncRead(void* buff, size_t size, TTcpSocket::TReadHandler h, TDeadline deadline) { + Impl_->AsyncRead(buff, size, h, deadline); + } + + void TTcpSocket::AsyncReadSome(void* buff, size_t size, TTcpSocket::TReadHandler h, TDeadline deadline) { + Impl_->AsyncReadSome(buff, size, h, deadline); + } + + void TTcpSocket::AsyncPollRead(TTcpSocket::TPollHandler h, TDeadline deadline) { + Impl_->AsyncPollRead(h, deadline); + } + + void TTcpSocket::AsyncPollWrite(TTcpSocket::TPollHandler h, TDeadline deadline) { + Impl_->AsyncPollWrite(h, deadline); + } + + void TTcpSocket::AsyncCancel() { + return Impl_->AsyncCancel(); + } + + size_t TTcpSocket::WriteSome(TContIOVector& d, TErrorCode& ec) noexcept { + return Impl_->WriteSome(d, ec); + } + + size_t TTcpSocket::WriteSome(const void* buff, size_t size, TErrorCode& ec) noexcept { + return Impl_->WriteSome(buff, size, ec); + } + + size_t TTcpSocket::ReadSome(void* buff, size_t size, TErrorCode& ec) noexcept { + return Impl_->ReadSome(buff, size, ec); + } + + bool TTcpSocket::IsOpen() const noexcept { + return Native() != INVALID_SOCKET; + } + + void TTcpSocket::Shutdown(TShutdownMode what, TErrorCode& ec) { + return Impl_->Shutdown(what, ec); + } + + SOCKET TTcpSocket::Native() const noexcept { + return Impl_->Fd(); + } + + TEndpoint TTcpSocket::RemoteEndpoint() const { + return Impl_->RemoteEndpoint(); + } + + ////////////////////////////////// + + TTcpAcceptor::TTcpAcceptor(TIOService& srv) noexcept + : Srv_(srv) + , Impl_(new TImpl(srv.GetImpl())) + { + } + + TTcpAcceptor::~TTcpAcceptor() { + } + + void TTcpAcceptor::Bind(TEndpoint& ep, TErrorCode& ec) noexcept { + return Impl_->Bind(ep, ec); + } + + void TTcpAcceptor::Listen(int backlog, TErrorCode& ec) noexcept { + return Impl_->Listen(backlog, ec); + } + + void TTcpAcceptor::AsyncAccept(TTcpSocket& s, TTcpAcceptor::TAcceptHandler h, TDeadline deadline) { + return Impl_->AsyncAccept(s, h, deadline); + } + + void TTcpAcceptor::AsyncCancel() { + Impl_->AsyncCancel(); + } + +} diff --git a/library/cpp/neh/asio/asio.h b/library/cpp/neh/asio/asio.h new file mode 100644 index 0000000000..87c3c6d525 --- /dev/null +++ b/library/cpp/neh/asio/asio.h @@ -0,0 +1,283 @@ +#pragma once + +// +//primary header for work with asio +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +//#define DEBUG_ASIO + +class TContIOVector; + +namespace NAsio { + class TErrorCode { + public: + inline TErrorCode(int val = 0) noexcept + : Val_(val) + { + } + + typedef void (*TUnspecifiedBoolType)(); + + static void UnspecifiedBoolTrue() { + } + + //safe cast to bool value + operator TUnspecifiedBoolType() const noexcept { // true if error + return Val_ == 0 ? nullptr : UnspecifiedBoolTrue; + } + + bool operator!() const noexcept { + return Val_ == 0; + } + + void Assign(int val) noexcept { + Val_ = val; + } + + int Value() const noexcept { + return Val_; + } + + TString Text() const { + if (!Val_) { + return TString(); + } + return LastSystemErrorText(Val_); + } + + void Check() { + if (Val_) { + throw TSystemError(Val_); + } + } + + private: + int Val_; + }; + + //wrapper for TInstant, for enabling use TDuration (+TInstant::Now()) as deadline + class TDeadline: public TInstant { + public: + TDeadline() + : TInstant(TInstant::Max()) + { + } + + TDeadline(const TInstant& t) + : TInstant(t) + { + } + + TDeadline(const TDuration& d) + : TInstant(TInstant::Now() + d) + { + } + }; + + class IHandlingContext { + public: + virtual ~IHandlingContext() { + } + + //if handler throw exception, call this function be ignored + virtual void ContinueUseHandler(TDeadline deadline = TDeadline()) = 0; + }; + + typedef std::function TCompletionHandler; + + class TIOService: public TNonCopyable { + public: + TIOService(); + ~TIOService(); + + void Run(); + void Post(TCompletionHandler); //call handler in Run() thread-executor + void Abort(); //in Run() all exist async i/o operations + timers receive error = ECANCELED, Run() exited + + // not const since internal queue is lockfree and needs to increment and decrement its reference counters + size_t GetOpQueueSize() noexcept; + + //counterpart boost::asio::io_service::work + class TWork { + public: + TWork(TWork&); + TWork(TIOService&); + ~TWork(); + + private: + void operator=(const TWork&); //disable + + TIOService& Srv_; + }; + + class TImpl; + + TImpl& GetImpl() noexcept { + return *Impl_; + } + + private: + THolder Impl_; + }; + + class TDeadlineTimer: public TNonCopyable { + public: + typedef std::function THandler; + + TDeadlineTimer(TIOService&) noexcept; + ~TDeadlineTimer(); + + void AsyncWaitExpireAt(TDeadline, THandler); + void Cancel(); + + TIOService& GetIOService() const noexcept { + return Srv_; + } + + class TImpl; + + private: + TIOService& Srv_; + TImpl* Impl_; + }; + + class TTcpSocket: public TNonCopyable { + public: + class IBuffers { + public: + virtual ~IBuffers() { + } + virtual TContIOVector* GetIOvec() = 0; + }; + typedef TAutoPtr TSendedData; + + typedef std::function THandler; + typedef THandler TConnectHandler; + typedef std::function TWriteHandler; + typedef std::function TReadHandler; + typedef THandler TPollHandler; + + enum TShutdownMode { + ShutdownReceive = SHUT_RD, + ShutdownSend = SHUT_WR, + ShutdownBoth = SHUT_RDWR + }; + + TTcpSocket(TIOService&) noexcept; + ~TTcpSocket(); + + void AsyncConnect(const TEndpoint& ep, TConnectHandler, TDeadline deadline = TDeadline()); + void AsyncWrite(TSendedData&, TWriteHandler, TDeadline deadline = TDeadline()); + void AsyncWrite(TContIOVector* buff, TWriteHandler, TDeadline deadline = TDeadline()); + void AsyncWrite(const void* buff, size_t size, TWriteHandler, TDeadline deadline = TDeadline()); + void AsyncRead(void* buff, size_t size, TReadHandler, TDeadline deadline = TDeadline()); + void AsyncReadSome(void* buff, size_t size, TReadHandler, TDeadline deadline = TDeadline()); + void AsyncPollWrite(TPollHandler, TDeadline deadline = TDeadline()); + void AsyncPollRead(TPollHandler, TDeadline deadline = TDeadline()); + void AsyncCancel(); + + //sync, but non blocked methods + size_t WriteSome(TContIOVector&, TErrorCode&) noexcept; + size_t WriteSome(const void* buff, size_t size, TErrorCode&) noexcept; + size_t ReadSome(void* buff, size_t size, TErrorCode&) noexcept; + + bool IsOpen() const noexcept; + void Shutdown(TShutdownMode mode, TErrorCode& ec); + + TIOService& GetIOService() const noexcept { + return Srv_; + } + + SOCKET Native() const noexcept; + + TEndpoint RemoteEndpoint() const; + + inline size_t WriteSome(TContIOVector& v) { + TErrorCode ec; + size_t n = WriteSome(v, ec); + ec.Check(); + return n; + } + + inline size_t WriteSome(const void* buff, size_t size) { + TErrorCode ec; + size_t n = WriteSome(buff, size, ec); + ec.Check(); + return n; + } + + inline size_t ReadSome(void* buff, size_t size) { + TErrorCode ec; + size_t n = ReadSome(buff, size, ec); + ec.Check(); + return n; + } + + void Shutdown(TShutdownMode mode) { + TErrorCode ec; + Shutdown(mode, ec); + ec.Check(); + } + + class TImpl; + + TImpl& GetImpl() const noexcept { + return *Impl_; + } + + private: + TIOService& Srv_; + TIntrusivePtr Impl_; + }; + + class TTcpAcceptor: public TNonCopyable { + public: + typedef std::function TAcceptHandler; + + TTcpAcceptor(TIOService&) noexcept; + ~TTcpAcceptor(); + + void Bind(TEndpoint&, TErrorCode&) noexcept; + void Listen(int backlog, TErrorCode&) noexcept; + + void AsyncAccept(TTcpSocket&, TAcceptHandler, TDeadline deadline = TDeadline()); + + void AsyncCancel(); + + inline void Bind(TEndpoint& ep) { + TErrorCode ec; + Bind(ep, ec); + ec.Check(); + } + inline void Listen(int backlog) { + TErrorCode ec; + Listen(backlog, ec); + ec.Check(); + } + + TIOService& GetIOService() const noexcept { + return Srv_; + } + + class TImpl; + + TImpl& GetImpl() const noexcept { + return *Impl_; + } + + private: + TIOService& Srv_; + TIntrusivePtr Impl_; + }; +} diff --git a/library/cpp/neh/asio/deadline_timer_impl.cpp b/library/cpp/neh/asio/deadline_timer_impl.cpp new file mode 100644 index 0000000000..399a4338fb --- /dev/null +++ b/library/cpp/neh/asio/deadline_timer_impl.cpp @@ -0,0 +1 @@ +#include "deadline_timer_impl.h" diff --git a/library/cpp/neh/asio/deadline_timer_impl.h b/library/cpp/neh/asio/deadline_timer_impl.h new file mode 100644 index 0000000000..d9db625c94 --- /dev/null +++ b/library/cpp/neh/asio/deadline_timer_impl.h @@ -0,0 +1,110 @@ +#pragma once + +#include "io_service_impl.h" + +namespace NAsio { + class TTimerOperation: public TOperation { + public: + TTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline) + : TOperation(deadline) + , T_(t) + { + } + + void AddOp(TIOService::TImpl&) override { + Y_ASSERT(0); + } + + void Finalize() override { + DBGOUT("TTimerDeadlineOperation::Finalize()"); + T_->DelOp(this); + } + + protected: + TIOService::TImpl::TTimer* T_; + }; + + class TRegisterTimerOperation: public TTimerOperation { + public: + TRegisterTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline = TInstant::Max()) + : TTimerOperation(t, deadline) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + T_->GetIOServiceImpl().SyncRegisterTimer(T_); + return true; + } + }; + + class TTimerDeadlineOperation: public TTimerOperation { + public: + TTimerDeadlineOperation(TIOService::TImpl::TTimer* t, TDeadlineTimer::THandler h, TInstant deadline) + : TTimerOperation(t, deadline) + , H_(h) + { + } + + void AddOp(TIOService::TImpl&) override { + T_->AddOp(this); + } + + bool Execute(int errorCode) override { + DBGOUT("TTimerDeadlineOperation::Execute(" << errorCode << ")"); + H_(errorCode == ETIMEDOUT ? 0 : errorCode, *this); + return true; + } + + private: + TDeadlineTimer::THandler H_; + }; + + class TCancelTimerOperation: public TTimerOperation { + public: + TCancelTimerOperation(TIOService::TImpl::TTimer* t) + : TTimerOperation(t, TInstant::Max()) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + T_->FailOperations(ECANCELED); + return true; + } + }; + + class TUnregisterTimerOperation: public TTimerOperation { + public: + TUnregisterTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline = TInstant::Max()) + : TTimerOperation(t, deadline) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + DBGOUT("TUnregisterTimerOperation::Execute(" << errorCode << ")"); + T_->GetIOServiceImpl().SyncUnregisterAndDestroyTimer(T_); + return true; + } + }; + + class TDeadlineTimer::TImpl: public TIOService::TImpl::TTimer { + public: + TImpl(TIOService::TImpl& srv) + : TIOService::TImpl::TTimer(srv) + { + } + + void AsyncWaitExpireAt(TDeadline d, TDeadlineTimer::THandler h) { + Srv_.ScheduleOp(new TTimerDeadlineOperation(this, h, d)); + } + + void Cancel() { + Srv_.ScheduleOp(new TCancelTimerOperation(this)); + } + }; +} diff --git a/library/cpp/neh/asio/executor.cpp b/library/cpp/neh/asio/executor.cpp new file mode 100644 index 0000000000..03b26bf847 --- /dev/null +++ b/library/cpp/neh/asio/executor.cpp @@ -0,0 +1 @@ +#include "executor.h" diff --git a/library/cpp/neh/asio/executor.h b/library/cpp/neh/asio/executor.h new file mode 100644 index 0000000000..4f6549044d --- /dev/null +++ b/library/cpp/neh/asio/executor.h @@ -0,0 +1,76 @@ +#pragma once + +#include "asio.h" + +#include + +#include +#include + +namespace NAsio { + class TIOServiceExecutor: public IThreadFactory::IThreadAble { + public: + TIOServiceExecutor() + : Work_(new TIOService::TWork(Srv_)) + { + T_ = SystemThreadFactory()->Run(this); + } + + ~TIOServiceExecutor() override { + SyncShutdown(); + } + + void DoExecute() override { + TThread::SetCurrentThreadName("NehAsioExecutor"); + Srv_.Run(); + } + + inline TIOService& GetIOService() noexcept { + return Srv_; + } + + void SyncShutdown() { + if (Work_) { + Work_.Destroy(); + Srv_.Abort(); //cancel all async operations, break Run() execution + T_->Join(); + } + } + + private: + TIOService Srv_; + TAutoPtr Work_; + typedef TAutoPtr IThreadRef; + IThreadRef T_; + }; + + class TExecutorsPool { + public: + TExecutorsPool(size_t executors) + : C_(0) + { + for (size_t i = 0; i < executors; ++i) { + E_.push_back(new TIOServiceExecutor()); + } + } + + inline size_t Size() const noexcept { + return E_.size(); + } + + inline TIOServiceExecutor& GetExecutor() noexcept { + TAtomicBase next = AtomicIncrement(C_); + return *E_[next % E_.size()]; + } + + void SyncShutdown() { + for (size_t i = 0; i < E_.size(); ++i) { + E_[i]->SyncShutdown(); + } + } + + private: + TAtomic C_; + TVector> E_; + }; +} diff --git a/library/cpp/neh/asio/io_service_impl.cpp b/library/cpp/neh/asio/io_service_impl.cpp new file mode 100644 index 0000000000..d49b3fb03e --- /dev/null +++ b/library/cpp/neh/asio/io_service_impl.cpp @@ -0,0 +1,161 @@ +#include "io_service_impl.h" + +#include + +using namespace NAsio; + +void TFdOperation::AddOp(TIOService::TImpl& srv) { + srv.AddOp(this); +} + +void TFdOperation::Finalize() { + (*PH_)->DelOp(this); +} + +void TPollFdEventHandler::ExecuteOperations(TFdOperations& oprs, int errorCode) { + TFdOperations::iterator it = oprs.begin(); + + try { + while (it != oprs.end()) { + TFdOperation* op = it->Get(); + + if (op->Execute(errorCode)) { // throw ? + if (op->IsRequiredRepeat()) { + Srv_.UpdateOpDeadline(op); + ++it; //operation completed, but want be repeated + } else { + FinishedOperations_.push_back(*it); + it = oprs.erase(it); + } + } else { + ++it; //operation not completed + } + } + } catch (...) { + if (it != oprs.end()) { + FinishedOperations_.push_back(*it); + oprs.erase(it); + } + throw; + } +} + +void TPollFdEventHandler::DelOp(TFdOperation* op) { + TAutoPtr& evh = *op->PH_; + + if (op->IsPollRead()) { + Y_ASSERT(FinishOp(ReadOperations_, op)); + } else { + Y_ASSERT(FinishOp(WriteOperations_, op)); + } + Srv_.FixHandledEvents(evh); //alarm, - 'this' can be destroyed here! +} + +void TInterrupterHandler::OnFdEvent(int status, ui16 filter) { + if (!status && (filter & CONT_POLL_READ)) { + PI_.Reset(); + } +} + +void TIOService::TImpl::Run() { + TEvh& iEvh = Evh_.Get(I_.Fd()); + iEvh.Reset(new TInterrupterHandler(*this, I_)); + + TInterrupterKeeper ik(*this, iEvh); + Y_UNUSED(ik); + IPollerFace::TEvents evs; + AtomicSet(NeedCheckOpQueue_, 1); + TInstant deadline; + + while (Y_LIKELY(!Aborted_ && (AtomicGet(OutstandingWork_) || FdEventHandlersCnt_ > 1 || TimersOpCnt_ || AtomicGet(NeedCheckOpQueue_)))) { + //while + // expected work (external flag) + // or have event handlers (exclude interrupter) + // or have not completed timer operation + // or have any operation in queues + + AtomicIncrement(IsWaiting_); + if (!AtomicGet(NeedCheckOpQueue_)) { + P_->Wait(evs, deadline); + } + AtomicDecrement(IsWaiting_); + + if (evs.size()) { + for (IPollerFace::TEvents::const_iterator iev = evs.begin(); iev != evs.end() && !Aborted_; ++iev) { + const IPollerFace::TEvent& ev = *iev; + TEvh& evh = *(TEvh*)ev.Data; + + if (!evh) { + continue; //op. cancel (see ProcessOpQueue) can destroy evh + } + + int status = ev.Status; + if (ev.Status == EIO) { + int error = status; + if (GetSockOpt(evh->Fd(), SOL_SOCKET, SO_ERROR, error) == 0) { + status = error; + } + } + + OnFdEvent(evh, status, ev.Filter); //here handle fd events + //immediatly after handling events for one descriptor check op. queue + //often queue can contain another operation for this fd (next async read as sample) + //so we can optimize redundant epoll_ctl (or similar) calls + ProcessOpQueue(); + } + + evs.clear(); + } else { + ProcessOpQueue(); + } + + deadline = DeadlinesQueue_.NextDeadline(); //here handle timeouts/process timers + } +} + +void TIOService::TImpl::Abort() { + class TAbortOperation: public TNoneOperation { + public: + TAbortOperation(TIOService::TImpl& srv) + : TNoneOperation() + , Srv_(srv) + { + Speculative_ = true; + } + + private: + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + Srv_.ProcessAbort(); + return true; + } + + TIOService::TImpl& Srv_; + }; + AtomicSet(HasAbort_, 1); + ScheduleOp(new TAbortOperation(*this)); +} + +void TIOService::TImpl::ProcessAbort() { + Aborted_ = true; + + for (int fd = 0; fd <= MaxFd_; ++fd) { + TEvh& evh = Evh_.Get(fd); + if (!!evh && evh->Fd() != I_.Fd()) { + OnFdEvent(evh, ECANCELED, CONT_POLL_READ | CONT_POLL_WRITE); + } + } + + for (auto t : Timers_) { + t->FailOperations(ECANCELED); + } + + TOperationPtr op; + while (OpQueue_.Dequeue(&op)) { //cancel all enqueued operations + try { + op->Execute(ECANCELED); + } catch (...) { + } + op.Destroy(); + } +} diff --git a/library/cpp/neh/asio/io_service_impl.h b/library/cpp/neh/asio/io_service_impl.h new file mode 100644 index 0000000000..e1768df6d9 --- /dev/null +++ b/library/cpp/neh/asio/io_service_impl.h @@ -0,0 +1,762 @@ +#pragma once + +#include "asio.h" +#include "poll_interrupter.h" + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#ifdef DEBUG_ASIO +#define DBGOUT(args) Cout << args << Endl; +#else +#define DBGOUT(args) +#endif + +namespace NAsio { +#if defined(_arm_) + template + struct TLockFreeSequence { + Y_NO_INLINE T& Get(size_t n) { + with_lock (M) { + return H[n]; + } + } + + TMutex M; + THashMap H; + }; +#else + //TODO: copypaste from neh, - need fix + template + class TLockFreeSequence { + public: + inline TLockFreeSequence() { + memset((void*)T_, 0, sizeof(T_)); + } + + inline ~TLockFreeSequence() { + for (size_t i = 0; i < Y_ARRAY_SIZE(T_); ++i) { + delete[] T_[i]; + } + } + + inline T& Get(size_t n) { + const size_t i = GetValueBitCount(n + 1) - 1; + + return GetList(i)[n + 1 - (((size_t)1) << i)]; + } + + private: + inline T* GetList(size_t n) { + T* volatile* t = T_ + n; + + while (!*t) { + TArrayHolder nt(new T[((size_t)1) << n]); + + if (AtomicCas(t, nt.Get(), nullptr)) { + return nt.Release(); + } + } + + return *t; + } + + private: + T* volatile T_[sizeof(size_t) * 8]; + }; +#endif + + struct TOperationCompare { + template + static inline bool Compare(const T& l, const T& r) noexcept { + return l.DeadLine() < r.DeadLine() || (l.DeadLine() == r.DeadLine() && &l < &r); + } + }; + + //async operation, execute in contex TIOService()::Run() thread-executor + //usualy used for call functors/callbacks + class TOperation: public TRbTreeItem, public IHandlingContext { + public: + TOperation(TInstant deadline = TInstant::Max()) + : D_(deadline) + , Speculative_(false) + , RequiredRepeatExecution_(false) + , ND_(deadline) + { + } + + //register this operation in svc.impl. + virtual void AddOp(TIOService::TImpl&) = 0; + + //return false, if operation not completed + virtual bool Execute(int errorCode = 0) = 0; + + void ContinueUseHandler(TDeadline deadline) override { + RequiredRepeatExecution_ = true; + ND_ = deadline; + } + + virtual void Finalize() = 0; + + inline TInstant Deadline() const noexcept { + return D_; + } + + inline TInstant DeadLine() const noexcept { + return D_; + } + + inline bool Speculative() const noexcept { + return Speculative_; + } + + inline bool IsRequiredRepeat() const noexcept { + return RequiredRepeatExecution_; + } + + inline void PrepareReExecution() noexcept { + RequiredRepeatExecution_ = false; + D_ = ND_; + } + + protected: + TInstant D_; + bool Speculative_; //if true, operation will be runned immediately after dequeue (even without wating any event) + //as sample used for optimisation writing, - obviously in buffers exist space for write + bool RequiredRepeatExecution_; //set to true, if required re-exec operation + TInstant ND_; //new deadline (for re-exec operation) + }; + + typedef TAutoPtr TOperationPtr; + + class TNoneOperation: public TOperation { + public: + TNoneOperation(TInstant deadline = TInstant::Max()) + : TOperation(deadline) + { + } + + void AddOp(TIOService::TImpl&) override { + Y_ASSERT(0); + } + + void Finalize() override { + } + }; + + class TPollFdEventHandler; + + //descriptor use operation + class TFdOperation: public TOperation { + public: + enum TPollType { + PollRead, + PollWrite + }; + + TFdOperation(SOCKET fd, TPollType pt, TInstant deadline = TInstant::Max()) + : TOperation(deadline) + , Fd_(fd) + , PT_(pt) + , PH_(nullptr) + { + Y_ASSERT(Fd() != INVALID_SOCKET); + } + + inline SOCKET Fd() const noexcept { + return Fd_; + } + + inline bool IsPollRead() const noexcept { + return PT_ == PollRead; + } + + void AddOp(TIOService::TImpl& srv) override; + + void Finalize() override; + + protected: + SOCKET Fd_; + TPollType PT_; + + public: + TAutoPtr* PH_; + }; + + typedef TAutoPtr TFdOperationPtr; + + class TPollFdEventHandler { + public: + TPollFdEventHandler(SOCKET fd, TIOService::TImpl& srv) + : Fd_(fd) + , HandledEvents_(0) + , Srv_(srv) + { + } + + virtual ~TPollFdEventHandler() { + Y_ASSERT(ReadOperations_.size() == 0); + Y_ASSERT(WriteOperations_.size() == 0); + } + + inline void AddReadOp(TFdOperationPtr op) { + ReadOperations_.push_back(op); + } + + inline void AddWriteOp(TFdOperationPtr op) { + WriteOperations_.push_back(op); + } + + virtual void OnFdEvent(int status, ui16 filter) { + DBGOUT("PollEvent(fd=" << Fd_ << ", " << status << ", " << filter << ")"); + if (status) { + ExecuteOperations(ReadOperations_, status); + ExecuteOperations(WriteOperations_, status); + } else { + if (filter & CONT_POLL_READ) { + ExecuteOperations(ReadOperations_, status); + } + if (filter & CONT_POLL_WRITE) { + ExecuteOperations(WriteOperations_, status); + } + } + } + + typedef TVector TFdOperations; + + void ExecuteOperations(TFdOperations& oprs, int errorCode); + + //return true if filter handled events changed and require re-configure events poller + virtual bool FixHandledEvents() noexcept { + DBGOUT("TPollFdEventHandler::FixHandledEvents()"); + ui16 filter = 0; + + if (WriteOperations_.size()) { + filter |= CONT_POLL_WRITE; + } + if (ReadOperations_.size()) { + filter |= CONT_POLL_READ; + } + + if (Y_LIKELY(HandledEvents_ == filter)) { + return false; + } + + HandledEvents_ = filter; + return true; + } + + inline bool FinishOp(TFdOperations& oprs, TFdOperation* op) noexcept { + for (TFdOperations::iterator it = oprs.begin(); it != oprs.end(); ++it) { + if (it->Get() == op) { + FinishedOperations_.push_back(*it); + oprs.erase(it); + return true; + } + } + return false; + } + + void DelOp(TFdOperation* op); + + inline SOCKET Fd() const noexcept { + return Fd_; + } + + inline ui16 HandledEvents() const noexcept { + return HandledEvents_; + } + + inline void AddHandlingEvent(ui16 ev) noexcept { + HandledEvents_ |= ev; + } + + inline void DestroyFinishedOperations() { + FinishedOperations_.clear(); + } + + TIOService::TImpl& GetServiceImpl() const noexcept { + return Srv_; + } + + protected: + SOCKET Fd_; + ui16 HandledEvents_; + TIOService::TImpl& Srv_; + + private: + TVector ReadOperations_; + TVector WriteOperations_; + // we can't immediatly destroy finished operations, this can cause closing used socket descriptor Fd_ + // (on cascade deletion operation object-handler), but later we use Fd_ for modify handled events at poller, + // so we collect here finished operations and destroy it only after update poller, - + // call FixHandledEvents(TPollFdEventHandlerPtr&) + TVector FinishedOperations_; + }; + + //additional descriptor for poller, used for interrupt current poll wait + class TInterrupterHandler: public TPollFdEventHandler { + public: + TInterrupterHandler(TIOService::TImpl& srv, TPollInterrupter& pi) + : TPollFdEventHandler(pi.Fd(), srv) + , PI_(pi) + { + HandledEvents_ = CONT_POLL_READ; + } + + ~TInterrupterHandler() override { + DBGOUT("~TInterrupterHandler"); + } + + void OnFdEvent(int status, ui16 filter) override; + + bool FixHandledEvents() noexcept override { + DBGOUT("TInterrupterHandler::FixHandledEvents()"); + return false; + } + + private: + TPollInterrupter& PI_; + }; + + namespace { + inline TAutoPtr CreatePoller() { + try { +#if defined(_linux_) + return IPollerFace::Construct(TStringBuf("epoll")); +#endif +#if defined(_freebsd_) || defined(_darwin_) + return IPollerFace::Construct(TStringBuf("kqueue")); +#endif + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + } + return IPollerFace::Default(); + } + } + + //some equivalent TContExecutor + class TIOService::TImpl: public TNonCopyable { + public: + typedef TAutoPtr TEvh; + typedef TLockFreeSequence TEventHandlers; + + class TTimer { + public: + typedef THashSet TOperations; + + TTimer(TIOService::TImpl& srv) + : Srv_(srv) + { + } + + virtual ~TTimer() { + FailOperations(ECANCELED); + } + + void AddOp(TOperation* op) { + THolder tmp(op); + Operations_.insert(op); + Y_UNUSED(tmp.Release()); + Srv_.RegisterOpDeadline(op); + Srv_.IncTimersOp(); + } + + void DelOp(TOperation* op) { + TOperations::iterator it = Operations_.find(op); + if (it != Operations_.end()) { + Srv_.DecTimersOp(); + delete op; + Operations_.erase(it); + } + } + + inline void FailOperations(int ec) { + for (auto operation : Operations_) { + try { + operation->Execute(ec); //throw ? + } catch (...) { + } + Srv_.DecTimersOp(); + delete operation; + } + Operations_.clear(); + } + + TIOService::TImpl& GetIOServiceImpl() const noexcept { + return Srv_; + } + + protected: + TIOService::TImpl& Srv_; + THashSet Operations_; + }; + + class TTimers: public THashSet { + public: + ~TTimers() { + for (auto it : *this) { + delete it; + } + } + }; + + TImpl() + : P_(CreatePoller()) + , DeadlinesQueue_(*this) + { + } + + ~TImpl() { + TOperationPtr op; + + while (OpQueue_.Dequeue(&op)) { //cancel all enqueued operations + try { + op->Execute(ECANCELED); + } catch (...) { + } + op.Destroy(); + } + } + + //similar TContExecutor::Execute() or io_service::run() + //process event loop (exit if none to do (no timers or event handlers)) + void Run(); + + //enqueue functor fo call in Run() eventloop (thread safing) + inline void Post(TCompletionHandler h) { + class TFuncOperation: public TNoneOperation { + public: + TFuncOperation(TCompletionHandler completionHandler) + : TNoneOperation() + , H_(std::move(completionHandler)) + { + Speculative_ = true; + } + + private: + //return false, if operation not completed + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + H_(); + return true; + } + + TCompletionHandler H_; + }; + + ScheduleOp(new TFuncOperation(std::move(h))); + } + + //cancel all current operations (handlers be called with errorCode == ECANCELED) + void Abort(); + bool HasAbort() { + return AtomicGet(HasAbort_); + } + + inline void ScheduleOp(TOperationPtr op) { //throw std::bad_alloc + Y_ASSERT(!Aborted_); + Y_ASSERT(!!op); + OpQueue_.Enqueue(op); + Interrupt(); + } + + inline void Interrupt() noexcept { + AtomicSet(NeedCheckOpQueue_, 1); + if (AtomicAdd(IsWaiting_, 0) == 1) { + I_.Interrupt(); + } + } + + inline void UpdateOpDeadline(TOperation* op) { + TInstant oldDeadline = op->Deadline(); + op->PrepareReExecution(); + + if (oldDeadline == op->Deadline()) { + return; + } + + if (oldDeadline != TInstant::Max()) { + op->UnLink(); + } + if (op->Deadline() != TInstant::Max()) { + DeadlinesQueue_.Register(op); + } + } + + inline size_t GetOpQueueSize() noexcept { + return OpQueue_.Size(); + } + + void SyncRegisterTimer(TTimer* t) { + Timers_.insert(t); + } + + inline void SyncUnregisterAndDestroyTimer(TTimer* t) { + Timers_.erase(t); + delete t; + } + + inline void IncTimersOp() noexcept { + ++TimersOpCnt_; + } + + inline void DecTimersOp() noexcept { + --TimersOpCnt_; + } + + inline void WorkStarted() { + AtomicIncrement(OutstandingWork_); + } + + inline void WorkFinished() { + if (AtomicDecrement(OutstandingWork_) == 0) { + Interrupt(); + } + } + + private: + void ProcessAbort(); + + inline TEvh& EnsureGetEvh(SOCKET fd) { + TEvh& evh = Evh_.Get(fd); + if (!evh) { + evh.Reset(new TPollFdEventHandler(fd, *this)); + } + return evh; + } + + inline void OnTimeoutOp(TOperation* op) { + DBGOUT("OnTimeoutOp"); + try { + op->Execute(ETIMEDOUT); //throw ? + } catch (...) { + op->Finalize(); + throw; + } + + if (op->IsRequiredRepeat()) { + //operation not completed + UpdateOpDeadline(op); + } else { + //destroy operation structure + op->Finalize(); + } + } + + public: + inline void FixHandledEvents(TEvh& evh) { + if (!!evh) { + if (evh->FixHandledEvents()) { + if (!evh->HandledEvents()) { + DelEventHandler(evh); + evh.Destroy(); + } else { + ModEventHandler(evh); + evh->DestroyFinishedOperations(); + } + } else { + evh->DestroyFinishedOperations(); + } + } + } + + private: + inline TEvh& GetHandlerForOp(TFdOperation* op) { + TEvh& evh = EnsureGetEvh(op->Fd()); + op->PH_ = &evh; + return evh; + } + + void ProcessOpQueue() { + if (!AtomicGet(NeedCheckOpQueue_)) { + return; + } + AtomicSet(NeedCheckOpQueue_, 0); + + TOperationPtr op; + + while (OpQueue_.Dequeue(&op)) { + if (op->Speculative()) { + if (op->Execute(Y_UNLIKELY(Aborted_) ? ECANCELED : 0)) { + op.Destroy(); + continue; //operation completed + } + + if (!op->IsRequiredRepeat()) { + op->PrepareReExecution(); + } + } + RegisterOpDeadline(op.Get()); + op.Get()->AddOp(*this); // ... -> AddOp() + Y_UNUSED(op.Release()); + } + } + + inline void RegisterOpDeadline(TOperation* op) { + if (op->DeadLine() != TInstant::Max()) { + DeadlinesQueue_.Register(op); + } + } + + public: + inline void AddOp(TFdOperation* op) { + DBGOUT("AddOp(" << op->Fd() << ")"); + TEvh& evh = GetHandlerForOp(op); + if (op->IsPollRead()) { + evh->AddReadOp(op); + EnsureEventHandled(evh, CONT_POLL_READ); + } else { + evh->AddWriteOp(op); + EnsureEventHandled(evh, CONT_POLL_WRITE); + } + } + + private: + inline void EnsureEventHandled(TEvh& evh, ui16 ev) { + if (!evh->HandledEvents()) { + evh->AddHandlingEvent(ev); + AddEventHandler(evh); + } else { + if ((evh->HandledEvents() & ev) == 0) { + evh->AddHandlingEvent(ev); + ModEventHandler(evh); + } + } + } + + public: + //cancel all current operations for socket + //method MUST be called from Run() thread-executor + void CancelFdOp(SOCKET fd) { + TEvh& evh = Evh_.Get(fd); + if (!evh) { + return; + } + + OnFdEvent(evh, ECANCELED, CONT_POLL_READ | CONT_POLL_WRITE); + } + + private: + //helper for fixing handled events even in case exception + struct TExceptionProofFixerHandledEvents { + TExceptionProofFixerHandledEvents(TIOService::TImpl& srv, TEvh& iEvh) + : Srv_(srv) + , Evh_(iEvh) + { + } + + ~TExceptionProofFixerHandledEvents() { + Srv_.FixHandledEvents(Evh_); + } + + TIOService::TImpl& Srv_; + TEvh& Evh_; + }; + + inline void OnFdEvent(TEvh& evh, int status, ui16 filter) { + TExceptionProofFixerHandledEvents fixer(*this, evh); + Y_UNUSED(fixer); + evh->OnFdEvent(status, filter); + } + + inline void AddEventHandler(TEvh& evh) { + if (evh->Fd() > MaxFd_) { + MaxFd_ = evh->Fd(); + } + SetEventHandler(&evh, evh->Fd(), evh->HandledEvents()); + ++FdEventHandlersCnt_; + } + + inline void ModEventHandler(TEvh& evh) { + SetEventHandler(&evh, evh->Fd(), evh->HandledEvents()); + } + + inline void DelEventHandler(TEvh& evh) { + SetEventHandler(&evh, evh->Fd(), 0); + --FdEventHandlersCnt_; + } + + inline void SetEventHandler(void* h, int fd, ui16 flags) { + DBGOUT("SetEventHandler(" << fd << ", " << flags << ")"); + P_->Set(h, fd, flags); + } + + //exception safe call DelEventHandler + struct TInterrupterKeeper { + TInterrupterKeeper(TImpl& srv, TEvh& iEvh) + : Srv_(srv) + , Evh_(iEvh) + { + Srv_.AddEventHandler(Evh_); + } + + ~TInterrupterKeeper() { + Srv_.DelEventHandler(Evh_); + } + + TImpl& Srv_; + TEvh& Evh_; + }; + + TAutoPtr P_; + TPollInterrupter I_; + TAtomic IsWaiting_ = 0; + TAtomic NeedCheckOpQueue_ = 0; + TAtomic OutstandingWork_ = 0; + + NNeh::TAutoLockFreeQueue OpQueue_; + + TEventHandlers Evh_; //i/o event handlers + TTimers Timers_; //timeout event handlers + + size_t FdEventHandlersCnt_ = 0; //i/o event handlers counter + size_t TimersOpCnt_ = 0; //timers op counter + SOCKET MaxFd_ = 0; //max used descriptor num + TAtomic HasAbort_ = 0; + bool Aborted_ = false; + + class TDeadlinesQueue { + public: + TDeadlinesQueue(TIOService::TImpl& srv) + : Srv_(srv) + { + } + + inline void Register(TOperation* op) { + Deadlines_.Insert(op); + } + + TInstant NextDeadline() { + TDeadlines::TIterator it = Deadlines_.Begin(); + + while (it != Deadlines_.End()) { + if (it->DeadLine() > TInstant::Now()) { + DBGOUT("TDeadlinesQueue::NewDeadline:" << (it->DeadLine().GetValue() - TInstant::Now().GetValue())); + return it->DeadLine(); + } + + TOperation* op = &*(it++); + Srv_.OnTimeoutOp(op); + } + + return Deadlines_.Empty() ? TInstant::Max() : Deadlines_.Begin()->DeadLine(); + } + + private: + typedef TRbTree TDeadlines; + TDeadlines Deadlines_; + TIOService::TImpl& Srv_; + }; + + TDeadlinesQueue DeadlinesQueue_; + }; +} diff --git a/library/cpp/neh/asio/poll_interrupter.cpp b/library/cpp/neh/asio/poll_interrupter.cpp new file mode 100644 index 0000000000..c96d40c4f3 --- /dev/null +++ b/library/cpp/neh/asio/poll_interrupter.cpp @@ -0,0 +1 @@ +#include "poll_interrupter.h" diff --git a/library/cpp/neh/asio/poll_interrupter.h b/library/cpp/neh/asio/poll_interrupter.h new file mode 100644 index 0000000000..faf815c512 --- /dev/null +++ b/library/cpp/neh/asio/poll_interrupter.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef _linux_ +#include +#endif + +#if defined(_bionic_) && !defined(EFD_SEMAPHORE) +#define EFD_SEMAPHORE 1 +#endif + +namespace NAsio { +#ifdef _linux_ + class TEventFdPollInterrupter { + public: + inline TEventFdPollInterrupter() { + F_ = eventfd(0, EFD_NONBLOCK | EFD_SEMAPHORE); + if (F_ < 0) { + ythrow TFileError() << "failed to create a eventfd"; + } + } + + inline ~TEventFdPollInterrupter() { + close(F_); + } + + inline void Interrupt() const noexcept { + const static eventfd_t ev(1); + ssize_t res = ::write(F_, &ev, sizeof ev); + Y_UNUSED(res); + } + + inline bool Reset() const noexcept { + eventfd_t ev(0); + + for (;;) { + ssize_t res = ::read(F_, &ev, sizeof ev); + if (res && res == EINTR) { + continue; + } + + return res > 0; + } + } + + int Fd() { + return F_; + } + + private: + int F_; + }; +#endif + + class TPipePollInterrupter { + public: + TPipePollInterrupter() { + TPipeHandle::Pipe(S_[0], S_[1]); + + SetNonBlock(S_[0]); + SetNonBlock(S_[1]); + } + + inline void Interrupt() const noexcept { + char byte = 0; + ssize_t res = S_[1].Write(&byte, 1); + Y_UNUSED(res); + } + + inline bool Reset() const noexcept { + char buff[256]; + + for (;;) { + ssize_t r = S_[0].Read(buff, sizeof buff); + + if (r < 0 && r == EINTR) { + continue; + } + + bool wasInterrupted = r > 0; + + while (r == sizeof buff) { + r = S_[0].Read(buff, sizeof buff); + } + + return wasInterrupted; + } + } + + PIPEHANDLE Fd() const noexcept { + return S_[0]; + } + + private: + TPipeHandle S_[2]; + }; + +#ifdef _linux_ + typedef TEventFdPollInterrupter TPollInterrupter; //more effective than pipe, but only linux impl. +#else + typedef TPipePollInterrupter TPollInterrupter; +#endif +} diff --git a/library/cpp/neh/asio/tcp_acceptor_impl.cpp b/library/cpp/neh/asio/tcp_acceptor_impl.cpp new file mode 100644 index 0000000000..7e1d75fcf5 --- /dev/null +++ b/library/cpp/neh/asio/tcp_acceptor_impl.cpp @@ -0,0 +1,25 @@ +#include "tcp_acceptor_impl.h" + +using namespace NAsio; + +bool TOperationAccept::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, *this); + + return true; + } + + struct sockaddr_storage addr; + socklen_t sz = sizeof(addr); + + SOCKET res = ::accept(Fd(), (sockaddr*)&addr, &sz); + + if (res == INVALID_SOCKET) { + H_(LastSystemError(), *this); + } else { + NS_.Assign(res, TEndpoint(new NAddr::TOpaqueAddr((sockaddr*)&addr))); + H_(0, *this); + } + + return true; +} diff --git a/library/cpp/neh/asio/tcp_acceptor_impl.h b/library/cpp/neh/asio/tcp_acceptor_impl.h new file mode 100644 index 0000000000..c990236efc --- /dev/null +++ b/library/cpp/neh/asio/tcp_acceptor_impl.h @@ -0,0 +1,76 @@ +#pragma once + +#include "asio.h" + +#include "tcp_socket_impl.h" + +namespace NAsio { + class TOperationAccept: public TFdOperation { + public: + TOperationAccept(SOCKET fd, TTcpSocket::TImpl& newSocket, TTcpAcceptor::TAcceptHandler h, TInstant deadline) + : TFdOperation(fd, PollRead, deadline) + , H_(h) + , NS_(newSocket) + { + } + + bool Execute(int errorCode) override; + + TTcpAcceptor::TAcceptHandler H_; + TTcpSocket::TImpl& NS_; + }; + + class TTcpAcceptor::TImpl: public TThrRefBase { + public: + TImpl(TIOService::TImpl& srv) noexcept + : Srv_(srv) + { + } + + inline void Bind(TEndpoint& ep, TErrorCode& ec) noexcept { + TSocketHolder s(socket(ep.SockAddr()->sa_family, SOCK_STREAM, 0)); + + if (s == INVALID_SOCKET) { + ec.Assign(LastSystemError()); + } + + FixIPv6ListenSocket(s); + CheckedSetSockOpt(s, SOL_SOCKET, SO_REUSEADDR, 1, "reuse addr"); + SetNonBlock(s); + + if (::bind(s, ep.SockAddr(), ep.SockAddrLen())) { + ec.Assign(LastSystemError()); + return; + } + + S_.Swap(s); + } + + inline void Listen(int backlog, TErrorCode& ec) noexcept { + if (::listen(S_, backlog)) { + ec.Assign(LastSystemError()); + return; + } + } + + inline void AsyncAccept(TTcpSocket& s, TTcpAcceptor::TAcceptHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationAccept((SOCKET)S_, s.GetImpl(), h, deadline)); //set callback + } + + inline void AsyncCancel() { + Srv_.ScheduleOp(new TOperationCancel(this)); + } + + inline TIOService::TImpl& GetIOServiceImpl() const noexcept { + return Srv_; + } + + inline SOCKET Fd() const noexcept { + return S_; + } + + private: + TIOService::TImpl& Srv_; + TSocketHolder S_; + }; +} diff --git a/library/cpp/neh/asio/tcp_socket_impl.cpp b/library/cpp/neh/asio/tcp_socket_impl.cpp new file mode 100644 index 0000000000..98cef97561 --- /dev/null +++ b/library/cpp/neh/asio/tcp_socket_impl.cpp @@ -0,0 +1,117 @@ +#include "tcp_socket_impl.h" + +using namespace NAsio; + +TSocketOperation::TSocketOperation(TTcpSocket::TImpl& s, TPollType pt, TInstant deadline) + : TFdOperation(s.Fd(), pt, deadline) + , S_(s) +{ +} + +bool TOperationWrite::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, Written_, *this); + + return true; //op. completed + } + + TErrorCode ec; + TContIOVector& iov = *Buffs_->GetIOvec(); + + size_t n = S_.WriteSome(iov, ec); + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + H_(ec, Written_ + n, *this); + + return true; + } + + if (n) { + Written_ += n; + iov.Proceed(n); + if (!iov.Bytes()) { + H_(ec, Written_, *this); + + return true; //op. completed + } + } + + return false; //operation not compleled +} + +bool TOperationWriteVector::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, Written_, *this); + + return true; //op. completed + } + + TErrorCode ec; + + size_t n = S_.WriteSome(V_, ec); + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + H_(ec, Written_ + n, *this); + + return true; + } + + if (n) { + Written_ += n; + V_.Proceed(n); + if (!V_.Bytes()) { + H_(ec, Written_, *this); + + return true; //op. completed + } + } + + return false; //operation not compleled +} + +bool TOperationReadSome::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, 0, *this); + + return true; //op. completed + } + + TErrorCode ec; + + H_(ec, S_.ReadSome(Buff_, Size_, ec), *this); + + return true; +} + +bool TOperationRead::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, Read_, *this); + + return true; //op. completed + } + + TErrorCode ec; + size_t n = S_.ReadSome(Buff_, Size_, ec); + Read_ += n; + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + H_(ec, Read_, *this); + + return true; //op. completed + } + + if (n) { + Size_ -= n; + if (!Size_) { + H_(ec, Read_, *this); + + return true; + } + Buff_ += n; + } else if (!ec) { // EOF while read not all + H_(ec, Read_, *this); + return true; + } + + return false; +} diff --git a/library/cpp/neh/asio/tcp_socket_impl.h b/library/cpp/neh/asio/tcp_socket_impl.h new file mode 100644 index 0000000000..44f8f42d87 --- /dev/null +++ b/library/cpp/neh/asio/tcp_socket_impl.h @@ -0,0 +1,332 @@ +#pragma once + +#include "asio.h" +#include "io_service_impl.h" + +#include + +#if defined(_bionic_) +# define IOV_MAX 1024 +#endif + +namespace NAsio { + // ownership/keep-alive references: + // Handlers <- TOperation...(TFdOperation) <- TPollFdEventHandler <- TIOService + + class TSocketOperation: public TFdOperation { + public: + TSocketOperation(TTcpSocket::TImpl& s, TPollType pt, TInstant deadline); + + protected: + TTcpSocket::TImpl& S_; + }; + + class TOperationConnect: public TSocketOperation { + public: + TOperationConnect(TTcpSocket::TImpl& s, TTcpSocket::TConnectHandler h, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + { + } + + bool Execute(int errorCode) override { + H_(errorCode, *this); + + return true; + } + + TTcpSocket::TConnectHandler H_; + }; + + class TOperationConnectFailed: public TSocketOperation { + public: + TOperationConnectFailed(TTcpSocket::TImpl& s, TTcpSocket::TConnectHandler h, int errorCode, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + , ErrorCode_(errorCode) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + H_(ErrorCode_, *this); + + return true; + } + + TTcpSocket::TConnectHandler H_; + int ErrorCode_; + }; + + class TOperationWrite: public TSocketOperation { + public: + TOperationWrite(TTcpSocket::TImpl& s, NAsio::TTcpSocket::TSendedData& buffs, TTcpSocket::TWriteHandler h, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + , Buffs_(buffs) + , Written_(0) + { + Speculative_ = true; + } + + //return true, if not need write more data + bool Execute(int errorCode) override; + + private: + TTcpSocket::TWriteHandler H_; + NAsio::TTcpSocket::TSendedData Buffs_; + size_t Written_; + }; + + class TOperationWriteVector: public TSocketOperation { + public: + TOperationWriteVector(TTcpSocket::TImpl& s, TContIOVector* v, TTcpSocket::TWriteHandler h, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + , V_(*v) + , Written_(0) + { + Speculative_ = true; + } + + //return true, if not need write more data + bool Execute(int errorCode) override; + + private: + TTcpSocket::TWriteHandler H_; + TContIOVector& V_; + size_t Written_; + }; + + class TOperationReadSome: public TSocketOperation { + public: + TOperationReadSome(TTcpSocket::TImpl& s, void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) + : TSocketOperation(s, PollRead, deadline) + , H_(h) + , Buff_(static_cast(buff)) + , Size_(size) + { + } + + //return true, if not need read more data + bool Execute(int errorCode) override; + + protected: + TTcpSocket::TReadHandler H_; + char* Buff_; + size_t Size_; + }; + + class TOperationRead: public TOperationReadSome { + public: + TOperationRead(TTcpSocket::TImpl& s, void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) + : TOperationReadSome(s, buff, size, h, deadline) + , Read_(0) + { + } + + bool Execute(int errorCode) override; + + private: + size_t Read_; + }; + + class TOperationPoll: public TSocketOperation { + public: + TOperationPoll(TTcpSocket::TImpl& s, TPollType pt, TTcpSocket::TPollHandler h, TInstant deadline) + : TSocketOperation(s, pt, deadline) + , H_(h) + { + } + + bool Execute(int errorCode) override { + H_(errorCode, *this); + + return true; + } + + private: + TTcpSocket::TPollHandler H_; + }; + + template + class TOperationCancel: public TNoneOperation { + public: + TOperationCancel(T* s) + : TNoneOperation() + , S_(s) + { + Speculative_ = true; + } + + ~TOperationCancel() override { + } + + private: + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + if (!errorCode && S_->Fd() != INVALID_SOCKET) { + S_->GetIOServiceImpl().CancelFdOp(S_->Fd()); + } + return true; + } + + TIntrusivePtr S_; + }; + + class TTcpSocket::TImpl: public TNonCopyable, public TThrRefBase { + public: + typedef TTcpSocket::TSendedData TSendedData; + + TImpl(TIOService::TImpl& srv) noexcept + : Srv_(srv) + { + } + + ~TImpl() override { + DBGOUT("TSocket::~TImpl()"); + } + + void Assign(SOCKET fd, TEndpoint ep) { + TSocketHolder(fd).Swap(S_); + RemoteEndpoint_ = ep; + } + + void AsyncConnect(const TEndpoint& ep, TTcpSocket::TConnectHandler h, TInstant deadline) { + TSocketHolder s(socket(ep.SockAddr()->sa_family, SOCK_STREAM, 0)); + + if (Y_UNLIKELY(s == INVALID_SOCKET || Srv_.HasAbort())) { + throw TSystemError() << TStringBuf("can't create socket"); + } + + SetNonBlock(s); + + int err; + do { + err = connect(s, ep.SockAddr(), (int)ep.SockAddrLen()); + if (Y_LIKELY(err)) { + err = LastSystemError(); + } +#if defined(_freebsd_) + if (Y_UNLIKELY(err == EINTR)) { + err = EINPROGRESS; + } + } while (0); +#elif defined(_linux_) + } while (Y_UNLIKELY(err == EINTR)); +#else + } while (0); +#endif + + RemoteEndpoint_ = ep; + S_.Swap(s); + + DBGOUT("AsyncConnect(): " << err); + if (Y_LIKELY(err == EINPROGRESS || err == EWOULDBLOCK || err == 0)) { + Srv_.ScheduleOp(new TOperationConnect(*this, h, deadline)); //set callback + } else { + Srv_.ScheduleOp(new TOperationConnectFailed(*this, h, err, deadline)); //set callback + } + } + + inline void AsyncWrite(TSendedData& d, TTcpSocket::TWriteHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationWrite(*this, d, h, deadline)); + } + + inline void AsyncWrite(TContIOVector* v, TTcpSocket::TWriteHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationWriteVector(*this, v, h, deadline)); + } + + inline void AsyncRead(void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationRead(*this, buff, size, h, deadline)); + } + + inline void AsyncReadSome(void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationReadSome(*this, buff, size, h, deadline)); + } + + inline void AsyncPollWrite(TTcpSocket::TPollHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationPoll(*this, TOperationPoll::PollWrite, h, deadline)); + } + + inline void AsyncPollRead(TTcpSocket::TPollHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationPoll(*this, TOperationPoll::PollRead, h, deadline)); + } + + inline void AsyncCancel() { + if (Y_UNLIKELY(Srv_.HasAbort())) { + return; + } + Srv_.ScheduleOp(new TOperationCancel(this)); + } + + inline bool SysCallHasResult(ssize_t& n, TErrorCode& ec) noexcept { + if (n >= 0) { + return true; + } + + int errn = LastSystemError(); + if (errn == EINTR) { + return false; + } + + ec.Assign(errn); + n = 0; + return true; + } + + size_t WriteSome(TContIOVector& iov, TErrorCode& ec) noexcept { + for (;;) { + ssize_t n = writev(S_, (const iovec*)iov.Parts(), Min(IOV_MAX, (int)iov.Count())); + DBGOUT("WriteSome(): n=" << n); + if (SysCallHasResult(n, ec)) { + return n; + } + } + } + + size_t WriteSome(const void* buff, size_t size, TErrorCode& ec) noexcept { + for (;;) { + ssize_t n = send(S_, (char*)buff, size, 0); + DBGOUT("WriteSome(): n=" << n); + if (SysCallHasResult(n, ec)) { + return n; + } + } + } + + size_t ReadSome(void* buff, size_t size, TErrorCode& ec) noexcept { + for (;;) { + ssize_t n = recv(S_, (char*)buff, size, 0); + DBGOUT("ReadSome(): n=" << n); + if (SysCallHasResult(n, ec)) { + return n; + } + } + } + + inline void Shutdown(TTcpSocket::TShutdownMode mode, TErrorCode& ec) { + if (shutdown(S_, mode)) { + ec.Assign(LastSystemError()); + } + } + + TIOService::TImpl& GetIOServiceImpl() const noexcept { + return Srv_; + } + + inline SOCKET Fd() const noexcept { + return S_; + } + + TEndpoint RemoteEndpoint() const { + return RemoteEndpoint_; + } + + private: + TIOService::TImpl& Srv_; + TSocketHolder S_; + TEndpoint RemoteEndpoint_; + }; +} diff --git a/library/cpp/neh/lfqueue.h b/library/cpp/neh/lfqueue.h new file mode 100644 index 0000000000..c957047a99 --- /dev/null +++ b/library/cpp/neh/lfqueue.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +namespace NNeh { + template + class TAutoLockFreeQueue { + struct TCounter : TAtomicCounter { + inline void IncCount(const T* const&) { + Inc(); + } + + inline void DecCount(const T* const&) { + Dec(); + } + }; + + public: + typedef TAutoPtr TRef; + + inline ~TAutoLockFreeQueue() { + TRef tmp; + + while (Dequeue(&tmp)) { + } + } + + inline bool Dequeue(TRef* t) { + T* res = nullptr; + + if (Q_.Dequeue(&res)) { + t->Reset(res); + + return true; + } + + return false; + } + + inline void Enqueue(TRef& t) { + Q_.Enqueue(t.Get()); + Y_UNUSED(t.Release()); + } + + inline size_t Size() { + return Q_.GetCounter().Val(); + } + + private: + TLockFreeQueue Q_; + }; +} diff --git a/library/cpp/neh/pipequeue.h b/library/cpp/neh/pipequeue.h new file mode 100644 index 0000000000..bed8d44bd2 --- /dev/null +++ b/library/cpp/neh/pipequeue.h @@ -0,0 +1,207 @@ +#pragma once + +#include "lfqueue.h" + +#include +#include +#include +#include + +#ifdef _linux_ +#include +#endif + +#if defined(_bionic_) && !defined(EFD_SEMAPHORE) +#define EFD_SEMAPHORE 1 +#endif + +namespace NNeh { +#ifdef _linux_ + class TSemaphoreEventFd { + public: + inline TSemaphoreEventFd() { + F_ = eventfd(0, EFD_NONBLOCK | EFD_SEMAPHORE); + if (F_ < 0) { + ythrow TFileError() << "failed to create a eventfd"; + } + } + + inline ~TSemaphoreEventFd() { + close(F_); + } + + inline size_t Acquire(TCont* c) { + ui64 ev; + return NCoro::ReadI(c, F_, &ev, sizeof ev).Processed(); + } + + inline void Release() { + const static ui64 ev(1); + (void)write(F_, &ev, sizeof ev); + } + + private: + int F_; + }; +#endif + + class TSemaphorePipe { + public: + inline TSemaphorePipe() { + TPipeHandle::Pipe(S_[0], S_[1]); + + SetNonBlock(S_[0]); + SetNonBlock(S_[1]); + } + + inline size_t Acquire(TCont* c) { + char ch; + return NCoro::ReadI(c, S_[0], &ch, 1).Processed(); + } + + inline size_t Acquire(TCont* c, char* buff, size_t buflen) { + return NCoro::ReadI(c, S_[0], buff, buflen).Processed(); + } + + inline void Release() { + char ch = 13; + S_[1].Write(&ch, 1); + } + + private: + TPipeHandle S_[2]; + }; + + class TPipeQueueBase { + public: + inline void Enqueue(void* job) { + Q_.Enqueue(job); + S_.Release(); + } + + inline void* Dequeue(TCont* c, char* ch, size_t buflen) { + void* ret = nullptr; + + while (!Q_.Dequeue(&ret) && S_.Acquire(c, ch, buflen)) { + } + + return ret; + } + + inline void* Dequeue() noexcept { + void* ret = nullptr; + + Q_.Dequeue(&ret); + + return ret; + } + + private: + TLockFreeQueue Q_; + TSemaphorePipe S_; + }; + + template + class TPipeQueue { + public: + template + inline void EnqueueSafe(TPtr req) { + Enqueue(req.Get()); + req.Release(); + } + + inline void Enqueue(T* req) { + Q_.Enqueue(req); + } + + template + inline void DequeueSafe(TCont* c, TPtr& ret) { + ret.Reset(Dequeue(c)); + } + + inline T* Dequeue(TCont* c) { + char ch[buflen]; + + return (T*)Q_.Dequeue(c, ch, sizeof(ch)); + } + + protected: + TPipeQueueBase Q_; + }; + + //optimized for avoiding unnecessary usage semaphore + use eventfd on linux + template + struct TOneConsumerPipeQueue { + inline TOneConsumerPipeQueue() + : Signaled_(0) + , SkipWait_(0) + { + } + + inline void Enqueue(T* job) { + Q_.Enqueue(job); + + AtomicSet(SkipWait_, 1); + if (AtomicCas(&Signaled_, 1, 0)) { + S_.Release(); + } + } + + inline T* Dequeue(TCont* c) { + T* ret = nullptr; + + while (!Q_.Dequeue(&ret)) { + AtomicSet(Signaled_, 0); + if (!AtomicCas(&SkipWait_, 0, 1)) { + if (!S_.Acquire(c)) { + break; + } + } + AtomicSet(Signaled_, 1); + } + + return ret; + } + + template + inline void EnqueueSafe(TPtr req) { + Enqueue(req.Get()); + Y_UNUSED(req.Release()); + } + + template + inline void DequeueSafe(TCont* c, TPtr& ret) { + ret.Reset(Dequeue(c)); + } + + protected: + TLockFreeQueue Q_; +#ifdef _linux_ + TSemaphoreEventFd S_; +#else + TSemaphorePipe S_; +#endif + TAtomic Signaled_; + TAtomic SkipWait_; + }; + + template + struct TAutoPipeQueue: public TPipeQueue { + ~TAutoPipeQueue() { + while (T* t = (T*)TPipeQueue::Q_.Dequeue()) { + delete t; + } + } + }; + + template + struct TAutoOneConsumerPipeQueue: public TOneConsumerPipeQueue { + ~TAutoOneConsumerPipeQueue() { + T* ret = nullptr; + + while (TOneConsumerPipeQueue::Q_.Dequeue(&ret)) { + delete ret; + } + } + }; +} From 452511af41fe78812d54066bf80c535a4437c0e2 Mon Sep 17 00:00:00 2001 From: Ilya-Repin Date: Mon, 10 Nov 2025 05:22:56 +0000 Subject: [PATCH 7/9] Add GetFastestLocation --- include/ydb-cpp-sdk/client/types/ydb.h | 3 +- .../internal/local_dc_detector/CMakeLists.txt | 2 + .../local_dc_detector/local_dc_detector.cpp | 52 +-- .../local_dc_detector/local_dc_detector.h | 24 +- .../internal/local_dc_detector/pinger.cpp | 106 +++++- .../impl/internal/local_dc_detector/pinger.h | 41 ++- tests/unit/client/CMakeLists.txt | 2 +- .../local_dc_detector_ut.cpp | 343 +++++++----------- 8 files changed, 309 insertions(+), 264 deletions(-) diff --git a/include/ydb-cpp-sdk/client/types/ydb.h b/include/ydb-cpp-sdk/client/types/ydb.h index ae689699df..6d0a323b52 100644 --- a/include/ydb-cpp-sdk/client/types/ydb.h +++ b/include/ydb-cpp-sdk/client/types/ydb.h @@ -54,7 +54,8 @@ class TBalancingPolicy { //! location is a name of datacenter (VLA, MAN), if location is nullopt local datacenter is used static TBalancingPolicy UsePreferableLocation(const std::optional& location = {}); - //! Use detected local dc + //! Use detected local DC + //! prefer datacenter with fastest tcp ping static TBalancingPolicy UseDetectedLocalDC(); //! Use all available cluster nodes regardless datacenter locality diff --git a/src/client/impl/internal/local_dc_detector/CMakeLists.txt b/src/client/impl/internal/local_dc_detector/CMakeLists.txt index 71a33f6aad..9abb12ae17 100644 --- a/src/client/impl/internal/local_dc_detector/CMakeLists.txt +++ b/src/client/impl/internal/local_dc_detector/CMakeLists.txt @@ -3,6 +3,8 @@ _ydb_sdk_add_library(impl-internal-local_dc_detector) target_link_libraries(impl-internal-local_dc_detector PUBLIC yutil api-protos + neh-asio + threading-future ) target_sources(impl-internal-local_dc_detector PRIVATE diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp index b9fc78b35e..b8e3c9cffb 100644 --- a/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp @@ -1,21 +1,26 @@ #define INCLUDE_YDB_INTERNAL_H #include "local_dc_detector.h" +#include +#include + namespace NYdb::inline V3 { -TLocalDCDetector::TLocalDCDetector(TPinger pingEndpoint) - : PingEndpoint_(std::move(pingEndpoint)) +TLocalDCDetector::TLocalDCDetector(std::unique_ptr pinger) + : Pinger_(std::move(pinger)) {} -void TLocalDCDetector::DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpointsList) { +std::string TLocalDCDetector::DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpointsList) { auto endpointsByLocation = GroupEndpointsByLocation(endpointsList); - SampleEndpoints(endpointsByLocation); + SelectEndpoints(endpointsByLocation); if (endpointsByLocation.size() > 1) { - Location_ = FindNearestLocation(endpointsByLocation); + Location_ = FindNearestLocation(); + Pinger_->Reset(); } else { Location_.clear(); } + return Location_; } bool TLocalDCDetector::IsLocalDC(const std::string& location) const { @@ -30,7 +35,7 @@ TLocalDCDetector::TEndpointsByLocation TLocalDCDetector::GroupEndpointsByLocatio return endpointsByLocation; } -void TLocalDCDetector::SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const { +void TLocalDCDetector::SelectEndpoints(TEndpointsByLocation& endpointsByLocation) const { std::mt19937 gen(std::random_device{}()); for (auto& [location, endpoints] : endpointsByLocation) { if (endpoints.size() > MAX_ENDPOINTS_PER_LOCATION) { @@ -39,37 +44,18 @@ void TLocalDCDetector::SampleEndpoints(TEndpointsByLocation& endpointsByLocation std::sample(endpoints.begin(), endpoints.end(), std::back_inserter(sample), MAX_ENDPOINTS_PER_LOCATION, gen); endpoints.swap(sample); } + for (const auto& endpoint : endpoints) { + Pinger_->AddEndpoint(endpoint.get(), DETECT_TIMEOUT); + } } } -std::uint64_t TLocalDCDetector::MeasureLocationRtt(const std::vector& endpoints) const { - if (endpoints.empty()) { - return std::numeric_limits::max(); - } - - std::vector timings; - timings.reserve(PING_COUNT); - for (size_t i = 0; i < PING_COUNT; ++i) { - const auto& ep = endpoints[i % endpoints.size()].get(); - timings.push_back(PingEndpoint_(ep).MicroSeconds()); - } - std::sort(timings.begin(), timings.end()); - - return std::midpoint(timings[(PING_COUNT - 1) / 2], timings[PING_COUNT / 2]); -} - - -std::string TLocalDCDetector::FindNearestLocation(const TEndpointsByLocation& endpointsByLocation) { - auto minRtt = std::numeric_limits::max(); - std::string nearestLocation; - for (const auto& [location, endpoints] : endpointsByLocation) { - auto rtt = MeasureLocationRtt(endpoints); - if (rtt < minRtt) { - minRtt = rtt; - nearestLocation = location; - } +std::string TLocalDCDetector::FindNearestLocation() { + try { + return Pinger_->GetFastestLocation().GetValue(DETECT_TIMEOUT); + } catch(...) { + return EMPTY_LOCATION; } - return nearestLocation; } } // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.h b/src/client/impl/internal/local_dc_detector/local_dc_detector.h index d255f8c047..2cc0d576d6 100644 --- a/src/client/impl/internal/local_dc_detector/local_dc_detector.h +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.h @@ -4,10 +4,8 @@ #include #include -#include +#include #include -#include -#include #include #include @@ -15,27 +13,25 @@ namespace NYdb::inline V3 { class TLocalDCDetector { public: - using TPinger = std::function; - explicit TLocalDCDetector(TPinger pingEndpoint = PingEndpoint); + explicit TLocalDCDetector(std::unique_ptr pinger = std::make_unique()); - void DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpoints); + std::string DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpoints); bool IsLocalDC(const std::string& location) const; private: - using TEndpoint = Ydb::Discovery::EndpointInfo; - using TEndpointRef = std::reference_wrapper; + using TEndpointRef = std::reference_wrapper; using TEndpointsByLocation = std::unordered_map>; TEndpointsByLocation GroupEndpointsByLocation(const Ydb::Discovery::ListEndpointsResult& endpointsList) const; - void SampleEndpoints(TEndpointsByLocation& endpointsByLocation) const; - std::uint64_t MeasureLocationRtt(const std::vector& endpoints) const; - std::string FindNearestLocation(const TEndpointsByLocation& endpointsByLocation); + void SelectEndpoints(TEndpointsByLocation& endpointsByLocation) const; + std::string FindNearestLocation(); private: - static constexpr std::size_t MAX_ENDPOINTS_PER_LOCATION = 3; - static constexpr std::size_t PING_COUNT = 2 * MAX_ENDPOINTS_PER_LOCATION; + static constexpr std::size_t MAX_ENDPOINTS_PER_LOCATION = 5; + static constexpr auto DETECT_TIMEOUT = std::chrono::seconds(5); + static constexpr auto EMPTY_LOCATION = ""; - TPinger PingEndpoint_; + std::unique_ptr Pinger_; std::string Location_; }; diff --git a/src/client/impl/internal/local_dc_detector/pinger.cpp b/src/client/impl/internal/local_dc_detector/pinger.cpp index 1acb49c9b3..19fc5dd9b7 100644 --- a/src/client/impl/internal/local_dc_detector/pinger.cpp +++ b/src/client/impl/internal/local_dc_detector/pinger.cpp @@ -1,17 +1,107 @@ #define INCLUDE_YDB_INTERNAL_H #include "pinger.h" +#include + namespace NYdb::inline V3 { -TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint) { - try { - TNetworkAddress addr(endpoint.address().data(), static_cast(endpoint.port())); - auto start = TInstant::Now(); - TSocket sock(addr, TDuration::Seconds(PING_TIMEOUT_SECONDS)); - return TInstant::Now() - start; - } catch (...) { - return TDuration::Max(); +class TPinger::TLocationPinger : public IThreadFactory::IThreadAble { +public: + using THandleConnect = std::function; + + TLocationPinger() = default; + + ~TLocationPinger() { + if (Thread_) { + IOService_.Abort(); + Thread_->Join(); + } + } + + void Start(IThreadPool& pool) { + Thread_.reset(pool.Run(this).Release()); + } + + void DoExecute() override { + IOService_.Run(); + } + + std::size_t AddEndpointPings(const std::string& address, const std::uint32_t port, const std::chrono::seconds timeout, THandleConnect handleConnect) { + auto& addr = Addresses_.emplace_back(address.c_str(), port); + + std::size_t pingsAmount = 0; + for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { + auto& socket = Sockets_.emplace_back( + std::make_unique(IOService_)); + auto& endpoint = Endpoints_.emplace_back( + MakeAtomicShared(&*it)); + + socket->AsyncConnect(endpoint, handleConnect, {timeout}); + ++pingsAmount; + } + return pingsAmount; + } + +private: + NAsio::TIOService IOService_; + std::unique_ptr Thread_; + + std::vector> Sockets_; + std::vector Endpoints_; + std::vector Addresses_; +}; + +TPinger::TPinger() + : PingContext_(std::make_shared()) + , ThreadPool_(std::make_unique()) +{ + ThreadPool_->Start(); +} + +TPinger::~TPinger() { + PingContext_.reset(); + ThreadPool_->Stop(); +} + +void TPinger::Reset() { + PingContext_ = std::make_shared(); +} + + +void TPinger::AddEndpoint(const Ydb::Discovery::EndpointInfo& endpoint, const std::chrono::seconds timeout) { + std::string location = endpoint.location(); + std::weak_ptr pingContext = PingContext_; + auto handleConnect = [location, pingContext](const NAsio::TErrorCode& err, NAsio::IHandlingContext& handlingContext) { + if (auto ctx = pingContext.lock()) { + if (!err) { + ctx->ResultPromise.TrySetValue(location); + } else if (ctx->PingsToFail.fetch_sub(1, std::memory_order_acq_rel) == 1) { + ctx->ResultPromise.TrySetException(std::make_exception_ptr( + std::runtime_error("All pings failed"))); + } + } + }; + + auto [it, inserted] = PingContext_->PingerByLocation.try_emplace(location, std::make_unique()); + auto& pinger = it->second; + + PingContext_->PingsToFail += + pinger->AddEndpointPings(endpoint.address(), endpoint.port(), timeout, handleConnect); +} + + +NThreading::TFuture TPinger::GetFastestLocation() { + auto future = PingContext_->ResultPromise.GetFuture(); + if (PingContext_->PingsToFail.load() > 0) { + for (auto& [location, pinger] : PingContext_->PingerByLocation) { + pinger->Start(*ThreadPool_); + } + } else { + PingContext_->ResultPromise.TrySetException(std::make_exception_ptr( + std::runtime_error("No pings to perform"))); } + + return future; } } // namespace NYdb diff --git a/src/client/impl/internal/local_dc_detector/pinger.h b/src/client/impl/internal/local_dc_detector/pinger.h index 627f7a368f..216237f0b3 100644 --- a/src/client/impl/internal/local_dc_detector/pinger.h +++ b/src/client/impl/internal/local_dc_detector/pinger.h @@ -1,15 +1,46 @@ #pragma once -#include #include +#include + +#include -#include -#include +#include + +#include namespace NYdb::inline V3 { -static constexpr std::uint32_t PING_TIMEOUT_SECONDS = 5; +class IPinger { +public: + virtual ~IPinger() = default; + + virtual void Reset() = 0; + virtual void AddEndpoint(const Ydb::Discovery::EndpointInfo& endpoint, const std::chrono::seconds timeout) = 0; + virtual NThreading::TFuture GetFastestLocation() = 0; +}; + +class TPinger : public IPinger { +public: + TPinger(); + ~TPinger() override; + + void Reset() override; + void AddEndpoint(const Ydb::Discovery::EndpointInfo& endpoint, const std::chrono::seconds timeout) override; + + NThreading::TFuture GetFastestLocation() override; + +private: + class TLocationPinger; + + struct TPingContext { + std::unordered_map> PingerByLocation; + std::atomic PingsToFail = 0; + NThreading::TPromise ResultPromise = NThreading::NewPromise(); + }; -TDuration PingEndpoint(const Ydb::Discovery::EndpointInfo& endpoint); + std::shared_ptr PingContext_; + std::unique_ptr ThreadPool_; +}; } // namespace NYdb diff --git a/tests/unit/client/CMakeLists.txt b/tests/unit/client/CMakeLists.txt index f3452a0c45..f9d9f08688 100644 --- a/tests/unit/client/CMakeLists.txt +++ b/tests/unit/client/CMakeLists.txt @@ -45,7 +45,7 @@ add_ydb_test(NAME client-impl-ydb_endpoints_ut unit ) -add_ydb_test(NAME client-impl-internal-local_dc_detector_ut +add_ydb_test(NAME client-impl-internal-local_dc_detector_ut GTEST INCLUDE_DIRS ${YDB_SDK_SOURCE_DIR}/src/client/impl/internal/local_dc_detector SOURCES diff --git a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp index d68f0ebec7..9b92aca62a 100644 --- a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp +++ b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp @@ -4,243 +4,182 @@ #include #undef INCLUDE_YDB_INTERNAL_H -#include -#include +#include -using namespace NYdb; - -class TMockedEndpoint { -public: - explicit TMockedEndpoint(std::vector measures) - : Measures_(std::move(measures)) - , Idx_(0) - {} - - TDuration Ping() { - std::size_t idx = Idx_++; - - if (idx < Measures_.size()) { - return Measures_.at(idx); - } - return TDuration::Max(); - } +#include -private: - const std::vector Measures_; - std::size_t Idx_; -}; +using namespace NYdb; -class TMockedPinger { +class TMockedPinger : public IPinger { public: - explicit TMockedPinger(std::unordered_map> measuresByAdress) { - EndpointByAdress_.reserve(measuresByAdress.size()); - - for (auto& [adress, measures] : measuresByAdress) { - EndpointByAdress_.emplace(std::move(adress), std::move(measures)); - } - } + void Reset() override {} - TDuration operator()(const Ydb::Discovery::EndpointInfo& endpoint) const { - auto it = EndpointByAdress_.find(endpoint.address()); - if (it == EndpointByAdress_.end() || Blacklist_.contains(endpoint.address())) { - return TDuration::Max(); - } - return it->second.Ping(); + void AddEndpoint(const Ydb::Discovery::EndpointInfo& endpoint, const std::chrono::seconds timeout) override { + EndpointsPerLocation_[endpoint.location()].insert(endpoint.address()); } - void BanEndpoint(const std::string& adress) { - Blacklist_.insert(adress); + NThreading::TFuture GetFastestLocation() override { + return NThreading::MakeFuture(""); } - void UnbanEndpoint(const std::string& adress) { - Blacklist_.erase(adress); + const std::unordered_map>& GetSelectedEndpoints() { + return EndpointsPerLocation_; } private: - mutable std::unordered_map EndpointByAdress_; - std::unordered_set Blacklist_; + std::unordered_map> EndpointsPerLocation_; }; -std::vector GenerateMeasures(size_t count, int minMs, int maxMs, std::mt19937& gen) { - std::vector measures; - measures.reserve(count); - std::uniform_int_distribution distrib(minMs, maxMs); - for (size_t i = 0; i < count; ++i) { - measures.push_back(TDuration::MicroSeconds(distrib(gen))); - } - return measures; -} - -Y_UNIT_TEST_SUITE(LocalDCDetectionTest) { - Y_UNIT_TEST(Basic) { - Ydb::Discovery::ListEndpointsResult endpoints; - std::unordered_map> mockData; - std::mt19937 gen(std::random_device{}()); - - const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; - const std::vector endpointsB = {"B1", "B2", "B3"}; - const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; - - const std::size_t epoches = 3; - const std::size_t measuresAmount = 10 * epoches; +struct LocalDCDetectionParams { + std::vector InputAddresses; + + size_t ExpectedLocationsAmount; + std::unordered_map ExpectedCountsPerLocation; + + std::string TestName; +}; - for (const auto& ep : endpointsA) { - mockData[ep] = GenerateMeasures(measuresAmount, 20, 30, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("A"); - endpoint.set_address(ep); - } - for (const auto& ep : endpointsB) { - mockData[ep] = GenerateMeasures(measuresAmount, 30, 45, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("B"); - endpoint.set_address(ep); - } - for (const auto& ep : endpointsC) { - mockData[ep] = GenerateMeasures(measuresAmount, 50, 70, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("C"); - endpoint.set_address(ep); - } +std::ostream& operator<<(std::ostream& os, const LocalDCDetectionParams& params) { + return os << params.TestName; +} - std::function pinger = TMockedPinger(mockData); - TLocalDCDetector detector(pinger); +class EndpointsSelectionTest : public ::testing::TestWithParam { +protected: + std::unique_ptr detector; + TMockedPinger* pinger; - for (std::size_t i = 0; i < epoches; ++i) { - detector.DetectLocalDC(endpoints); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); - } + void SetUp() override { + auto pingerPtr = std::make_unique(); + pinger = pingerPtr.get(); + detector = std::make_unique(std::move(pingerPtr)); } +}; - Y_UNIT_TEST(SingleLocation) { - Ydb::Discovery::ListEndpointsResult endpoints; - std::unordered_map> mockData; - std::mt19937 gen(std::random_device{}()); - - const std::vector endpointsA = {"A1", "A2", "A3"}; - - for (const auto& ep : endpointsA) { - mockData[ep] = GenerateMeasures(10, 20, 30, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("A"); - endpoint.set_address(ep); - } - - std::function pinger = TMockedPinger(mockData); - TLocalDCDetector detector(pinger); - - detector.DetectLocalDC(endpoints); +TEST_P(EndpointsSelectionTest, EndpointsSelection) { + const auto& params = GetParam(); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); + Ydb::Discovery::ListEndpointsResult endpoints; + for (const auto& ep : params.InputAddresses) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location(ep.substr(0, 1)); + endpoint.set_address(ep); } - Y_UNIT_TEST(UnavailableLocalDC) { - Ydb::Discovery::ListEndpointsResult endpoints; - std::unordered_map> mockData; - std::mt19937 gen(std::random_device{}()); - - const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; - const std::vector endpointsB = {"B1", "B2", "B3", "B4", "B5", "B6", "B7"}; - const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; - - const std::size_t epoches = 3; - const std::size_t measuresAmount = 10 * epoches; - - for (const auto& ep : endpointsA) { - mockData[ep] = GenerateMeasures(measuresAmount, 20, 30, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("A"); - endpoint.set_address(ep); - } - for (const auto& ep : endpointsB) { - mockData[ep] = GenerateMeasures(measuresAmount, 30, 45, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("B"); - endpoint.set_address(ep); - } - for (const auto& ep : endpointsC) { - mockData[ep] = GenerateMeasures(measuresAmount, 50, 70, gen); - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("C"); - endpoint.set_address(ep); - } - - TMockedPinger mockPinger(mockData); - std::function pinger = - [&mockPinger](const Ydb::Discovery::EndpointInfo& endpoint) { return mockPinger(endpoint); }; + detector->DetectLocalDC(endpoints); - TLocalDCDetector detector(pinger); + auto selected = pinger->GetSelectedEndpoints(); - detector.DetectLocalDC(endpoints); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + ASSERT_EQ(selected.size(), params.ExpectedLocationsAmount); - for (const auto& ep : endpointsA) { - mockPinger.BanEndpoint(ep); - } + for (const auto& [location, expectedCount] : params.ExpectedCountsPerLocation) { + auto it = selected.find(location); + + ASSERT_NE(it, selected.end()); + EXPECT_EQ(it->second.size(), expectedCount); + } - detector.DetectLocalDC(endpoints); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("A")); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); + for (const auto& [location, _] : selected) { + EXPECT_TRUE(params.ExpectedCountsPerLocation.count(location)); + } +} - for (const auto& ep : endpointsA) { - mockPinger.UnbanEndpoint(ep); - } +INSTANTIATE_TEST_SUITE_P( + LocalDCDetectorTest, + EndpointsSelectionTest, + ::testing::Values ( + LocalDCDetectionParams{ + {"A1", "A2", "A3", "A4", "A5", "A6", "B1", "B2", "B3", "B4", "B5", "C1", "C2", "C3", "C4", "C5"}, + 3, + {{"A", 5}, {"B", 5}, {"C", 5}}, + "Basic" + }, + LocalDCDetectionParams{ + {"A1", "A2", "A3", "B1"}, + 2, + {{"A", 3}, {"B", 1}}, + "BelowLimit" + }, + LocalDCDetectionParams{ + {}, + 0, + {}, + "EmptyInput" + }, + LocalDCDetectionParams{ + {"X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8"}, + 1, + {{"X", 5}}, + "SingleLocationOverLimit" + } + ) +); + +ui16 GetAssignedPort(SOCKET s) { + struct sockaddr_in6 bound_addr_sys; + socklen_t addrLen = sizeof(bound_addr_sys); + + if (getsockname(s, (struct sockaddr*)&bound_addr_sys, &addrLen) == 0) { + const ui16 assignedPort = ntohs(bound_addr_sys.sin6_port); + + return assignedPort; + } + return 0; +} - detector.DetectLocalDC(endpoints); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("B")); - UNIT_ASSERT_VALUES_EQUAL(false, detector.IsLocalDC("C")); +TEST (TPingerTest, FastestEndpoint) { + TInetStreamSocket socket; + { + TSockAddrInet addr("127.0.0.1", 0); + TBaseSocket::Check(socket.Bind(&addr)); + TBaseSocket::Check(socket.Listen(1)); } - Y_UNIT_TEST(OfflineDCs) { - Ydb::Discovery::ListEndpointsResult endpoints; - std::unordered_map> mockData; - std::mt19937 gen(std::random_device{}()); + Ydb::Discovery::EndpointInfo endpoint; + endpoint.set_address("127.0.0.1"); + endpoint.set_port(GetAssignedPort(socket)); + endpoint.set_location("real"); + + TInetStreamSocket fakeSocket; + { + TSockAddrInet addr("127.0.0.1", 0); + TBaseSocket::Check(fakeSocket.Bind(&addr)); + TBaseSocket::Check(fakeSocket.Listen(1)); + } - const std::vector endpointsA = {"A1", "A2", "A3", "A4", "A5", "A6", "A7"}; - const std::vector endpointsB = {"B1", "B2", "B3", "B4", "B5", "B6", "B7"}; - const std::vector endpointsC = {"C1", "C2", "C3", "C4", "C5", "C6", "C7"}; + Ydb::Discovery::EndpointInfo fakeEndpoint; + fakeEndpoint.set_address("127.0.0.1"); + fakeEndpoint.set_port(GetAssignedPort(fakeSocket)); + fakeEndpoint.set_location("fake"); - for (const auto& ep : endpointsA) { - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("A"); - endpoint.set_address(ep); - } - for (const auto& ep : endpointsB) { - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("B"); - endpoint.set_address(ep); - } - for (const auto& ep : endpointsC) { - auto& endpoint = *endpoints.add_endpoints(); - endpoint.set_location("C"); - endpoint.set_address(ep); - } + TPinger pinger; + pinger.AddEndpoint(endpoint, std::chrono::seconds(1)); + pinger.AddEndpoint(fakeEndpoint, std::chrono::seconds(1)); - TMockedPinger mockPinger(mockData); - std::function pinger = - [&mockPinger](const Ydb::Discovery::EndpointInfo& endpoint) { return mockPinger(endpoint); }; + fakeSocket.Close(); - TLocalDCDetector detector(pinger); + std::string localDC = pinger.GetFastestLocation().GetValue(std::chrono::seconds(1)); + pinger.Reset(); - for (const auto& ep : endpointsA) { - mockPinger.BanEndpoint(ep); - } - for (const auto& ep : endpointsB) { - mockPinger.BanEndpoint(ep); - } - for (const auto& ep : endpointsC) { - mockPinger.BanEndpoint(ep); - } + EXPECT_EQ(localDC, "real"); +} - detector.DetectLocalDC(endpoints); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("A")); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("B")); - UNIT_ASSERT_VALUES_EQUAL(true, detector.IsLocalDC("C")); +TEST (TPingerTest, FailedPings) { + TPinger pinger; + for (std::size_t i = 0; i < 3; ++i) { + TInetStreamSocket fakeSocket; + { + TSockAddrInet addr("127.0.0.1", 0); + TBaseSocket::Check(fakeSocket.Bind(&addr)); + TBaseSocket::Check(fakeSocket.Listen(1)); + } + Ydb::Discovery::EndpointInfo fakeEndpoint; + fakeEndpoint.set_address("127.0.0.1"); + fakeEndpoint.set_port(GetAssignedPort(fakeSocket)); + fakeEndpoint.set_location("fake" + std::to_string(i)); + pinger.AddEndpoint(fakeEndpoint, std::chrono::seconds(1)); + fakeSocket.Close(); } + + EXPECT_THROW(pinger.GetFastestLocation().GetValue(std::chrono::seconds(1)), std::runtime_error); + pinger.Reset(); } From daf74ed23074b42b499dc0ba6fa069502664d653 Mon Sep 17 00:00:00 2001 From: Ilya-Repin Date: Mon, 10 Nov 2025 05:34:35 +0000 Subject: [PATCH 8/9] fix --- .../impl/internal/local_dc_detector/local_dc_detector.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp index b8e3c9cffb..e9b042608a 100644 --- a/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp +++ b/src/client/impl/internal/local_dc_detector/local_dc_detector.cpp @@ -12,9 +12,9 @@ TLocalDCDetector::TLocalDCDetector(std::unique_ptr pinger) std::string TLocalDCDetector::DetectLocalDC(const Ydb::Discovery::ListEndpointsResult& endpointsList) { auto endpointsByLocation = GroupEndpointsByLocation(endpointsList); - SelectEndpoints(endpointsByLocation); if (endpointsByLocation.size() > 1) { + SelectEndpoints(endpointsByLocation); Location_ = FindNearestLocation(); Pinger_->Reset(); } else { From e8ef1d665890a775b81984a67fc9031dc4e215b6 Mon Sep 17 00:00:00 2001 From: Ilya-Repin Date: Mon, 10 Nov 2025 07:11:16 +0000 Subject: [PATCH 9/9] fix --- library/cpp/dns/CMakeLists.txt | 2 ++ library/cpp/neh/asio/CMakeLists.txt | 2 ++ .../local_dc_detector_ut.cpp | 23 +++++++++++++++---- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/library/cpp/dns/CMakeLists.txt b/library/cpp/dns/CMakeLists.txt index 634a5e96a9..821ed53f8c 100644 --- a/library/cpp/dns/CMakeLists.txt +++ b/library/cpp/dns/CMakeLists.txt @@ -15,3 +15,5 @@ target_link_libraries(dns PUBLIC yutil ) + +_ydb_sdk_install_targets(TARGETS dns) diff --git a/library/cpp/neh/asio/CMakeLists.txt b/library/cpp/neh/asio/CMakeLists.txt index b053959419..a4c63b701d 100644 --- a/library/cpp/neh/asio/CMakeLists.txt +++ b/library/cpp/neh/asio/CMakeLists.txt @@ -20,3 +20,5 @@ target_link_libraries(neh-asio PRIVATE enum_serialization_runtime ) + +_ydb_sdk_install_targets(TARGETS neh-asio) \ No newline at end of file diff --git a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp index 9b92aca62a..c7b9212a6a 100644 --- a/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp +++ b/tests/unit/client/local_dc_detector/local_dc_detector_ut.cpp @@ -106,14 +106,29 @@ INSTANTIATE_TEST_SUITE_P( "EmptyInput" }, LocalDCDetectionParams{ - {"X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8"}, - 1, - {{"X", 5}}, - "SingleLocationOverLimit" + {"A1", "B1", "C1"}, + 3, + {{"A", 1}, {"B", 1}, {"C", 1}}, + "SingleNodesInLocations" } ) ); +TEST (TLocalDCDetector, SingleLocation) { + TLocalDCDetector detector; + + Ydb::Discovery::ListEndpointsResult endpoints; + for (std::size_t i = 0; i < 3; ++i) { + auto& endpoint = *endpoints.add_endpoints(); + endpoint.set_location("A"); + endpoint.set_address(std::to_string(i)); + } + + std::string localDC = detector.DetectLocalDC(endpoints); + + EXPECT_EQ(localDC, ""); +} + ui16 GetAssignedPort(SOCKET s) { struct sockaddr_in6 bound_addr_sys; socklen_t addrLen = sizeof(bound_addr_sys);