From 28931cd00d5014c006bc904a8b205b4aff4e066a Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 23 Jul 2025 11:57:46 +0500 Subject: [PATCH] * Replaced malloc with std::vector for safer memory management. * Marked functions ref_count_client and ref_count_server as noexcept since standard exceptions are not being handled and application crashes are acceptable in such cases. --- include/PgSQL_PreparedStatement.h | 4 +- lib/PgSQL_PreparedStatement.cpp | 95 +++++++++++++++---------------- 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/include/PgSQL_PreparedStatement.h b/include/PgSQL_PreparedStatement.h index 7e167933f..d1f60890f 100644 --- a/include/PgSQL_PreparedStatement.h +++ b/include/PgSQL_PreparedStatement.h @@ -93,8 +93,8 @@ public: inline void rdlock() { pthread_rwlock_rdlock(&rwlock_); } inline void wrlock() { pthread_rwlock_wrlock(&rwlock_); } inline void unlock() { pthread_rwlock_unlock(&rwlock_); } - void ref_count_client(uint64_t _stmt, int _v, bool lock=true); - void ref_count_server(uint64_t _stmt, int _v, bool lock=true); + void ref_count_client(uint64_t _stmt, int _v, bool lock=true) noexcept; + void ref_count_server(uint64_t _stmt, int _v, bool lock=true) noexcept; PgSQL_STMT_Global_info* add_prepared_statement(char *user, char *database, char *query, unsigned int query_len, char *fc, Parse_Param_Types&& ppt, bool lock=true); void get_metrics(uint64_t *c_unique, uint64_t *c_total, uint64_t *stmt_max_stmt_id, uint64_t *cached, diff --git a/lib/PgSQL_PreparedStatement.cpp b/lib/PgSQL_PreparedStatement.cpp index cd688a7a6..460b2ea4e 100644 --- a/lib/PgSQL_PreparedStatement.cpp +++ b/lib/PgSQL_PreparedStatement.cpp @@ -40,7 +40,8 @@ static uint64_t stmt_compute_hash(const char *user, l += (param_types.size() * sizeof(uint32_t)); // add length for parameter types } - auto buf = (char *)malloc(l); + std::vector storage(l); + char* buf = storage.data(); l = 0; memcpy(buf + l, user, user_length); l += user_length; // write user memcpy(buf + l, DELIM1, delim1_length); l += delim1_length; // write delimiter1 @@ -54,7 +55,6 @@ static uint64_t stmt_compute_hash(const char *user, memcpy(buf + l, param_types.data(), size * sizeof(uint32_t)); l += (size * sizeof(uint32_t)); // write each parameter type } uint64_t hash = SpookyHash::Hash64(buf, l, 0); - free(buf); return hash; } @@ -297,7 +297,7 @@ PgSQL_STMT_Manager_v14::~PgSQL_STMT_Manager_v14() { } } -void PgSQL_STMT_Manager_v14::ref_count_client(uint64_t _stmt_id ,int _v, bool lock) { +void PgSQL_STMT_Manager_v14::ref_count_client(uint64_t _stmt_id ,int _v, bool lock) noexcept { if (lock) pthread_rwlock_wrlock(&rwlock_); @@ -312,60 +312,57 @@ void PgSQL_STMT_Manager_v14::ref_count_client(uint64_t _stmt_id ,int _v, bool lo } } stmt_info->ref_count_client += _v; - time_t ct = time(NULL); - uint64_t num_client_count_zero = __sync_add_and_fetch(&num_stmt_with_ref_client_count_zero, 0); - uint64_t num_server_count_zero = __sync_add_and_fetch(&num_stmt_with_ref_server_count_zero, 0); - - size_t map_size = map_stmt_id_to_info.size(); - if ( - (ct > last_purge_time+1) && - (map_size > (unsigned)mysql_thread___max_stmts_cache ) && - (num_client_count_zero > map_size/10) && - (num_server_count_zero > map_size/10) - ) { // purge only if there is at least 10% gain - last_purge_time = ct; - int max_purge = map_size ; - int i = -1; - uint64_t *torem = - (uint64_t *)malloc(max_purge * sizeof(uint64_t)); - for (auto it = map_stmt_id_to_info.begin(); it != map_stmt_id_to_info.end(); ++it) { - if ( (i == (max_purge - 1)) || (i == ((int)num_client_count_zero - 1)) ) { - break; // nothing left to clean up - } - PgSQL_STMT_Global_info *a = it->second; - if ((__sync_add_and_fetch(&a->ref_count_client, 0) == 0) && - (a->ref_count_server == 0) ) // this to avoid that IDs are incorrectly reused - { - uint64_t hash = a->hash; - if (auto s2 = map_stmt_hash_to_info.find(hash); s2 != map_stmt_hash_to_info.end()) { - map_stmt_hash_to_info.erase(s2); - } - __sync_sub_and_fetch(&num_stmt_with_ref_client_count_zero,1); - i++; - torem[i] = it->first; - } + time_t ct = time(NULL); + uint64_t num_client_count_zero = __sync_add_and_fetch(&num_stmt_with_ref_client_count_zero, 0); + uint64_t num_server_count_zero = __sync_add_and_fetch(&num_stmt_with_ref_server_count_zero, 0); + + size_t map_size = map_stmt_id_to_info.size(); + if ( + (ct > last_purge_time+1) && + (map_size > (unsigned)pgsql_thread___max_stmts_cache) && + (num_client_count_zero > map_size/10) && + (num_server_count_zero > map_size/10) + ) { // purge only if there is at least 10% gain + last_purge_time = ct; + int max_purge = map_size ; + std::vector torem; + torem.reserve(max_purge); + + for (auto it = map_stmt_id_to_info.begin(); it != map_stmt_id_to_info.end(); ++it) { + if (torem.size() >= std::min(static_cast(max_purge), + static_cast(num_client_count_zero))) { + break; } - while (i >= 0) { - uint64_t id = torem[i]; - auto s3 = map_stmt_id_to_info.find(id); - PgSQL_STMT_Global_info *a = s3->second; - if (a->ref_count_server == 0) { - __sync_sub_and_fetch(&num_stmt_with_ref_server_count_zero,1); - free_stmt_ids.push(id); - } - map_stmt_id_to_info.erase(s3); - statuses.s_total -= a->ref_count_server; - delete a; - i--; + PgSQL_STMT_Global_info *a = it->second; + if ((__sync_add_and_fetch(&a->ref_count_client, 0) == 0) && + (a->ref_count_server == 0) ) // this to avoid that IDs are incorrectly reused + { + uint64_t hash = a->hash; + map_stmt_hash_to_info.erase(hash); + __sync_sub_and_fetch(&num_stmt_with_ref_client_count_zero,1); + torem.emplace_back(it->first); + } + } + while (!torem.empty()) { + uint64_t id = torem.back(); + torem.pop_back(); + auto s3 = map_stmt_id_to_info.find(id); + PgSQL_STMT_Global_info *a = s3->second; + if (a->ref_count_server == 0) { + __sync_sub_and_fetch(&num_stmt_with_ref_server_count_zero,1); + free_stmt_ids.push(id); } - free(torem); + map_stmt_id_to_info.erase(s3); + statuses.s_total -= a->ref_count_server; + delete a; } + } } if (lock) pthread_rwlock_unlock(&rwlock_); } -void PgSQL_STMT_Manager_v14::ref_count_server(uint64_t _stmt_id ,int _v, bool lock) { +void PgSQL_STMT_Manager_v14::ref_count_server(uint64_t _stmt_id ,int _v, bool lock) noexcept { if (lock) pthread_rwlock_wrlock(&rwlock_); std::map::iterator s;