Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions include/elio/net/tcp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ class tcp_connect_awaitable {
int ret = ::connect(fd_, reinterpret_cast<struct sockaddr*>(&sa_), sa_len);
if (ret == 0) {
// Connected immediately (rare for TCP, but possible)
connect_in_progress_ = false;
result_ = io::io_result{0, 0};
return false; // Don't suspend, resume immediately
}
Expand All @@ -689,10 +690,11 @@ class tcp_connect_awaitable {
}

// Connection in progress, wait for socket to become writable
connect_in_progress_ = true;
auto& ctx = io::current_io_context();

io::io_request req{};
req.op = io::io_op::connect;
req.op = io::io_op::poll_write;
req.fd = fd_;
req.addr = reinterpret_cast<struct sockaddr*>(&sa_);
req.addrlen = &sa_len_;
Expand All @@ -709,11 +711,20 @@ class tcp_connect_awaitable {
}

std::optional<tcp_stream> await_resume() {
// If result wasn't set (async path completed), get from io_context
if (result_.result == 0 && fd_ >= 0) {
auto ctx_result = io::io_context::get_last_result();
if (ctx_result.result != 0 || ctx_result.flags != 0) {
result_ = ctx_result;
// Async path completion result comes from io_context.
if (connect_in_progress_ && fd_ >= 0) {
result_ = io::io_context::get_last_result();
}

// For non-blocking connect, writability means completion, not success.
// Use SO_ERROR to fetch the actual connect result.
if (connect_in_progress_ && result_.success() && fd_ >= 0) {
int so_error = 0;
socklen_t len = sizeof(so_error);
if (::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) {
result_ = io::io_result{-errno, 0};
} else if (so_error != 0) {
result_ = io::io_result{-so_error, 0};
}
}

Expand All @@ -740,6 +751,7 @@ class tcp_connect_awaitable {
struct sockaddr_storage sa_{};
socklen_t sa_len_ = sizeof(sa_);
int fd_ = -1;
bool connect_in_progress_ = false;
io::io_result result_{};
};

Expand Down
28 changes: 18 additions & 10 deletions include/elio/net/uds.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ class uds_connect_awaitable {
int ret = ::connect(fd_, reinterpret_cast<struct sockaddr*>(&sa_), sa_len_);
if (ret == 0) {
// Connected immediately (common for UDS)
connect_in_progress_ = false;
result_ = io::io_result{0, 0};
return false; // Don't suspend, resume immediately
}
Expand All @@ -478,10 +479,11 @@ class uds_connect_awaitable {
}

// Connection in progress, wait for socket to become writable
connect_in_progress_ = true;
auto& ctx = io::current_io_context();

io::io_request req{};
req.op = io::io_op::connect;
req.op = io::io_op::poll_write;
req.fd = fd_;
req.addr = reinterpret_cast<struct sockaddr*>(&sa_);
req.addrlen = &sa_len_;
Expand All @@ -498,16 +500,21 @@ class uds_connect_awaitable {
}

std::optional<uds_stream> await_resume() {
// For async completion (EINPROGRESS path), get result from io_context
// For immediate completion, result_ is already set
if (result_.result == 0 && result_.flags == 0 && fd_ >= 0) {
// This could be immediate success ({0,0}) or we need to check async result
auto ctx_result = io::io_context::get_last_result();
// Only use ctx_result if it looks like a real completion (not default)
if (ctx_result.result != 0 || ctx_result.flags != 0) {
result_ = ctx_result;
// Async path completion result comes from io_context.
if (connect_in_progress_ && fd_ >= 0) {
result_ = io::io_context::get_last_result();
}

// For non-blocking connect, writability means completion, not success.
// Use SO_ERROR to fetch the actual connect result.
if (connect_in_progress_ && result_.success() && fd_ >= 0) {
int so_error = 0;
socklen_t len = sizeof(so_error);
if (::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) {
result_ = io::io_result{-errno, 0};
} else if (so_error != 0) {
result_ = io::io_result{-so_error, 0};
}
// If ctx_result is also {0,0}, keep our result_ (immediate success)
}

if (!result_.success()) {
Expand All @@ -533,6 +540,7 @@ class uds_connect_awaitable {
struct sockaddr_un sa_{};
socklen_t sa_len_ = 0;
int fd_ = -1;
bool connect_in_progress_ = false;
io::io_result result_{};
};

Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,23 @@ TEST_CASE("UDS echo test", "[uds][echo]") {

#include <elio/net/tcp.hpp>

static task<void> tcp_connect_regression_attempt(
uint16_t port,
std::atomic<int>& connected,
std::atomic<int>& failed,
std::atomic<int>& first_error) {
auto stream = co_await tcp_connect(ipv6_address("::1", port));
if (stream) {
connected.fetch_add(1, std::memory_order_relaxed);
} else {
failed.fetch_add(1, std::memory_order_relaxed);
int err = errno;
int expected = 0;
first_error.compare_exchange_strong(expected, err);
}
co_return;
}

TEST_CASE("ipv4_address basic operations", "[tcp][address][ipv4]") {
SECTION("default constructor") {
ipv4_address addr;
Expand Down Expand Up @@ -1199,6 +1216,37 @@ TEST_CASE("TCP IPv6 listener and connect", "[tcp][ipv6][integration]") {
}
}

TEST_CASE("TCP connect regression avoids double connect", "[tcp][connect][regression]") {
auto listener = tcp_listener::bind(ipv6_address("::1", 0));
REQUIRE(listener.has_value());

uint16_t port = listener->local_address().port();
REQUIRE(port > 0);

std::atomic<int> connected{0};
std::atomic<int> failed{0};
std::atomic<int> first_error{0};

scheduler sched(2);
sched.start();

constexpr int kAttempts = 64;
for (int i = 0; i < kAttempts; ++i) {
auto t = tcp_connect_regression_attempt(port, connected, failed, first_error);
sched.spawn(t.release());
}

for (int i = 0; i < 500 && (connected + failed) < kAttempts; ++i) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}

sched.shutdown();

INFO("connect failures=" << failed.load() << ", first errno=" << first_error.load());
REQUIRE(connected == kAttempts);
REQUIRE(failed == 0);
}

TEST_CASE("socket_address with hostname resolution", "[tcp][address][dns]") {
// Test that socket_address can be constructed from "localhost"
// This tests the DNS resolution path
Expand Down
Loading