diff --git a/src/db_adapter/BUILD b/src/db_adapter/BUILD index 5ac293d5..b121a633 100644 --- a/src/db_adapter/BUILD +++ b/src/db_adapter/BUILD @@ -9,7 +9,8 @@ cc_library( ":context_loader", ":data_mapper", ":data_types", - ":db_wrapper", + ":database_connection", + ":database_wrapper", "//commons:commons_lib", "//commons/atoms:atoms_lib", "//db_adapter/postgres:postgres_lib", @@ -17,33 +18,36 @@ cc_library( ) cc_library( - name = "data_mapper", - srcs = ["DataMapper.cc"], - hdrs = ["DataMapper.h"], + name = "data_types", + hdrs = ["DataTypes.h"], includes = ["."], deps = [ - ":data_types", "//commons:commons_lib", "//commons/atoms:atoms_lib", ], ) cc_library( - name = "data_types", - hdrs = ["DataTypes.h"], + name = "data_mapper", + srcs = ["DataMapper.cc"], + hdrs = ["DataMapper.h"], includes = ["."], deps = [ + ":data_types", "//commons:commons_lib", "//commons/atoms:atoms_lib", ], ) cc_library( - name = "db_wrapper", - hdrs = ["DBWrapper.h"], + name = "database_wrapper", + srcs = ["DatabaseWrapper.cc"], + hdrs = ["DatabaseWrapper.h"], includes = ["."], deps = [ + ":data_mapper", ":data_types", + ":database_connection", "//commons:commons_lib", "//commons/atoms:atoms_lib", ], @@ -58,3 +62,16 @@ cc_library( "@nlohmann_json//:json", ], ) + +cc_library( + name = "database_connection", + srcs = ["DatabaseConnection.cc"], + hdrs = ["DatabaseConnection.h"], + includes = ["."], + deps = [ + ":data_types", + "//commons:commons_lib", + "//commons/atoms:atoms_lib", + "//commons/processor:processor_lib", + ], +) diff --git a/src/db_adapter/DatabaseConnection.cc b/src/db_adapter/DatabaseConnection.cc new file mode 100644 index 00000000..fc7f6c4d --- /dev/null +++ b/src/db_adapter/DatabaseConnection.cc @@ -0,0 +1,44 @@ +#include "DatabaseConnection.h" + +using namespace db_adapter; + +DatabaseConnection::DatabaseConnection(const string& id, const string& host, int port) : Processor(id) { + this->host = host; + this->port = port; + this->connected = false; + this->setup(); +} + +DatabaseConnection::~DatabaseConnection() {} + +void DatabaseConnection::setup() { + if (!this->is_setup()) { + Processor::setup(); + } +} + +void DatabaseConnection::start() { + if (this->is_running() || this->is_finished()) return; + + { + lock_guard lock(this->connection_mutex); + this->connect(); + this->connected = true; + } + + Processor::start(); +} + +void DatabaseConnection::stop() { + if (!this->is_running()) return; + + { + lock_guard lock(this->connection_mutex); + this->disconnect(); + this->connected = false; + } + + Processor::stop(); +} + +bool DatabaseConnection::is_connected() const { return this->connected; } diff --git a/src/db_adapter/DatabaseConnection.h b/src/db_adapter/DatabaseConnection.h new file mode 100644 index 00000000..adc23f7e --- /dev/null +++ b/src/db_adapter/DatabaseConnection.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include "Processor.h" + +using namespace std; +using namespace processor; + +namespace db_adapter { + +class DatabaseConnection : public Processor { + public: + DatabaseConnection(const string& id, const string& host, int port); + ~DatabaseConnection() override; + + virtual void setup() override; + virtual void start() override; + virtual void stop() override; + + virtual void connect() = 0; + virtual void disconnect() = 0; + + bool is_connected() const; + + protected: + string host; + int port; + + private: + bool connected; + mutex connection_mutex; +}; + +} // namespace db_adapter \ No newline at end of file diff --git a/src/db_adapter/DatabaseWrapper.cc b/src/db_adapter/DatabaseWrapper.cc new file mode 100644 index 00000000..aac94dfd --- /dev/null +++ b/src/db_adapter/DatabaseWrapper.cc @@ -0,0 +1,22 @@ +#include "DatabaseWrapper.h" + +namespace { +shared_ptr create_mapper(MAPPER_TYPE mapper_type) { + switch (mapper_type) { + case MAPPER_TYPE::SQL2METTA: + return make_shared(); + case MAPPER_TYPE::SQL2ATOMS: + return make_shared(); + default: + throw invalid_argument("Unknown mapper type"); + } +} +} // namespace + +SQLWrapper::SQLWrapper(MAPPER_TYPE mapper_type) + : DatabaseWrapper(create_mapper(mapper_type), mapper_type) {} + +DatabaseWrapper::DatabaseWrapper(shared_ptr mapper, MAPPER_TYPE mapper_type) + : mapper(move(mapper)), mapper_type(mapper_type) {} + +unsigned int DatabaseWrapper::mapper_handle_trie_size() { return this->mapper->handle_trie_size(); } diff --git a/src/db_adapter/DBWrapper.h b/src/db_adapter/DatabaseWrapper.h similarity index 55% rename from src/db_adapter/DBWrapper.h rename to src/db_adapter/DatabaseWrapper.h index f4e81297..f96b2d5c 100644 --- a/src/db_adapter/DBWrapper.h +++ b/src/db_adapter/DatabaseWrapper.h @@ -1,44 +1,34 @@ #pragma once #include +#include #include #include #include #include "DataMapper.h" +#include "DataTypes.h" +#include "DatabaseConnection.h" using namespace std; using namespace db_adapter; +using namespace commons; namespace db_adapter { /** * @class DatabaseWrapper * @brief Generic interface for a database connection wrapper. - * - * @tparam ConnT The underlying connection object type (e.g., pqxx::connection). */ -template class DatabaseWrapper { public: - explicit DatabaseWrapper(shared_ptr mapper, MAPPER_TYPE mapper_type) - : mapper(move(mapper)), mapper_type(mapper_type) {} + DatabaseWrapper(shared_ptr mapper, MAPPER_TYPE mapper_type); virtual ~DatabaseWrapper() = default; - /** - * @brief Closes the connection. - */ - virtual void disconnect() = 0; - - unsigned int mapper_handle_trie_size() { return mapper->handle_trie_size(); } + unsigned int mapper_handle_trie_size(); protected: - /** - * @brief Establishes connection to the database. - */ - virtual unique_ptr connect() = 0; - - unique_ptr db_client; + unique_ptr db_client; shared_ptr mapper; MAPPER_TYPE mapper_type; }; @@ -47,11 +37,9 @@ class DatabaseWrapper { * @class SQLWrapper * @brief Specialization of DatabaseWrapper for SQL-based databases. */ -template -class SQLWrapper : public DatabaseWrapper { +class SQLWrapper : public DatabaseWrapper { public: - explicit SQLWrapper(MAPPER_TYPE mapper_type) - : DatabaseWrapper(create_mapper(mapper_type), mapper_type) {} + explicit SQLWrapper(MAPPER_TYPE mapper_type); virtual ~SQLWrapper() = default; /** @@ -81,18 +69,5 @@ class SQLWrapper : public DatabaseWrapper { * @brief Executes a raw SQL query and maps the result. */ virtual void map_sql_query(const string& virtual_name, const string& raw_query) = 0; - - private: - // Factory method for creating the specific mapper strategy - static shared_ptr create_mapper(MAPPER_TYPE mapper_type) { - switch (mapper_type) { - case MAPPER_TYPE::SQL2METTA: - return make_shared(); - case MAPPER_TYPE::SQL2ATOMS: - return make_shared(); - default: - throw invalid_argument("Unknown mapper type"); - } - } }; } // namespace db_adapter \ No newline at end of file diff --git a/src/db_adapter/postgres/BUILD b/src/db_adapter/postgres/BUILD index 9e8450a2..70b02017 100644 --- a/src/db_adapter/postgres/BUILD +++ b/src/db_adapter/postgres/BUILD @@ -18,6 +18,7 @@ cc_library( deps = [ "//commons:commons_lib", "//db_adapter:data_mapper", - "//db_adapter:db_wrapper", + "//db_adapter:database_connection", + "//db_adapter:database_wrapper", ], ) diff --git a/src/db_adapter/postgres/PostgresWrapper.cc b/src/db_adapter/postgres/PostgresWrapper.cc index bfe1b703..f78383a2 100644 --- a/src/db_adapter/postgres/PostgresWrapper.cc +++ b/src/db_adapter/postgres/PostgresWrapper.cc @@ -13,46 +13,76 @@ using namespace std; -PostgresWrapper::PostgresWrapper(const string& host, - int port, - const string& database, - const string& user, - const string& password, - MAPPER_TYPE mapper_type) - : SQLWrapper(mapper_type), - host(host), - port(port), - database(database), - user(user), - password(password) { - this->db_client = this->connect(); +PostgresDatabaseConnection::PostgresDatabaseConnection(const string& id, + const string& host, + int port, + const string& database, + const string& user, + const string& password) + : DatabaseConnection(id, host, port), database(database), user(user), password(password) {} + +PostgresDatabaseConnection::~PostgresDatabaseConnection() { + if (this->is_running()) { + this->stop(); + } } -PostgresWrapper::~PostgresWrapper() { this->disconnect(); } - -unique_ptr PostgresWrapper::connect() { +void PostgresDatabaseConnection::connect() { try { - string conn_str = "host=" + host + " port=" + to_string(port) + " dbname=" + database; + string conn_str = "host=" + host + " port=" + std::to_string(port) + " dbname=" + database; if (!user.empty()) { conn_str += " user=" + user; } if (!password.empty()) { conn_str += " password=" + password; } - return make_unique(conn_str); - } catch (const pqxx::sql_error& e) { - throw runtime_error("Could not connect to database: " + string(e.what())); + this->conn = make_unique(conn_str); } catch (const exception& e) { throw runtime_error("Could not connect to database: " + string(e.what())); } } -void PostgresWrapper::disconnect() { - if (this->db_client) { - this->db_client->close(); +void PostgresDatabaseConnection::disconnect() { + if (this->conn) { + this->conn->close(); + this->conn.reset(); } } +pqxx::result PostgresDatabaseConnection::execute_query(const string& query) { + if (!this->conn || !this->conn->is_open()) { + Utils::error("Postgres connection is not open."); + } + + try { + pqxx::work transaction(*this->conn); + pqxx::result result = transaction.exec(query); + transaction.commit(); + return result; + } catch (const exception& e) { + Utils::error("Error during query execution: " + string(e.what())); + } + return pqxx::result{}; +} + +// =============================================================================================== +// PostgresWrapper implementation +// =============================================================================================== + +PostgresWrapper::PostgresWrapper(const string& host, + int port, + const string& database, + const string& user, + const string& password, + MAPPER_TYPE mapper_type) + : SQLWrapper(mapper_type) { + this->db_client = + make_unique("postgres-conn", host, port, database, user, password); + this->db_client->start(); +} + +PostgresWrapper::~PostgresWrapper() {} + Table PostgresWrapper::get_table(const string& name) { auto tables = this->list_tables(); for (const auto& table : tables) { @@ -135,7 +165,7 @@ vector PostgresWrapper::list_tables() { ORDER BY pg_total_relation_size(ti.table_name) ASC; )"; - auto result = this->execute_query(query); + auto result = pg_client().execute_query(query); vector
tables; tables.reserve(result.size()); @@ -263,8 +293,8 @@ vector PostgresWrapper::collect_fk_ids(const string& table_name, while (true) { string query = "SELECT " + column_name + " FROM " + table_name + " WHERE " + where_clause + - " LIMIT " + to_string(limit) + " OFFSET " + to_string(offset) + ";"; - pqxx::result rows = this->execute_query(query); + " LIMIT " + std::to_string(limit) + " OFFSET " + std::to_string(offset) + ";"; + pqxx::result rows = pg_client().execute_query(query); if (rows.empty()) break; @@ -367,8 +397,9 @@ void PostgresWrapper::fetch_rows_paginated(const Table& table, size_t limit = 10000; while (true) { - string paginated_query = query + " LIMIT " + to_string(limit) + " OFFSET " + to_string(offset); - pqxx::result rows = this->execute_query(paginated_query); + string paginated_query = + query + " LIMIT " + std::to_string(limit) + " OFFSET " + std::to_string(offset); + pqxx::result rows = pg_client().execute_query(paginated_query); LOG_DEBUG("Executing paginated query: " << paginated_query); LOG_DEBUG("Fetched " << rows.size() << " rows from table " << table.name); @@ -439,21 +470,3 @@ SqlRow PostgresWrapper::build_sql_row(const pqxx::row& row, const Table& table, } return sql_row; } - -pqxx::result PostgresWrapper::execute_query(const string& query) { - if (!this->db_client || !this->db_client->is_open()) { - Utils::error("Database connection is not open."); - } - - try { - pqxx::work transaction(*this->db_client); - pqxx::result result = transaction.exec(query); - transaction.commit(); - return result; - } catch (const pqxx::sql_error& e) { - Utils::error("SQL error during query execution: " + string(e.what())); - } catch (const exception& e) { - Utils::error("Error during query execution: " + string(e.what())); - } - return pqxx::result{}; -} \ No newline at end of file diff --git a/src/db_adapter/postgres/PostgresWrapper.h b/src/db_adapter/postgres/PostgresWrapper.h index 1f98deb4..2578c06f 100644 --- a/src/db_adapter/postgres/PostgresWrapper.h +++ b/src/db_adapter/postgres/PostgresWrapper.h @@ -9,7 +9,8 @@ #include #include -#include "DBWrapper.h" +#include "DatabaseConnection.h" +#include "DatabaseWrapper.h" #define MAX_VALUE_SIZE ((size_t) 1000) @@ -19,11 +20,32 @@ using namespace commons; namespace db_adapter { +class PostgresDatabaseConnection : public DatabaseConnection { + public: + PostgresDatabaseConnection(const string& id, + const string& host, + int port, + const string& database, + const string& user, + const string& password); + ~PostgresDatabaseConnection() override; + + void connect() override; + void disconnect() override; + pqxx::result execute_query(const string& query); + + protected: + unique_ptr conn; + string database; + string user; + string password; +}; + /** * @class PostgresWrapper * @brief Concrete implementation of SQLWrapper for PostgreSQL using libpqxx. */ -class PostgresWrapper : public SQLWrapper { +class PostgresWrapper : public SQLWrapper { public: /** * @brief Constructs a PostgresWrapper. @@ -44,7 +66,6 @@ class PostgresWrapper : public SQLWrapper { ~PostgresWrapper() override; - void disconnect() override; Table get_table(const string& name) override; vector
list_tables() override; void map_table(const Table& table, @@ -52,20 +73,15 @@ class PostgresWrapper : public SQLWrapper { const vector& skip_columns = {}, bool second_level = false) override; void map_sql_query(const string& virtual_name, const string& raw_query) override; - pqxx::result execute_query(const string& query); protected: - unique_ptr connect() override; // Regex for parsing alias patterns (e.g., "AS public_feature__uniquename") const string alias_pattern_regex = R"(\bAS\s+([a-zA-Z_][a-zA-Z0-9_]*)__([a-zA-Z_][a-zA-Z0-9_]*))"; + PostgresDatabaseConnection& pg_client() { + return static_cast(*db_client); + } private: - string host; - int port; - string database; - string user; - string password; - // Store tables in cache to avoid repeated database queries. optional> tables_cache; vector build_columns_to_map(const Table& table, const vector& skip_columns = {}); diff --git a/src/tests/cpp/db_adapter_test.cc b/src/tests/cpp/db_adapter_test.cc index 30b8531d..ada0a955 100644 --- a/src/tests/cpp/db_adapter_test.cc +++ b/src/tests/cpp/db_adapter_test.cc @@ -25,6 +25,48 @@ class PostgresWrapperTestEnvironment : public ::testing::Environment { void TearDown() override {} }; +class PostgresDatabaseConnectionTest : public ::testing::Test { + protected: + string TEST_HOST = "localhost"; + int TEST_PORT = 5433; + string TEST_DB = "postgres_wrapper_test"; + string TEST_USER = "postgres"; + string TEST_PASSWORD = "test"; + + string INVALID_HOST = "invalid.host"; + int INVALID_PORT = 99999; + string INVALID_DB = "database_xyz"; + + string FEATURE_TABLE = "public.feature"; + string ORGANISM_TABLE = "public.organism"; + string CVTERM_TABLE = "public.cvterm"; + string FEATURE_PK = "feature_id"; + string ORGANISM_PK = "organism_id"; + string CVTERM_PK = "cvterm_id"; + + int DROSOPHILA_ORGANISM_ID = 1; + int HUMAN_ORGANISM_ID = 2; + + int WHITE_GENE_ID = 1; + string WHITE_GENE_NAME = "white"; + string WHITE_GENE_UNIQUENAME = "FBgn0003996"; + + int TOTAL_ROWS_ORGANISMS = 5; + int TOTAL_ROWS_CVTERMS = 10; + int TOTAL_ROWS_FEATURES = 26; + + void SetUp() override {} + + void TearDown() override {} + + shared_ptr create_db_connection() { + auto conn = make_shared( + "test-conn", TEST_HOST, TEST_PORT, TEST_DB, TEST_USER, TEST_PASSWORD); + conn->start(); + return conn; + } +}; + class PostgresWrapperTest : public ::testing::Test { protected: string TEST_HOST = "localhost"; @@ -92,26 +134,110 @@ class PostgresWrapperTest : public ::testing::Test { } string temp_file_path; + + shared_ptr create_db_connection() { + auto conn = make_shared( + "test-conn", TEST_HOST, TEST_PORT, TEST_DB, TEST_USER, TEST_PASSWORD); + conn->start(); + return conn; + } }; -TEST_F(PostgresWrapperTest, Connection) { - auto wrapper = create_wrapper(); +TEST_F(PostgresDatabaseConnectionTest, Connection) { + auto conn = create_db_connection(); + + EXPECT_TRUE(conn->is_connected()); - auto result = wrapper->execute_query("SELECT 1"); + auto result = conn->execute_query("SELECT 1"); EXPECT_FALSE(result.empty()); EXPECT_EQ(result[0][0].as(), 1); - EXPECT_THROW({ PostgresWrapper("invalid.host", TEST_PORT, TEST_DB, TEST_USER, TEST_PASSWORD); }, - std::runtime_error); - EXPECT_THROW({ PostgresWrapper(TEST_HOST, 99999, TEST_DB, TEST_USER, TEST_PASSWORD); }, - std::runtime_error); - EXPECT_THROW({ PostgresWrapper(TEST_HOST, TEST_PORT, "non_existent_db", TEST_USER, TEST_PASSWORD); }, - std::runtime_error); + conn->stop(); - wrapper->disconnect(); + EXPECT_FALSE(conn->is_connected()); - EXPECT_THROW(wrapper->execute_query("SELECT 1"), std::runtime_error); + auto conn1 = new PostgresDatabaseConnection( + "test-conn1", INVALID_HOST, TEST_PORT, TEST_DB, TEST_USER, TEST_PASSWORD); + EXPECT_THROW(conn1->connect(), std::runtime_error); + + auto conn2 = new PostgresDatabaseConnection( + "test-conn2", TEST_HOST, INVALID_PORT, TEST_DB, TEST_USER, TEST_PASSWORD); + EXPECT_THROW(conn2->connect(), std::runtime_error); + + auto conn3 = new PostgresDatabaseConnection( + "test-conn3", TEST_HOST, TEST_PORT, INVALID_DB, TEST_USER, TEST_PASSWORD); + EXPECT_THROW(conn3->connect(), std::runtime_error); +} + +TEST_F(PostgresDatabaseConnectionTest, ConcurrentConnection) { + const int num_threads = 100; + vector threads; + atomic count_threads{0}; + + auto worker = [&](int thread_id) { + try { + string thread_id_str = "conn-" + to_string(thread_id); + auto conn = new PostgresDatabaseConnection( + thread_id_str, TEST_HOST, TEST_PORT, TEST_DB, TEST_USER, TEST_PASSWORD); + + EXPECT_FALSE(conn->is_connected()); + + conn->start(); + + EXPECT_TRUE(conn->is_connected()); + + conn->execute_query("SELECT 1"); + + count_threads++; + + conn->stop(); + + EXPECT_FALSE(conn->is_connected()); + } catch (const exception& e) { + cout << "Thread " << thread_id << " failed with error: " << e.what() << endl; + } + }; + + for (int i = 0; i < num_threads; ++i) threads.emplace_back(worker, i); + + for (auto& t : threads) t.join(); + + EXPECT_EQ(count_threads, num_threads); +} + +TEST_F(PostgresDatabaseConnectionTest, CheckData) { + auto conn = create_db_connection(); + + auto result = conn->execute_query( + "SELECT organism_id, genus, species, common_name FROM organism WHERE organism_id = 1"); + + ASSERT_EQ(result.size(), 1); + EXPECT_EQ(result[0]["organism_id"].as(), 1); + EXPECT_EQ(result[0]["genus"].as(), "Drosophila"); + EXPECT_EQ(result[0]["species"].as(), "melanogaster"); + EXPECT_EQ(result[0]["common_name"].as(), "fruit fly"); + + auto result2 = + conn->execute_query("SELECT feature_id, name, uniquename FROM feature WHERE feature_id = " + + to_string(WHITE_GENE_ID)); + + ASSERT_EQ(result2.size(), 1); + EXPECT_EQ(result2[0]["feature_id"].as(), WHITE_GENE_ID); + EXPECT_EQ(result2[0]["name"].as(), WHITE_GENE_NAME); + EXPECT_EQ(result2[0]["uniquename"].as(), WHITE_GENE_UNIQUENAME); + + auto result3 = conn->execute_query("SELECT COUNT(*) FROM organism"); + + EXPECT_EQ(result3[0][0].as(), TOTAL_ROWS_ORGANISMS); + + auto result4 = conn->execute_query("SELECT COUNT(*) FROM cvterm"); + + EXPECT_EQ(result4[0][0].as(), TOTAL_ROWS_CVTERMS); + + auto result5 = conn->execute_query("SELECT COUNT(*) FROM feature"); + + EXPECT_EQ(result5[0][0].as(), TOTAL_ROWS_FEATURES); } TEST_F(PostgresWrapperTest, GetTable) { @@ -206,40 +332,6 @@ TEST_F(PostgresWrapperTest, TablesStructure) { EXPECT_TRUE(has_type_fk); } -TEST_F(PostgresWrapperTest, CheckData) { - auto wrapper = create_wrapper(); - - auto result = wrapper->execute_query( - "SELECT organism_id, genus, species, common_name FROM organism WHERE organism_id = 1"); - - ASSERT_EQ(result.size(), 1); - EXPECT_EQ(result[0]["organism_id"].as(), 1); - EXPECT_EQ(result[0]["genus"].as(), "Drosophila"); - EXPECT_EQ(result[0]["species"].as(), "melanogaster"); - EXPECT_EQ(result[0]["common_name"].as(), "fruit fly"); - - auto result2 = - wrapper->execute_query("SELECT feature_id, name, uniquename FROM feature WHERE feature_id = " + - to_string(WHITE_GENE_ID)); - - ASSERT_EQ(result2.size(), 1); - EXPECT_EQ(result2[0]["feature_id"].as(), WHITE_GENE_ID); - EXPECT_EQ(result2[0]["name"].as(), WHITE_GENE_NAME); - EXPECT_EQ(result2[0]["uniquename"].as(), WHITE_GENE_UNIQUENAME); - - auto result3 = wrapper->execute_query("SELECT COUNT(*) FROM organism"); - - EXPECT_EQ(result3[0][0].as(), TOTAL_ROWS_ORGANISMS); - - auto result4 = wrapper->execute_query("SELECT COUNT(*) FROM cvterm"); - - EXPECT_EQ(result4[0][0].as(), TOTAL_ROWS_CVTERMS); - - auto result5 = wrapper->execute_query("SELECT COUNT(*) FROM feature"); - - EXPECT_EQ(result5[0][0].as(), TOTAL_ROWS_FEATURES); -} - // map_table - SQL2ATOMS TEST_F(PostgresWrapperTest, MapTablesFirstRowAtoms) { auto wrapper = create_wrapper();