diff --git a/include/elio/net/tcp.hpp b/include/elio/net/tcp.hpp index f04338a..126161c 100644 --- a/include/elio/net/tcp.hpp +++ b/include/elio/net/tcp.hpp @@ -676,6 +676,7 @@ class tcp_connect_awaitable { int ret = ::connect(fd_, reinterpret_cast(&sa_), sa_len); if (ret == 0) { // Connected immediately (rare for TCP, but possible) + connect_in_progress_ = false; result_ = io::io_result{0, 0}; return false; // Don't suspend, resume immediately } @@ -689,10 +690,11 @@ class tcp_connect_awaitable { } // Connection in progress, wait for socket to become writable + connect_in_progress_ = true; auto& ctx = io::current_io_context(); io::io_request req{}; - req.op = io::io_op::connect; + req.op = io::io_op::poll_write; req.fd = fd_; req.addr = reinterpret_cast(&sa_); req.addrlen = &sa_len_; @@ -709,11 +711,20 @@ class tcp_connect_awaitable { } std::optional await_resume() { - // If result wasn't set (async path completed), get from io_context - if (result_.result == 0 && fd_ >= 0) { - auto ctx_result = io::io_context::get_last_result(); - if (ctx_result.result != 0 || ctx_result.flags != 0) { - result_ = ctx_result; + // Async path completion result comes from io_context. + if (connect_in_progress_ && fd_ >= 0) { + result_ = io::io_context::get_last_result(); + } + + // For non-blocking connect, writability means completion, not success. + // Use SO_ERROR to fetch the actual connect result. + if (connect_in_progress_ && result_.success() && fd_ >= 0) { + int so_error = 0; + socklen_t len = sizeof(so_error); + if (::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) { + result_ = io::io_result{-errno, 0}; + } else if (so_error != 0) { + result_ = io::io_result{-so_error, 0}; } } @@ -740,6 +751,7 @@ class tcp_connect_awaitable { struct sockaddr_storage sa_{}; socklen_t sa_len_ = sizeof(sa_); int fd_ = -1; + bool connect_in_progress_ = false; io::io_result result_{}; }; diff --git a/include/elio/net/uds.hpp b/include/elio/net/uds.hpp index c9e6304..6451e7d 100644 --- a/include/elio/net/uds.hpp +++ b/include/elio/net/uds.hpp @@ -465,6 +465,7 @@ class uds_connect_awaitable { int ret = ::connect(fd_, reinterpret_cast(&sa_), sa_len_); if (ret == 0) { // Connected immediately (common for UDS) + connect_in_progress_ = false; result_ = io::io_result{0, 0}; return false; // Don't suspend, resume immediately } @@ -478,10 +479,11 @@ class uds_connect_awaitable { } // Connection in progress, wait for socket to become writable + connect_in_progress_ = true; auto& ctx = io::current_io_context(); io::io_request req{}; - req.op = io::io_op::connect; + req.op = io::io_op::poll_write; req.fd = fd_; req.addr = reinterpret_cast(&sa_); req.addrlen = &sa_len_; @@ -498,16 +500,21 @@ class uds_connect_awaitable { } std::optional await_resume() { - // For async completion (EINPROGRESS path), get result from io_context - // For immediate completion, result_ is already set - if (result_.result == 0 && result_.flags == 0 && fd_ >= 0) { - // This could be immediate success ({0,0}) or we need to check async result - auto ctx_result = io::io_context::get_last_result(); - // Only use ctx_result if it looks like a real completion (not default) - if (ctx_result.result != 0 || ctx_result.flags != 0) { - result_ = ctx_result; + // Async path completion result comes from io_context. + if (connect_in_progress_ && fd_ >= 0) { + result_ = io::io_context::get_last_result(); + } + + // For non-blocking connect, writability means completion, not success. + // Use SO_ERROR to fetch the actual connect result. + if (connect_in_progress_ && result_.success() && fd_ >= 0) { + int so_error = 0; + socklen_t len = sizeof(so_error); + if (::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) { + result_ = io::io_result{-errno, 0}; + } else if (so_error != 0) { + result_ = io::io_result{-so_error, 0}; } - // If ctx_result is also {0,0}, keep our result_ (immediate success) } if (!result_.success()) { @@ -533,6 +540,7 @@ class uds_connect_awaitable { struct sockaddr_un sa_{}; socklen_t sa_len_ = 0; int fd_ = -1; + bool connect_in_progress_ = false; io::io_result result_{}; }; diff --git a/tests/unit/test_io.cpp b/tests/unit/test_io.cpp index 7cabd9a..cc3bc11 100644 --- a/tests/unit/test_io.cpp +++ b/tests/unit/test_io.cpp @@ -994,6 +994,23 @@ TEST_CASE("UDS echo test", "[uds][echo]") { #include +static task tcp_connect_regression_attempt( + uint16_t port, + std::atomic& connected, + std::atomic& failed, + std::atomic& first_error) { + auto stream = co_await tcp_connect(ipv6_address("::1", port)); + if (stream) { + connected.fetch_add(1, std::memory_order_relaxed); + } else { + failed.fetch_add(1, std::memory_order_relaxed); + int err = errno; + int expected = 0; + first_error.compare_exchange_strong(expected, err); + } + co_return; +} + TEST_CASE("ipv4_address basic operations", "[tcp][address][ipv4]") { SECTION("default constructor") { ipv4_address addr; @@ -1199,6 +1216,37 @@ TEST_CASE("TCP IPv6 listener and connect", "[tcp][ipv6][integration]") { } } +TEST_CASE("TCP connect regression avoids double connect", "[tcp][connect][regression]") { + auto listener = tcp_listener::bind(ipv6_address("::1", 0)); + REQUIRE(listener.has_value()); + + uint16_t port = listener->local_address().port(); + REQUIRE(port > 0); + + std::atomic connected{0}; + std::atomic failed{0}; + std::atomic first_error{0}; + + scheduler sched(2); + sched.start(); + + constexpr int kAttempts = 64; + for (int i = 0; i < kAttempts; ++i) { + auto t = tcp_connect_regression_attempt(port, connected, failed, first_error); + sched.spawn(t.release()); + } + + for (int i = 0; i < 500 && (connected + failed) < kAttempts; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + sched.shutdown(); + + INFO("connect failures=" << failed.load() << ", first errno=" << first_error.load()); + REQUIRE(connected == kAttempts); + REQUIRE(failed == 0); +} + TEST_CASE("socket_address with hostname resolution", "[tcp][address][dns]") { // Test that socket_address can be constructed from "localhost" // This tests the DNS resolution path