diff --git a/include/mgclient.h b/include/mgclient.h index be540ad..20e67ae 100644 --- a/include/mgclient.h +++ b/include/mgclient.h @@ -1244,6 +1244,8 @@ MGCLIENT_EXPORT void mg_session_params_set_host(mg_session_params *, const char *host); MGCLIENT_EXPORT void mg_session_params_set_port(mg_session_params *, uint16_t port); +MGCLIENT_EXPORT void mg_session_params_set_scheme(mg_session_params *, + const char *scheme); MGCLIENT_EXPORT void mg_session_params_set_username(mg_session_params *, const char *username); MGCLIENT_EXPORT void mg_session_params_set_password(mg_session_params *, diff --git a/mgclient_cpp/include/mgclient.hpp b/mgclient_cpp/include/mgclient.hpp index 241fee6..4b0a79b 100644 --- a/mgclient_cpp/include/mgclient.hpp +++ b/mgclient_cpp/include/mgclient.hpp @@ -57,6 +57,7 @@ class Client { struct Params { std::string host = "127.0.0.1"; uint16_t port = 7687; + std::string scheme = "none"; std::string username = ""; std::string password = ""; bool use_ssl = false; @@ -148,13 +149,24 @@ inline std::unique_ptr Client::Connect(const Client::Params ¶ms) { if (!mg_params) { return nullptr; } - mg_session_params_set_host(mg_params, params.host.c_str()); - mg_session_params_set_port(mg_params, params.port); + if (!params.host.empty()) { + mg_session_params_set_host(mg_params, params.host.c_str()); + } + if (params.port != 0) { + mg_session_params_set_port(mg_params, params.port); + } + if (!params.scheme.empty()) { + mg_session_params_set_scheme(mg_params, params.scheme.c_str()); + } if (!params.username.empty()) { mg_session_params_set_username(mg_params, params.username.c_str()); + } + if (!params.password.empty()) { mg_session_params_set_password(mg_params, params.password.c_str()); } - mg_session_params_set_user_agent(mg_params, params.user_agent.c_str()); + if (!params.user_agent.empty()) { + mg_session_params_set_user_agent(mg_params, params.user_agent.c_str()); + } mg_session_params_set_sslmode( mg_params, params.use_ssl ? MG_SSLMODE_REQUIRE : MG_SSLMODE_DISABLE); diff --git a/src/mgclient.c b/src/mgclient.c index 58d8284..c346d91 100644 --- a/src/mgclient.c +++ b/src/mgclient.c @@ -67,6 +67,7 @@ typedef struct mg_session_params { const char *address; const char *host; uint16_t port; + const char *scheme; const char *username; const char *password; const char *user_agent; @@ -118,6 +119,11 @@ void mg_session_params_set_port(mg_session_params *params, uint16_t port) { params->port = port; } +void mg_session_params_set_scheme(mg_session_params *params, + const char *scheme) { + params->scheme = scheme; +} + void mg_session_params_set_username(mg_session_params *params, const char *username) { params->username = username; @@ -364,8 +370,8 @@ int mg_bolt_init_v1(mg_session *session, const mg_session_params *params) { return status; } -static mg_map *build_hello_extra(const char *user_agent, const char *username, - const char *password) { +static mg_map *build_hello_extra(const char *user_agent, const char *scheme, + const char *username, const char *password) { mg_map *extra = mg_map_make_empty(4); if (!extra) { return NULL; @@ -379,40 +385,53 @@ static mg_map *build_hello_extra(const char *user_agent, const char *username, } } - assert((username && password) || (!username && !password)); - if (username) { - mg_value *scheme = mg_value_make_string("basic"); - if (!scheme || mg_map_insert_unsafe(extra, "scheme", scheme) != 0) { + // The "basic" scheme requires a username and a password/credential within the + // HELLO message. Other schemes (save for "kerberos", which is not supported + // by Memgraph) do not have such requirements: + // https://neo4j.com/docs/bolt/current/bolt/message/#messages-hello + // https://neo4j.com/docs/bolt/current/bolt/message/#messages-logon + // NOTE: HELLO message does NOT contain schema after Bolt 5.0. + if (scheme && strcmp(scheme, "basic") == 0) { + assert(username && password); + } + + if (!username && !password) { + mg_value *scheme_ = mg_value_make_string("none"); + if (!scheme_ || mg_map_insert_unsafe(extra, "scheme", scheme_) != 0) { goto cleanup; } + return extra; + } + + mg_value *scheme_ = mg_value_make_string(scheme ? scheme : "none"); // NOTE: Makes none default. + if (!scheme_ || mg_map_insert_unsafe(extra, "scheme", scheme_) != 0) { + goto cleanup; + } + if (username) { mg_value *principal = mg_value_make_string(username); if (!principal || mg_map_insert_unsafe(extra, "principal", principal)) { goto cleanup; } + } + if (password) { mg_value *credentials = mg_value_make_string(password); if (!credentials || mg_map_insert_unsafe(extra, "credentials", credentials)) { goto cleanup; } - } else { - mg_value *scheme = mg_value_make_string("none"); - if (!scheme || mg_map_insert_unsafe(extra, "scheme", scheme) != 0) { - goto cleanup; - } } return extra; - cleanup: mg_map_destroy(extra); return NULL; } int mg_bolt_init_v4(mg_session *session, const mg_session_params *params) { - mg_map *extra = - build_hello_extra(params->user_agent, params->username, params->password); + mg_map *extra = build_hello_extra(params->user_agent, params->scheme, + params->username, params->password); if (!extra) { return MG_ERROR_OOM; } diff --git a/tests/client.cpp b/tests/client.cpp index 59c3839..4f48046 100644 --- a/tests/client.cpp +++ b/tests/client.cpp @@ -19,7 +19,6 @@ #include #include "mgclient.h" -#include "mgcommon.h" #include "mgsession.h" #include "mgsocket.h" @@ -508,76 +507,73 @@ TEST_F(ConnectTest, Success) { ASSERT_MEMORY_OK(); } -TEST_F(ConnectTest, Success_v4) { - RunServer([](int sockfd) { - // Perform handshake. - { - char handshake[20]; - ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); - ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); +auto run_v4_server_success = [](int sockfd) { + // Perform handshake. + { + char handshake[20]; + ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); + ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); - uint32_t version = htobe32(0x0104); - ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); - } + uint32_t version = htobe32(1); + ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); + } - mg_session *session = mg_session_init(&mg_system_allocator); - ASSERT_TRUE(session); - session->version = 4; - mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, - &mg_system_allocator); + mg_session *session = mg_session_init(&mg_system_allocator); + ASSERT_TRUE(session); + session->version = 1; + mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, + &mg_system_allocator); - // Read HELLO message. + // Read INIT message. + { + mg_message *message; + ASSERT_EQ(mg_session_receive_message(session), 0); + ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0); + ASSERT_EQ(message->type, MG_MESSAGE_TYPE_INIT); + + mg_message_init *msg_init = message->init_v; + EXPECT_EQ( + std::string(msg_init->client_name->data, msg_init->client_name->size), + MG_USER_AGENT); { - mg_message *message; - ASSERT_EQ(mg_session_receive_message(session), 0); - ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0); - ASSERT_EQ(message->type, MG_MESSAGE_TYPE_HELLO); - - mg_message_hello *msg_hello = message->hello_v; - { - ASSERT_EQ(mg_map_size(msg_hello->extra), 4u); - - const mg_value *user_agent_val = - mg_map_at(msg_hello->extra, "user_agent"); - ASSERT_TRUE(user_agent_val); - ASSERT_EQ(mg_value_get_type(user_agent_val), MG_VALUE_TYPE_STRING); - const mg_string *user_agent = mg_value_string(user_agent_val); - ASSERT_EQ(std::string(user_agent->data, user_agent->size), - MG_USER_AGENT); - - const mg_value *scheme_val = mg_map_at(msg_hello->extra, "scheme"); - ASSERT_TRUE(scheme_val); - ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); - const mg_string *scheme = mg_value_string(scheme_val); - ASSERT_EQ(std::string(scheme->data, scheme->size), "basic"); - - const mg_value *principal_val = - mg_map_at(msg_hello->extra, "principal"); - ASSERT_TRUE(principal_val); - ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING); - const mg_string *principal = mg_value_string(principal_val); - ASSERT_EQ(std::string(principal->data, principal->size), "user"); + ASSERT_EQ(mg_map_size(msg_init->auth_token), 3u); + + const mg_value *scheme_val = mg_map_at(msg_init->auth_token, "scheme"); + ASSERT_TRUE(scheme_val); + ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); + const mg_string *scheme = mg_value_string(scheme_val); + ASSERT_EQ(std::string(scheme->data, scheme->size), "basic"); + + const mg_value *principal_val = + mg_map_at(msg_init->auth_token, "principal"); + ASSERT_TRUE(principal_val); + ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING); + const mg_string *principal = mg_value_string(principal_val); + ASSERT_EQ(std::string(principal->data, principal->size), "user"); + + const mg_value *credentials_val = + mg_map_at(msg_init->auth_token, "credentials"); + ASSERT_TRUE(credentials_val); + ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING); + const mg_string *credentials = mg_value_string(credentials_val); + ASSERT_EQ(std::string(credentials->data, credentials->size), "pass"); + } + + mg_message_destroy_ca(message, session->decoder_allocator); + } - const mg_value *credentials_val = - mg_map_at(msg_hello->extra, "credentials"); - ASSERT_TRUE(credentials_val); - ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING); - const mg_string *credentials = mg_value_string(credentials_val); - ASSERT_EQ(std::string(credentials->data, credentials->size), "pass"); - } + // Send SUCCESS message. + ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0); - mg_message_destroy_ca(message, session->decoder_allocator); - } - - // Send SUCCESS message. - ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0); + mg_session_destroy(session); +}; - mg_session_destroy(session); - }); +TEST_F(ConnectTest, Success_v4) { + RunServer(run_v4_server_success); mg_session_params *params = mg_session_params_make(); mg_session_params_set_host(params, "127.0.0.1"); mg_session_params_set_port(params, port); @@ -592,70 +588,7 @@ TEST_F(ConnectTest, Success_v4) { } TEST_F(ConnectTest, SuccessWithSSL) { - RunServer([](int sockfd) { - // Perform handshake. - { - char handshake[20]; - ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); - ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); - - uint32_t version = htobe32(1); - ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); - } - - mg_session *session = mg_session_init(&mg_system_allocator); - ASSERT_TRUE(session); - session->version = 1; - mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, - &mg_system_allocator); - - // Read INIT message. - { - mg_message *message; - ASSERT_EQ(mg_session_receive_message(session), 0); - ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0); - ASSERT_EQ(message->type, MG_MESSAGE_TYPE_INIT); - - mg_message_init *msg_init = message->init_v; - EXPECT_EQ( - std::string(msg_init->client_name->data, msg_init->client_name->size), - MG_USER_AGENT); - { - ASSERT_EQ(mg_map_size(msg_init->auth_token), 3u); - - const mg_value *scheme_val = mg_map_at(msg_init->auth_token, "scheme"); - ASSERT_TRUE(scheme_val); - ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); - const mg_string *scheme = mg_value_string(scheme_val); - ASSERT_EQ(std::string(scheme->data, scheme->size), "basic"); - - const mg_value *principal_val = - mg_map_at(msg_init->auth_token, "principal"); - ASSERT_TRUE(principal_val); - ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING); - const mg_string *principal = mg_value_string(principal_val); - ASSERT_EQ(std::string(principal->data, principal->size), "user"); - - const mg_value *credentials_val = - mg_map_at(msg_init->auth_token, "credentials"); - ASSERT_TRUE(credentials_val); - ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING); - const mg_string *credentials = mg_value_string(credentials_val); - ASSERT_EQ(std::string(credentials->data, credentials->size), "pass"); - } - - mg_message_destroy_ca(message, session->decoder_allocator); - } - - // Send SUCCESS message. - ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0); - - mg_session_destroy(session); - }); + RunServer(run_v4_server_success); mg_secure_transport_init_called = 0; trust_callback_ok = 0; @@ -681,6 +614,22 @@ TEST_F(ConnectTest, SuccessWithSSL) { ASSERT_MEMORY_OK(); } +TEST_F(ConnectTest, CustomScheme) { + RunServer(run_v4_server_success); + mg_session_params *params = mg_session_params_make(); + mg_session_params_set_host(params, "127.0.0.1"); + mg_session_params_set_port(params, port); + mg_session_params_set_scheme(params, "custom_scheme"); + mg_session_params_set_username(params, "user"); + mg_session_params_set_password(params, "pass"); + mg_session *session; + ASSERT_EQ(mg_connect_ca(params, &session, (mg_allocator *)&allocator), 0); + EXPECT_EQ(mg_session_status(session), MG_SESSION_READY); + mg_session_params_destroy(params); + mg_session_destroy(session); + ASSERT_MEMORY_OK(); +} + class RunTest : public ::testing::Test { protected: virtual void SetUp() override { diff --git a/tests/integration/basic_cpp.cpp b/tests/integration/basic_cpp.cpp index b5ff5a8..adc9264 100644 --- a/tests/integration/basic_cpp.cpp +++ b/tests/integration/basic_cpp.cpp @@ -35,7 +35,7 @@ class MemgraphConnection : public ::testing::Test { client = mg::Client::Connect( {GetEnvOrDefault("MEMGRAPH_HOST", "127.0.0.1"), - GetEnvOrDefault("MEMGRAPH_PORT", 7687), "", "", + GetEnvOrDefault("MEMGRAPH_PORT", 7687), "basic", "", "", GetEnvOrDefault("MEMGRAPH_SSLMODE", false), ""}); ASSERT_TRUE(client);