Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# include <netdb.h>
# include <netinet/tcp.h>
# include <signal.h>
# include <sys/un.h>
# include <unistd.h>
#endif

Expand Down Expand Up @@ -193,6 +194,69 @@ struct SocketRAIIWrapper {
}
};

} // namespace

#if defined(_unix_)
namespace {
SOCKET SocketConnectUnix(const std::string& socket_path, const SocketTimeoutParams& timeout_params) {
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;

if (socket_path.size() >= sizeof(addr.sun_path)) {
throw std::system_error(EINVAL, std::system_category(), "UNIX socket path too long");
}

strncpy(addr.sun_path, socket_path.c_str(), sizeof(addr.sun_path) - 1);

SocketRAIIWrapper s{socket(AF_UNIX, SOCK_STREAM, 0)};

if (*s == INVALID_SOCKET) {
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to create UNIX socket");
}

SetNonBlock(*s, true);
SetTimeout(*s, timeout_params);

if (connect(*s, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != 0) {
int err = getSocketErrorCode();
if (err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK) {
pollfd fd;
fd.fd = *s;
fd.events = POLLOUT;
fd.revents = 0;
ssize_t rval = Poll(&fd, 1, static_cast<int>(timeout_params.connect_timeout.count()));

if (rval == -1) {
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to connect to UNIX socket");
}
if (rval == 0) {
throw std::system_error(ETIMEDOUT, getErrorCategory(), "timeout connecting to UNIX socket");
}
if (rval > 0) {
socklen_t len = sizeof(err);
getsockopt(*s, SOL_SOCKET, SO_ERROR, &err, &len);

if (err) {
throw std::system_error(err, getErrorCategory(), "fail to connect to UNIX socket");
}
SetNonBlock(*s, false);
return s.release();
}
// Should not reach here, but ensure we return a value
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to connect to UNIX socket");
} else {
throw std::system_error(err, getErrorCategory(), "fail to connect to UNIX socket");
}
} else {
SetNonBlock(*s, false);
return s.release();
}
}
} // namespace
#endif

namespace {
SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) {
int last_err = 0;
for (auto res = addr.Info(); res != nullptr; res = res->ai_next) {
Expand Down Expand Up @@ -324,6 +388,10 @@ Socket::Socket(const NetworkAddress & addr)
: handle_(SocketConnect(addr, SocketTimeoutParams{}))
{}

Socket::Socket(SOCKET handle)
: handle_(handle)
{}

Socket::Socket(Socket&& other) noexcept
: handle_(other.handle_)
{
Expand Down Expand Up @@ -391,6 +459,15 @@ std::unique_ptr<OutputStream> Socket::makeOutputStream() const {
NonSecureSocketFactory::~NonSecureSocketFactory() {}

std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts, const Endpoint& endpoint) {
#if defined(_unix_)
// Check if UNIX domain socket path is provided
if (!endpoint.socket_path.empty()) {
SocketTimeoutParams timeout_params { opts.connection_connect_timeout, opts.connection_recv_timeout, opts.connection_send_timeout };
SOCKET handle = SocketConnectUnix(endpoint.socket_path, timeout_params);
// Skip TCP-specific options for UNIX sockets
return std::make_unique<Socket>(handle);
}
#endif

const auto address = NetworkAddress(endpoint.host, std::to_string(endpoint.port));
auto socket = doConnect(address, opts);
Expand Down
1 change: 1 addition & 0 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Socket : public SocketBase {
public:
Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);
Socket(const NetworkAddress& addr);
explicit Socket(SOCKET handle);
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;

Expand Down
16 changes: 12 additions & 4 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct ClientInfo {
};

std::ostream& operator<<(std::ostream& os, const Endpoint& endpoint) {
if (!endpoint.socket_path.empty()) {
return os << "unix://" << endpoint.socket_path;
}
return os << endpoint.host << ":" << endpoint.port;
}

Expand All @@ -75,7 +78,7 @@ std::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {

if (!opt.host.empty()) {
extra_endpoints = 1;
os << opt.user << '@' << Endpoint{opt.host, opt.port};
os << opt.user << '@' << Endpoint{opt.host, opt.port, opt.socket_path};

if (opt.endpoints.size())
os << ", ";
Expand Down Expand Up @@ -255,10 +258,11 @@ class Client::Impl {

ClientOptions modifyClientOptions(ClientOptions opts)
{
if (opts.host.empty())
if (opts.host.empty() && opts.socket_path.empty()) {
return opts;
}

Endpoint default_endpoint({opts.host, opts.port});
Endpoint default_endpoint({opts.host, opts.port, opts.socket_path});
opts.endpoints.emplace(opts.endpoints.begin(), default_endpoint);
return opts;
}
Expand Down Expand Up @@ -431,7 +435,11 @@ void Client::Impl::ResetConnection() {
InitializeStreams(socket_factory_->connect(options_, current_endpoint_.value()));

if (!Handshake()) {
throw ProtocolError("fail to connect to " + options_.host);
const auto& endpoint = current_endpoint_.value();
std::string connection_target = endpoint.socket_path.empty()
? (options_.host.empty() ? endpoint.host : options_.host)
: endpoint.socket_path;
throw ProtocolError("fail to connect to " + connection_target);
}
}

Expand Down
6 changes: 5 additions & 1 deletion clickhouse/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ enum class CompressionMethod : int8_t {
struct Endpoint {
std::string host;
uint16_t port = 9000;
/// UNIX domain socket path. If set, this takes precedence over host/port.
std::string socket_path = std::string();
inline bool operator==(const Endpoint& right) const {
return host == right.host && port == right.port;
return host == right.host && port == right.port && socket_path == right.socket_path;
}
};

Expand All @@ -72,6 +74,8 @@ struct ClientOptions {
DECLARE_FIELD(host, std::string, SetHost, std::string());
/// Service port.
DECLARE_FIELD(port, uint16_t, SetPort, 9000);
/// UNIX domain socket path. If set, this takes precedence over host/port.
DECLARE_FIELD(socket_path, std::string, SetSocketPath, std::string());

/** Set endpoints (host+port), only one is used.
* Client tries to connect to those endpoints one by one, on the round-robin basis:
Expand Down
2 changes: 1 addition & 1 deletion tests/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ int main() {
.SetHost( getEnvOrDefault("CLICKHOUSE_HOST", "localhost"))
.SetPort( getEnvOrDefault<size_t>("CLICKHOUSE_PORT", "9000"))
.SetEndpoints({ {"asasdasd", 9000}
,{"localhost"}
,{"localhost", 9000}
,{"noalocalhost", 9000}
})
.SetUser( getEnvOrDefault("CLICKHOUSE_USER", "default"))
Expand Down
5 changes: 5 additions & 0 deletions ut/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ SET ( clickhouse-cpp-ut-src
low_cardinality_nullable_tests.cpp
)

# Add UNIX socket server for Unix-like systems
IF (NOT WIN32)
LIST (APPEND clickhouse-cpp-ut-src unix_socket_server.cpp)
ENDIF ()

IF (WITH_OPENSSL)
LIST (APPEND clickhouse-cpp-ut-src ssl_ut.cpp)
ENDIF ()
Expand Down
8 changes: 4 additions & 4 deletions ut/client_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,9 +1315,9 @@ INSTANTIATE_TEST_SUITE_P(ClientMultipleEndpointsWithDefaultPort, ConnectionSucce
::testing::Values(ClientCase::ParamType{
ClientOptions()
.SetEndpoints({
{"somedeadhost"}
{"somedeadhost", 9000}
, {"deadaginghost", 1245}
, {"localhost"}
, {"localhost", 9000}
, {"noalocalhost", 6784}
})
.SetUser( getEnvOrDefault("CLICKHOUSE_USER", "default"))
Expand Down Expand Up @@ -1401,10 +1401,10 @@ INSTANTIATE_TEST_SUITE_P(ResetConnectionClientTest, ResetConnectionTestCase,
::testing::Values(ResetConnectionTestCase::ParamType {
ClientOptions()
.SetEndpoints({
{"localhost", 9000}
{"localhost"}
,{"somedeadhost", 1245}
,{"noalocalhost", 6784}
,{"127.0.0.1", 9000}
,{"127.0.0.1"}
})
.SetUser( getEnvOrDefault("CLICKHOUSE_USER", "default"))
.SetPassword( getEnvOrDefault("CLICKHOUSE_PASSWORD", ""))
Expand Down
107 changes: 107 additions & 0 deletions ut/socket_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,110 @@ TEST(Socketcase, connecttimeout) {
// auto input = socket.makeInputStream();
// input->Read(buffer, sizeof(buffer));
//}

#if defined(_unix_)

#include "unix_socket_server.h"
#include <clickhouse/client.h>

TEST(Socketcase, UnixSocketConnect) {
const std::string socket_path = "/tmp/test_clickhouse_cpp_unix_socket.sock";
LocalUnixSocketServer server(socket_path);
server.start();

std::this_thread::sleep_for(std::chrono::milliseconds(100));

try {
// Test connection via NonSecureSocketFactory
ClientOptions opts;
opts.SetSocketPath(socket_path);
Endpoint endpoint;
endpoint.socket_path = socket_path;

NonSecureSocketFactory factory;
auto socket_base = factory.connect(opts, endpoint);
EXPECT_NE(nullptr, socket_base);
SUCCEED();
} catch (const std::system_error& e) {
FAIL() << "Failed to connect to UNIX socket: " << e.what();
}

std::this_thread::sleep_for(std::chrono::milliseconds(100));
server.stop();
}

TEST(Socketcase, UnixSocketConnectError) {
const std::string socket_path = "/tmp/test_clickhouse_cpp_unix_socket_nonexistent.sock";

try {
ClientOptions opts;
opts.SetSocketPath(socket_path);
opts.SetConnectionConnectTimeout(std::chrono::milliseconds(100));
Endpoint endpoint;
endpoint.socket_path = socket_path;

NonSecureSocketFactory factory;
auto socket_base = factory.connect(opts, endpoint);
FAIL() << "Should have failed to connect to non-existent UNIX socket";
} catch (const std::system_error& e) {
// Expected to fail
EXPECT_TRUE(e.code().value() == ECONNREFUSED || e.code().value() == ENOENT);
}
}

TEST(Socketcase, UnixSocketPathTooLong) {
const std::string long_path(200, 'a'); // Longer than UNIX_PATH_MAX (typically 108)

try {
ClientOptions opts;
opts.SetSocketPath(long_path);
Endpoint endpoint;
endpoint.socket_path = long_path;

NonSecureSocketFactory factory;
auto socket_base = factory.connect(opts, endpoint);
FAIL() << "Should have failed with path too long error";
} catch (const std::system_error& e) {
EXPECT_EQ(EINVAL, e.code().value());
}
}

TEST(ClientUnixSocket, UnixSocketEndpoint) {
// This test requires a real ClickHouse server listening on a UNIX socket
// For now, we'll just test that the endpoint structure works correctly
Endpoint endpoint;
endpoint.host = "localhost";
endpoint.port = 9000;
endpoint.socket_path = "/tmp/test.sock";

Endpoint endpoint2;
endpoint2.host = "localhost";
endpoint2.port = 9000;
endpoint2.socket_path = "/tmp/test.sock";

EXPECT_EQ(endpoint, endpoint2);

Endpoint endpoint3;
endpoint3.host = "localhost";
endpoint3.port = 9000;
endpoint3.socket_path = "/tmp/different.sock";

EXPECT_FALSE(endpoint == endpoint3);
}

TEST(ClientUnixSocket, UnixSocketClientOptions) {
ClientOptions opts;
opts.SetSocketPath("/tmp/test.sock");
EXPECT_EQ("/tmp/test.sock", opts.socket_path);

opts.SetHost("localhost");
opts.SetPort(9000);
opts.SetSocketPath("/tmp/test.sock");

// socket_path should take precedence
EXPECT_EQ("/tmp/test.sock", opts.socket_path);
EXPECT_EQ("localhost", opts.host);
EXPECT_EQ(9000u, opts.port);
}

#endif // defined(_unix_)
Loading