From ae1f3126a2528a914084aa13939b885366e46498 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 13:26:10 +0500 Subject: [PATCH 01/16] Added explicit transaction state manager --- include/PgSQL_ExplicitTxnStateMgr.h | 112 +++++++++ lib/PgSQL_ExplicitTxnStateMgr.cpp | 377 ++++++++++++++++++++++++++++ 2 files changed, 489 insertions(+) create mode 100644 include/PgSQL_ExplicitTxnStateMgr.h create mode 100644 lib/PgSQL_ExplicitTxnStateMgr.cpp diff --git a/include/PgSQL_ExplicitTxnStateMgr.h b/include/PgSQL_ExplicitTxnStateMgr.h new file mode 100644 index 000000000..1d91a4db0 --- /dev/null +++ b/include/PgSQL_ExplicitTxnStateMgr.h @@ -0,0 +1,112 @@ +#ifndef PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H +#define PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H +#include +#include +#include +#include "proxysql.h" +#include "cpp.h" +#include "PgSQL_Connection.h" + +#ifndef PROXYJSON +#define PROXYJSON +#include "../deps/json/json_fwd.hpp" +#endif // PROXYJSON + +/** + * @struct PgSQL_Variable_Snapshot + * @brief Represents a snapshot of PostgreSQL variables during a transaction. + * + * This structure is used to store the state of PostgreSQL variables, including + * their values and hash representations, at a specific point in time during a transaction. + */ +typedef struct PgSQL_Variable_Snapshot { + char* var_value[PGSQL_NAME_LAST_HIGH_WM] = {}; + uint32_t var_hash[PGSQL_NAME_LAST_HIGH_WM] = {}; +} PgSQL_Variable_Snapshot_t; + +/** + * @struct TxnCmd + * @brief Represents a transaction command type begin executed and its associated metadata. + * + */ +struct TxnCmd { + /** + * @enum Type + * @brief Enumerates the types of transaction commands. + */ + enum Type { + UNKNOWN = -1, + BEGIN, + COMMIT, + ROLLBACK, + SAVEPOINT, + RELEASE, + ROLLBACK_TO + } type = Type::UNKNOWN; + std::string savepoint; //< The name of the savepoint, if applicable. +}; + +/** + * @class PgSQL_TxnCmdParser + * @brief Parses transaction-related commands for PostgreSQL. + * + * This class is responsible for tokenizing and interpreting transaction-related + * commands such as BEGIN, COMMIT, ROLLBACK, SAVEPOINT, etc. + */ +class PgSQL_TxnCmdParser { +public: + TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept; + +private: + std::vector tokens; + + TxnCmd parse_rollback(size_t& pos) noexcept; + TxnCmd parse_savepoint(size_t& pos) noexcept; + TxnCmd parse_release(size_t& pos) noexcept; + + // Helpers + static std::string to_lower(std::string_view s) noexcept { + std::string s_copy(s); + std::transform(s_copy.begin(), s_copy.end(), s_copy.begin(), ::tolower); + return s_copy; + } + + inline static bool contains(std::vector&& list, std::string_view value) noexcept { + for (const auto& item : list) if (item == value) return true; + return false; + } +}; + +/** + * @class PgSQL_ExplicitTxnStateMgr + * @brief Manages the state of explicit transactions in PostgreSQL. + * + * This class is responsible for handling explicit transaction commands such as + * BEGIN, COMMIT, ROLLBACK, SAVEPOINT, and managing the associated state. + */ +class PgSQL_ExplicitTxnStateMgr { +public: + PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess); + ~PgSQL_ExplicitTxnStateMgr(); + + bool handle_transaction(std::string_view input); + int get_savepoint_count() const { return savepoint.size(); } + void fill_internal_session(nlohmann::json& j); + +private: + PgSQL_Session* session; + std::vector transaction_state; + std::vector savepoint; + PgSQL_TxnCmdParser tx_parser; + + void start_transaction(); + void commit(); + void rollback(); + bool add_savepoint(std::string_view name); + bool rollback_to_savepoint(std::string_view name); + bool release_savepoint(std::string_view name); + + static void reset_variable_snapshot(PgSQL_Variable_Snapshot_t& var_snapshot) noexcept; +}; + +#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H \ No newline at end of file diff --git a/lib/PgSQL_ExplicitTxnStateMgr.cpp b/lib/PgSQL_ExplicitTxnStateMgr.cpp new file mode 100644 index 000000000..f29cb1cc9 --- /dev/null +++ b/lib/PgSQL_ExplicitTxnStateMgr.cpp @@ -0,0 +1,377 @@ +#include "PgSQL_ExplicitTxnStateMgr.h" +#include "proxysql.h" +#include "PgSQL_Session.h" +#include "PgSQL_Data_Stream.h" +#include "PgSQL_Connection.h" + +extern class PgSQL_Variables pgsql_variables; + +PgSQL_ExplicitTxnStateMgr::PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess) : session(sess) { + +} + +PgSQL_ExplicitTxnStateMgr::~PgSQL_ExplicitTxnStateMgr() { + for (auto& tran_state : transaction_state) { + reset_variable_snapshot(tran_state); + } + transaction_state.clear(); + savepoint.clear(); +} + +void PgSQL_ExplicitTxnStateMgr::reset_variable_snapshot(PgSQL_Variable_Snapshot_t& var_snapshot) noexcept { + for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { + if (var_snapshot.var_value[idx]) { + free(var_snapshot.var_value[idx]); + var_snapshot.var_value[idx] = nullptr; + } + var_snapshot.var_hash[idx] = 0; + } +} + +void verify_server_variables(PgSQL_Session* session) { +#ifdef DEBUG + for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { + const char* conn_param_status = session->mybe->server_myds->myconn->get_pg_parameter_status(pgsql_tracked_variables[idx].set_variable_name); + const char* param_value = session->mybe->server_myds->myconn->variables[idx].value; + if (conn_param_status && param_value) { + assert(strcmp(conn_param_status, param_value) == 0); + } + } +#endif +} + +void PgSQL_ExplicitTxnStateMgr::start_transaction() { + if (transaction_state.empty() == false) { + // Transaction already started, do nothing and return + proxy_warning("Received BEGIN command. There is already a transaction in progress\n"); + assert(session->NumActiveTransactions() > 0); + return; + } + + assert(session->client_myds && session->client_myds->myconn); + + PgSQL_Variable_Snapshot_t var_snapshot{}; + + // check if already in transaction, if yes then do nothing + for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { + + uint32_t hash = pgsql_variables.client_get_hash(session, idx); + if (hash != 0) { + var_snapshot.var_hash[idx] = hash; + var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx)); + } else { + assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null + // no need to store the value + //var_snapshot.var_hash[idx] = 0; + //var_snapshot.var_value[idx] = NULL; + } + } + transaction_state.emplace_back(std::move(var_snapshot)); +} + +void PgSQL_ExplicitTxnStateMgr::commit() { + if (transaction_state.empty()) { + proxy_warning("Received COMMIT command. There is no transaction in progress\n"); + assert(session->NumActiveTransactions() == 0); + return; + } + + assert(session->client_myds && session->client_myds->myconn); + + for (auto& tran_state : transaction_state) { + reset_variable_snapshot(tran_state); + } + transaction_state.clear(); + savepoint.clear(); + verify_server_variables(session); +} + +void PgSQL_ExplicitTxnStateMgr::rollback() { + + if (transaction_state.empty()) { + proxy_warning("Received ROLLBACK command. There is no transaction in progress\n"); + assert(session->NumActiveTransactions() == 0); + return; + } + + assert(session->client_myds && session->client_myds->myconn); + + const PgSQL_Variable_Snapshot_t& var_snapshot = transaction_state.front(); + + for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { + uint32_t hash = var_snapshot.var_hash[idx]; + if (hash != 0) { + uint32_t client_hash = pgsql_variables.client_get_hash(session, idx); + uint32_t server_hash = pgsql_variables.server_get_hash(session, idx); + + assert(client_hash == server_hash); + if (hash == client_hash) + continue; + + pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); + pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); + } else { + assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null + pgsql_variables.client_reset_value(session, idx, false); + pgsql_variables.server_reset_value(session, idx, false); + } + } + // reuse of connection that has extra param set in connection + session->client_myds->myconn->reorder_dynamic_variables_idx(); + if (session->mybe) { + session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); + + verify_server_variables(session); + } + + // Clear savepoints and reset the initial snapshot + for (auto& tran_state : transaction_state) { + reset_variable_snapshot(tran_state); + } + transaction_state.clear(); + savepoint.clear(); +} + +bool PgSQL_ExplicitTxnStateMgr::rollback_to_savepoint(std::string_view name) { + + if (transaction_state.empty()) { + proxy_warning("Received ROLLBACK TO SAVEPOINT '%s' command. There is no transaction in progress\n", name.data()); + assert(session->NumActiveTransactions() == 0); + return false; + } + + assert(session->client_myds && session->client_myds->myconn); + + int tran_state_idx = -1; + + for (size_t idx = 0; idx < savepoint.size(); idx++) { + if (savepoint[idx].size() == name.size() && + strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) { + tran_state_idx = idx; + break; + } + } + + if (tran_state_idx == -1) { + proxy_warning("Savepoint '%s' not found.\n", name.data()); + return false; + }; + + assert(tran_state_idx + 1 < (int)transaction_state.size()); + + PgSQL_Variable_Snapshot_t& var_snapshot = transaction_state[tran_state_idx+1]; + for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { + uint32_t hash = var_snapshot.var_hash[idx]; + if (hash != 0) { + uint32_t client_hash = pgsql_variables.client_get_hash(session, idx); + uint32_t server_hash = pgsql_variables.server_get_hash(session, idx); + assert(client_hash == server_hash); + if (hash == client_hash) + continue; + pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); + pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); + } + else { + assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null + pgsql_variables.client_reset_value(session, idx, false); + pgsql_variables.server_reset_value(session, idx, false); + } + } + for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) { + reset_variable_snapshot(transaction_state[idx]); + } + transaction_state.resize(tran_state_idx + 1); + savepoint.resize(tran_state_idx); + + return true; +} + +bool PgSQL_ExplicitTxnStateMgr::release_savepoint(std::string_view name) { + + if (transaction_state.empty()) { + proxy_warning("Received RELEASE SAVEPOINT '%s' command. There is no transaction in progress\n", name.data()); + assert(session->NumActiveTransactions() == 0); + return false; + } + + assert(session->client_myds && session->client_myds->myconn); + + int tran_state_idx = -1; + + for (size_t idx = 0; idx < savepoint.size(); idx++) { + if (savepoint[idx].size() == name.size() && + strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) { + tran_state_idx = idx; + break; + } + } + + if (tran_state_idx == -1) { + proxy_warning("SAVEPOINT '%s' not found.\n", name.data()); + return false; + }; + + for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) { + reset_variable_snapshot(transaction_state[idx]); + } + transaction_state.resize(tran_state_idx + 1); + savepoint.resize(tran_state_idx); + + return true; +} + +bool PgSQL_ExplicitTxnStateMgr::add_savepoint(std::string_view name) { + + if (transaction_state.empty()) { + proxy_warning("Received SAVEPOINT '%s' command. There is no transaction in progress\n", name.data()); + assert(session->NumActiveTransactions() == 0); + return false; + } + + assert(session->client_myds && session->client_myds->myconn); + + auto it = std::find_if(savepoint.begin(), savepoint.end(), [name](std::string_view sp) { + return sp.size() == name.size() && + strncasecmp(sp.data(), name.data(), name.size()) == 0; + }); + if (it != savepoint.end()) return false; + + PgSQL_Variable_Snapshot_t var_snapshot{}; + + for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { + uint32_t hash = pgsql_variables.client_get_hash(session, idx); + if (hash != 0) { + var_snapshot.var_hash[idx] = hash; + var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx)); + } + } + transaction_state.emplace_back(std::move(var_snapshot)); + savepoint.emplace_back(name); + assert((transaction_state.size() - 1) == savepoint.size()); + + return true; +} + +void PgSQL_ExplicitTxnStateMgr::fill_internal_session(nlohmann::json& j) { + if (transaction_state.empty()) return; + + auto& initial_state = j["initial_state"]; + + for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { + uint32_t hash = transaction_state[0].var_hash[idx]; + if (hash != 0) { + initial_state[pgsql_tracked_variables[idx].set_variable_name] = transaction_state[0].var_value[idx]; + } + } + + if (savepoint.empty()) return; + + for (size_t idx = 0; idx < savepoint.size(); idx++) { + auto& savepoint_json = j["savepoints"][savepoint[idx]]; + int tran_state_idx = idx + 1; + for (int idx2 = 0; idx2 < PGSQL_NAME_LAST_HIGH_WM; idx2++) { + uint32_t hash = transaction_state[tran_state_idx].var_hash[idx2]; + if (hash != 0) { + savepoint_json[pgsql_tracked_variables[idx2].set_variable_name] = transaction_state[tran_state_idx].var_value[idx2]; + } + } + } +} + +bool PgSQL_ExplicitTxnStateMgr::handle_transaction(std::string_view input) { + TxnCmd cmd = tx_parser.parse(input, (session->active_transactions > 0)); + switch (cmd.type) { + case TxnCmd::BEGIN: + start_transaction(); + break; + case TxnCmd::COMMIT: + commit(); + break; + case TxnCmd::ROLLBACK: + rollback(); + break; + case TxnCmd::SAVEPOINT: + return add_savepoint(cmd.savepoint); + case TxnCmd::RELEASE: + return release_savepoint(cmd.savepoint); + case TxnCmd::ROLLBACK_TO: + return rollback_to_savepoint(cmd.savepoint); + default: + // Unknown command + return false; + } + return true; +} + + +TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mode) noexcept { + tokens.clear(); + TxnCmd cmd; + bool in_quote = false; + size_t start = 0; + char quote_char = 0; + + // Tokenize with quote handling + for (size_t i = 0; i <= input.size(); ++i) { + const bool at_end = i == input.size(); + const char c = at_end ? 0 : input[i]; + + if (in_quote) { + if (c == quote_char || at_end) { + tokens.emplace_back(input.substr(start + 1, i - start - 1)); + in_quote = false; + } + continue; + } + + if (c == '"' || c == '\'') { + in_quote = true; + quote_char = c; + start = i; + } + else if (isspace(c) || c == ';' || at_end) { + if (start < i) tokens.emplace_back(input.substr(start, i - start)); + start = i + 1; + } + } + + if (tokens.empty()) return cmd; + + size_t pos = 0; + const std::string first = to_lower(tokens[pos++]); + + if (in_transaction_mode == true) { + if (first == "begin") cmd.type = TxnCmd::BEGIN; + else if (first == "savepoint") cmd = parse_savepoint(pos); + else if (first == "release") cmd = parse_release(pos); + } else { + if (first == "commit") cmd.type = TxnCmd::COMMIT; + else if (first == "rollback" || (first == "abort")) cmd = parse_rollback(pos); + } + return cmd; +} + +TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept { + TxnCmd cmd{ TxnCmd::ROLLBACK }; + while (pos < tokens.size() && contains({ "work", "transaction" }, to_lower(tokens[pos]))) pos++; + + if (pos < tokens.size() && to_lower(tokens[pos]) == "to") { + cmd.type = TxnCmd::ROLLBACK_TO; + if (++pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++; + if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; + } + return cmd; +} + +TxnCmd PgSQL_TxnCmdParser::parse_savepoint(size_t& pos) noexcept { + TxnCmd cmd{ TxnCmd::SAVEPOINT }; + if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; + return cmd; +} + +TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept { + TxnCmd cmd{ TxnCmd::RELEASE }; + if (pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++; + if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; + return cmd; +} From 2f9bb83dda8c8e56ad3ece73cbe2c82d5ee429c5 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 13:28:29 +0500 Subject: [PATCH 02/16] Added Abort (alias of Rollback) --- include/proxysql_structs.h | 1 + lib/PgSQL_Query_Processor.cpp | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 983b121fa..eb6e04c6f 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -627,6 +627,7 @@ enum PGSQL_QUERY_command { PGSQL_QUERY_BEGIN, PGSQL_QUERY_COMMIT, PGSQL_QUERY_ROLLBACK, + PGSQL_QUERY_ABORT, PGSQL_QUERY_DECLARE_CURSOR, PGSQL_QUERY_CLOSE_CURSOR, PGSQL_QUERY_DISCARD, diff --git a/lib/PgSQL_Query_Processor.cpp b/lib/PgSQL_Query_Processor.cpp index 9165a561f..257639f86 100644 --- a/lib/PgSQL_Query_Processor.cpp +++ b/lib/PgSQL_Query_Processor.cpp @@ -102,6 +102,7 @@ static char* commands_counters_desc[PGSQL_QUERY___NONE] = { [PGSQL_QUERY_BEGIN] = (char*)"BEGIN", [PGSQL_QUERY_COMMIT] = (char*)"COMMIT", [PGSQL_QUERY_ROLLBACK] = (char*)"ROLLBACK", + [PGSQL_QUERY_ABORT] = (char*)"ABORT", [PGSQL_QUERY_DECLARE_CURSOR] = (char*)"DECLARE_CURSOR", [PGSQL_QUERY_CLOSE_CURSOR] = (char*)"CLOSE_CURSOR", [PGSQL_QUERY_DISCARD] = (char*)"DISCARD", @@ -664,7 +665,7 @@ enum PGSQL_QUERY_command PgSQL_Query_Processor::query_parser_command_type(SQP_pa char c1; tokenizer_t tok; - tokenizer(&tok, text, " ", TOKENIZER_NO_EMPTIES); + tokenizer(&tok, text, " ;", TOKENIZER_NO_EMPTIES); char* token = NULL; __get_token: token = (char*)tokenize(&tok); @@ -761,6 +762,10 @@ __remove_parenthesis: ret = PGSQL_QUERY_ANALYZE; break; } + if (!strcasecmp("ABORT", token)) { + ret = PGSQL_QUERY_ROLLBACK;; + break; + } break; case 'b': case 'B': From d0952275adc44d4e7825d6409efc935fbb9e47a6 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 13:29:15 +0500 Subject: [PATCH 03/16] Added PgSQL_ExplicitTxnStateMgr in MakeFile --- lib/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Makefile b/lib/Makefile index 916730260..33fc0fa5c 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -164,7 +164,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo PgSQL_Protocol.oo PgSQL_Thread.oo PgSQL_Data_Stream.oo PgSQL_Session.oo PgSQL_Variables.oo PgSQL_HostGroups_Manager.oo PgSQL_Connection.oo PgSQL_Backend.oo PgSQL_Logger.oo PgSQL_Authentication.oo PgSQL_Error_Helper.oo \ MySQL_Query_Cache.oo PgSQL_Query_Cache.oo PgSQL_Monitor.oo \ MySQL_Set_Stmt_Parser.oo PgSQL_Set_Stmt_Parser.oo \ - PgSQL_Variables_Validator.oo + PgSQL_Variables_Validator.oo PgSQL_ExplicitTxnStateMgr.oo OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX)) HEADERS := ../include/*.h ../include/*.hpp From 33b8b2bd3d2f124edf53713ddfc39d20fb9c1808 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 14:50:20 +0500 Subject: [PATCH 04/16] Added a static_assert to enforce a maximum thread name length of 16 characters at compile time --- include/proxysql_utils.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/proxysql_utils.h b/include/proxysql_utils.h index e03b01bbc..2d2ac3e1a 100644 --- a/include/proxysql_utils.h +++ b/include/proxysql_utils.h @@ -307,7 +307,9 @@ struct free_deleter { template using mf_unique_ptr = std::unique_ptr; -static inline void set_thread_name(const char name[16], const bool en = true) { +template +static inline void set_thread_name(const char(&name)[LEN], const bool en = true) { + static_assert(LEN < 17, "Thread name must not exceed 16 characters"); if (en == false) { return; } From c6fd3cef9c5ae4a4d2f78804def841289ceffdf9 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 14:53:33 +0500 Subject: [PATCH 05/16] Added reorder_dynamic_variables_idx flag for server connection variables --- include/PgSQL_Variables.h | 6 +++--- lib/PgSQL_Variables.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/PgSQL_Variables.h b/include/PgSQL_Variables.h index 7e2daec8c..04768bfa0 100644 --- a/include/PgSQL_Variables.h +++ b/include/PgSQL_Variables.h @@ -31,15 +31,15 @@ public: bool client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx); bool client_set_hash_and_value(PgSQL_Session* session, int idx, const std::string& value, uint32_t hash); - void client_reset_value(PgSQL_Session* session, int idx); + void client_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx); const char* client_get_value(PgSQL_Session* session, int idx) const; uint32_t client_get_hash(PgSQL_Session* session, int idx) const; void server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx); void server_set_hash_and_value(PgSQL_Session* session, int idx, const char* value, uint32_t hash); - void server_reset_value(PgSQL_Session* session, int idx); + void server_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx); const char* server_get_value(PgSQL_Session* session, int idx) const; - inline uint32_t server_get_hash(PgSQL_Session* session, int idx) const; + uint32_t server_get_hash(PgSQL_Session* session, int idx) const; bool verify_variable(PgSQL_Session* session, int idx) const; bool update_variable(PgSQL_Session* session, session_status status, int &_rc); diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index 7285436c6..801add653 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -85,7 +85,7 @@ bool PgSQL_Variables::client_set_hash_and_value(PgSQL_Session* session, int idx, return true; } -void PgSQL_Variables::client_reset_value(PgSQL_Session* session, int idx) { +void PgSQL_Variables::client_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx) { if (!session || !session->client_myds || !session->client_myds->myconn) { proxy_warning("Session validation failed\n"); return; @@ -99,7 +99,7 @@ void PgSQL_Variables::client_reset_value(PgSQL_Session* session, int idx) { free(client_conn->variables[idx].value); client_conn->variables[idx].value = NULL; } - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx client_conn->reorder_dynamic_variables_idx(); } @@ -170,7 +170,7 @@ void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const ch } } -void PgSQL_Variables::server_reset_value(PgSQL_Session* session, int idx) { +void PgSQL_Variables::server_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx) { assert(session); assert(session->mybe); assert(session->mybe->server_myds); @@ -184,7 +184,7 @@ void PgSQL_Variables::server_reset_value(PgSQL_Session* session, int idx) { free(backend_conn->variables[idx].value); backend_conn->variables[idx].value = NULL; } - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx backend_conn->reorder_dynamic_variables_idx(); } From d90bf4ac73de74ff8fdbd8c75c91c2f39d11ed4e Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 14:55:01 +0500 Subject: [PATCH 06/16] Optimized savepoint detection Added get_pg_parameter_status --- include/PgSQL_Connection.h | 3 ++- lib/PgSQL_Connection.cpp | 42 +++++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index 470278628..252b917eb 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -363,7 +363,7 @@ public: bool set_single_row_mode(); void update_bytes_recv(uint64_t bytes_recv); void update_bytes_sent(uint64_t bytes_sent); - void ProcessQueryAndSetStatusFlags(char* query_digest_text); + void ProcessQueryAndSetStatusFlags(char* query_digest_text, int savepoint_count); inline const PGconn* get_pg_connection() const { return pgsql_conn; } inline int get_pg_server_version() { return PQserverVersion(pgsql_conn); } @@ -388,6 +388,7 @@ public: inline int get_pg_is_threadsafe() { return PQisthreadsafe(); } inline const char* get_pg_error_message() { return PQerrorMessage(pgsql_conn); } inline SSL* get_pg_ssl_object() { return (SSL*)PQsslStruct(pgsql_conn, "OpenSSL"); } + inline const char* get_pg_parameter_status(const char* param) { return PQparameterStatus(pgsql_conn, param); } const char* get_pg_server_version_str(char* buff, int buff_size); const char* get_pg_connection_status_str(); const char* get_pg_transaction_status_str(); diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index cb1517f9e..a105058a9 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -812,7 +812,7 @@ void PgSQL_Connection::connect_start() { const std::string& conninfo_str = conninfo.str(); pgsql_conn = PQconnectStart(conninfo_str.c_str()); - //PQsetErrorVerbosity(pgsql_conn, PQERRORS_VERBOSE); + //PQsetErrorVerbosity(pgsql_conn, PQERRORS_SQLSTATE); //PQsetErrorContextVisibility(pgsql_conn, PQSHOW_CONTEXT_ERRORS); if (pgsql_conn == NULL || PQstatus(pgsql_conn) == CONNECTION_BAD) { @@ -1042,15 +1042,12 @@ int PgSQL_Connection::async_connect(short event) { async_state_machine = ASYNC_IDLE; myds->wait_until = 0; return 0; - break; case ASYNC_CONNECT_FAILED: return -1; - break; case ASYNC_CONNECT_TIMEOUT: return -2; - break; default: - return 1; + break; } return 1; } @@ -1416,8 +1413,8 @@ bool PgSQL_Connection::IsServerOffline() { } bool PgSQL_Connection::is_connection_in_reusable_state() const { - const PGTransactionStatusType txn_status = PQtransactionStatus(pgsql_conn); - const bool conn_usable = !(txn_status == PQTRANS_UNKNOWN || txn_status == PQTRANS_ACTIVE); + PGTransactionStatusType txn_status = PQtransactionStatus(pgsql_conn); + bool conn_usable = !(txn_status == PQTRANS_UNKNOWN || txn_status == PQTRANS_ACTIVE); assert(!(conn_usable == false && is_error_present() == false)); return conn_usable; } @@ -1678,7 +1675,7 @@ void PgSQL_Connection::unhandled_notice_cb(void* arg, const PGresult* result) { #endif } -void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) { +void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text, int savepoint_count) { if (query_digest_text == NULL) return; // unknown what to do with multiplex int mul = -1; @@ -1721,8 +1718,7 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) { default: break; } - } - else { + } else { if (mul != 2 && index(query_digest_text, '.')) { // mul = 2 has a special meaning : do not disable multiplex for variables in THIS QUERY ONLY if (!IsKeepMultiplexEnabledVariables(query_digest_text)) { set_status(true, STATUS_MYSQL_CONNECTION_USER_VARIABLE); @@ -1767,18 +1763,26 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) { } }*/ if (get_status(STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT) == false) { - - if (IsKnownActiveTransaction()) { - if (!strncasecmp(query_digest_text, "SAVEPOINT ", strlen("SAVEPOINT "))) { - set_status(true, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT); + if (savepoint_count > 0) { + set_status(true, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT); + } else if (savepoint_count == -1) { + if (IsKnownActiveTransaction()) { + if (!strncasecmp(query_digest_text, "SAVEPOINT ", strlen("SAVEPOINT "))) { + set_status(true, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT); + } } - } + } } else { - if ((IsKnownActiveTransaction() == false) || - (strcasecmp(query_digest_text, "COMMIT") == 0) || - (strcasecmp(query_digest_text, "ROLLBACK") == 0)) { + if (savepoint_count == 0) { set_status(false, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT); - } + } else if (savepoint_count == -1) { + if ((IsKnownActiveTransaction() == false) || + (strncasecmp(query_digest_text, "COMMIT", strlen("COMMIT")) == 0) || + (strncasecmp(query_digest_text, "ROLLBACK", strlen("ROLLBACK")) == 0) || + (strncasecmp(query_digest_text, "ABORT", strlen("ABORT")) == 0)) { + set_status(false, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT); + } + } } } From 190e3696587ac7a63b9194d33e26bff924449cef Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 14:57:15 +0500 Subject: [PATCH 07/16] Added Transaction State flag in SQLite3_to_Postgres --- include/PgSQL_Protocol.h | 4 ++-- lib/PgSQL_Protocol.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index a1d2a23c1..55a4a9e5b 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -279,7 +279,7 @@ private: std::vector pkt_offset; bool multiple_pkt_mode = false; bool ownership = true; - friend void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type); + friend void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type, char txn_state); }; class PgSQL_Protocol; @@ -1052,6 +1052,6 @@ private: friend void admin_session_handler(S* sess, void* _pa, PtrSize_t* pkt); }; -void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type); +void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type, char txn_state = 'I'); #endif // __POSTGRES_PROTOCOL_H diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 3dd2b79cf..4b606f267 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -197,7 +197,7 @@ void PG_pkt::write_RowDescription(const char *tupdesc, ...) { } -void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error, int affected_rows, const char *query_type) { +void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error, int affected_rows, const char *query_type, char txn_state) { assert(psa != NULL); const char *fs = strchr(query_type, ' '); int qtlen = strlen(query_type); @@ -257,7 +257,7 @@ void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error, pkt.write_CommandComplete(buf); } pkt.to_PtrSizeArray(psa); - pkt.write_ReadyForQuery(); + pkt.write_ReadyForQuery(txn_state); pkt.to_PtrSizeArray(psa); } else { // no resultset PG_pkt pkt(64); @@ -289,7 +289,7 @@ void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error, } } pkt.to_PtrSizeArray(psa); - pkt.write_ReadyForQuery(); + pkt.write_ReadyForQuery(txn_state); pkt.to_PtrSizeArray(psa); } } From 979b3a81f4dc639016ddb00a67e2869f8421896a Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 14:58:51 +0500 Subject: [PATCH 08/16] Added PgSQL_ExplicitTxnStateMgr in session Dump transaction state in proxysql internal session output --- include/PgSQL_Session.h | 2 ++ lib/PgSQL_Session.cpp | 38 ++++++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 96975f608..4b2ada944 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -15,6 +15,7 @@ class PgSQL_Query_Result; +class PgSQL_ExplicitTxnStateMgr; //#include "../deps/json/json.hpp" //using json = nlohmann::json; @@ -389,6 +390,7 @@ public: PgSQL_Data_Stream* client_myds; #endif // 0 PgSQL_Data_Stream* server_myds; + PgSQL_ExplicitTxnStateMgr* transaction_state_manager; #if 0 /* * @brief Store the hostgroups that hold connections that have been flagged as 'expired' by the diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 5426a49ce..21767a99a 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -25,7 +25,7 @@ using json = nlohmann::json; #include "ProxySQL_Cluster.hpp" #include "PgSQL_Query_Cache.h" #include "PgSQL_Variables_Validator.h" - +#include "PgSQL_ExplicitTxnStateMgr.h" #include "libinjection.h" #include "libinjection_sqli.h" @@ -579,6 +579,7 @@ PgSQL_Session::PgSQL_Session() { last_HG_affected_rows = -1; // #1421 : advanced support for LAST_INSERT_ID() proxysql_node_address = NULL; use_ldap_auth = false; + transaction_state_manager = new PgSQL_ExplicitTxnStateMgr(this); } void PgSQL_Session::reset() { @@ -678,6 +679,8 @@ PgSQL_Session::~PgSQL_Session() { for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { reset_default_session_variable((enum pgsql_variable_name)i); } + if (transaction_state_manager) + delete transaction_state_manager; } bool PgSQL_Session::handler_CommitRollback(PtrSize_t* pkt) { @@ -777,6 +780,9 @@ void PgSQL_Session::generate_proxysql_internal_session_json(json& j) { j["qpo"]["max_lag_ms"] = qpo->max_lag_ms; j["user_attributes"] = (user_attributes ? user_attributes : ""); j["transaction_persistent"] = transaction_persistent; + + transaction_state_manager->fill_internal_session(j["transaction_state"]); + if (client_myds != NULL) { // only if client_myds is defined j["client"]["stream"]["pkts_recv"] = client_myds->pkts_recv; j["client"]["stream"]["pkts_sent"] = client_myds->pkts_sent; @@ -1820,7 +1826,7 @@ bool PgSQL_Session::handler_again___status_CONNECTING_SERVER(int* _rc) { NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); } else { - __exit_handler_again___status_CONNECTING_SERVER_with_err: +__exit_handler_again___status_CONNECTING_SERVER_with_err: bool is_error_present = myconn->is_error_present(); if (is_error_present) { client_myds->myprot.generate_error_packet(true, true, myconn->error_info.message.c_str(), @@ -5463,26 +5469,26 @@ void PgSQL_Session::RequestEnd(PgSQL_Data_Stream* myds, const unsigned int myerr if (status != PROCESSING_STMT_EXECUTE) { qdt = CurrentQuery.get_digest_text(); - } - else { + } else { qdt = CurrentQuery.stmt_info->digest_text; } + // is savepoint currently present in transaction. + int savepoint_count = -1; // haven't checked yet + + // we do not maintain the transaction variable state if the session is locked on a hostgroup + // or is a Fast Forward session. + if (locked_on_hostgroup == -1 && session_fast_forward == SESSION_FORWARD_TYPE_NONE) { + transaction_state_manager->handle_transaction(qdt); + savepoint_count = transaction_state_manager->get_savepoint_count(); + } + if (qdt && myds && myds->myconn) { - myds->myconn->ProcessQueryAndSetStatusFlags(qdt); + myds->myconn->ProcessQueryAndSetStatusFlags(qdt, savepoint_count); } - switch (status) { - /*case PROCESSING_STMT_EXECUTE: - case PROCESSING_STMT_PREPARE: - // if a prepared statement is executed, LogQuery was already called - break; - */ - default: - if (session_fast_forward == SESSION_FORWARD_TYPE_NONE) { - LogQuery(myds); - } - break; + if (session_fast_forward == SESSION_FORWARD_TYPE_NONE) { + LogQuery(myds); } GloPgQPro->delete_QP_out(qpo); From 9910a18e844b4d088303e364a324f1724f086894 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 15:02:27 +0500 Subject: [PATCH 09/16] Some optimisation --- lib/Base_Session.cpp | 36 +++++++++++++++++++++--------------- lib/MySQL_Session.cpp | 1 - 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/lib/Base_Session.cpp b/lib/Base_Session.cpp index 8976ca84b..06170f5fb 100644 --- a/lib/Base_Session.cpp +++ b/lib/Base_Session.cpp @@ -330,7 +330,10 @@ void Base_Session::return_proxysql_internal(PtrSize_t* pkt) { char* pta[1]; pta[0] = (char*)s.c_str(); resultset->add_row(pta); - SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset, nullptr, 0, (const char*)pkt->ptr + 5); + + unsigned int nTxn = NumActiveTransactions(); + char txn_state = (nTxn ? 'T' : 'I'); + SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset, nullptr, 0, (const char*)pkt->ptr + 5, txn_state); delete resultset; l_free(pkt->size, pkt->ptr); return; @@ -599,12 +602,14 @@ unsigned int Base_Session::NumActiveTransactions(bool check_savepoint) if (_mybe->server_myds->myconn->IsActiveTransaction()) { ret++; } else { - // we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due - // to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to - // SAVEPOINT and autocommit=0 - if (check_savepoint) { - if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) { - ret++; + if constexpr (std::is_same_v) { + // we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due + // to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to + // SAVEPOINT and autocommit=0 + if (check_savepoint) { + if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) { + ret++; + } } } } @@ -663,17 +668,18 @@ int Base_Session::FindOneActiveTransaction(bool check_savepoint) { if (_mybe->server_myds->myconn) { if (_mybe->server_myds->myconn->IsKnownActiveTransaction()) { return (int)_mybe->server_myds->myconn->parent->myhgc->hid; - } - else if (_mybe->server_myds->myconn->IsActiveTransaction()) { + } else if (_mybe->server_myds->myconn->IsActiveTransaction()) { ret = (int)_mybe->server_myds->myconn->parent->myhgc->hid; } else { - // we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due - // to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to - // SAVEPOINT and autocommit=0 - if (check_savepoint) { - if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) { - return (int)_mybe->server_myds->myconn->parent->myhgc->hid; + if constexpr (std::is_same_v) { + // we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due + // to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to + // SAVEPOINT and autocommit=0 + if (check_savepoint) { + if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) { + return (int)_mybe->server_myds->myconn->parent->myhgc->hid; + } } } } diff --git a/lib/MySQL_Session.cpp b/lib/MySQL_Session.cpp index 888cb05de..d85a39cba 100644 --- a/lib/MySQL_Session.cpp +++ b/lib/MySQL_Session.cpp @@ -2956,7 +2956,6 @@ __exit_handler_again___status_CONNECTING_SERVER_with_err: sprintf(sqlstate,"%s",mysql_sqlstate(myconn->mysql)); client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,mysql_errno(myconn->mysql),sqlstate, errmsg.c_str(), true); } else { - char buf[256]; errmsg = "Max connect failure while reaching hostgroup " + to_string(current_hostgroup); client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,9002,(char *)"HY000", errmsg.c_str(), true); if (thread) { From afbd12767b2bcf576d118f73716e2bcb91f1afe8 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 30 Apr 2025 15:05:06 +0500 Subject: [PATCH 10/16] Added newline --- include/PgSQL_ExplicitTxnStateMgr.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/PgSQL_ExplicitTxnStateMgr.h b/include/PgSQL_ExplicitTxnStateMgr.h index 1d91a4db0..5594e24f3 100644 --- a/include/PgSQL_ExplicitTxnStateMgr.h +++ b/include/PgSQL_ExplicitTxnStateMgr.h @@ -109,4 +109,4 @@ private: static void reset_variable_snapshot(PgSQL_Variable_Snapshot_t& var_snapshot) noexcept; }; -#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H \ No newline at end of file +#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H From c71aa550df0282e526980589a6078bee013f4a69 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 6 May 2025 11:00:19 +0500 Subject: [PATCH 11/16] Added names (alias of client_encoding) --- lib/PgSQL_Session.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 21767a99a..a66cf63b1 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -81,8 +81,9 @@ static inline char is_normal_char(char c) { } */ -static const std::array pgsql_critical_variables = { +static const std::array pgsql_critical_variables = { "client_encoding", + "names", "datestyle", "intervalstyle", "standard_conforming_strings", From 9559703a022e80b0f31da25c5293a3a985327bb5 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 6 May 2025 11:01:03 +0500 Subject: [PATCH 12/16] Removed extra semicolon --- lib/PgSQL_Query_Processor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/PgSQL_Query_Processor.cpp b/lib/PgSQL_Query_Processor.cpp index 257639f86..68e373987 100644 --- a/lib/PgSQL_Query_Processor.cpp +++ b/lib/PgSQL_Query_Processor.cpp @@ -763,7 +763,7 @@ __remove_parenthesis: break; } if (!strcasecmp("ABORT", token)) { - ret = PGSQL_QUERY_ROLLBACK;; + ret = PGSQL_QUERY_ROLLBACK; break; } break; From 7c7cbd0bc1022d9cd5f78a8d49fddbae1ff194b1 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 6 May 2025 11:44:41 +0500 Subject: [PATCH 13/16] Few improvements --- include/PgSQL_ExplicitTxnStateMgr.h | 224 ++++++++++++++-------------- lib/PgSQL_ExplicitTxnStateMgr.cpp | 21 ++- 2 files changed, 127 insertions(+), 118 deletions(-) diff --git a/include/PgSQL_ExplicitTxnStateMgr.h b/include/PgSQL_ExplicitTxnStateMgr.h index 5594e24f3..ea50cb28c 100644 --- a/include/PgSQL_ExplicitTxnStateMgr.h +++ b/include/PgSQL_ExplicitTxnStateMgr.h @@ -1,112 +1,112 @@ -#ifndef PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H -#define PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H -#include -#include -#include -#include "proxysql.h" -#include "cpp.h" -#include "PgSQL_Connection.h" - -#ifndef PROXYJSON -#define PROXYJSON -#include "../deps/json/json_fwd.hpp" -#endif // PROXYJSON - -/** - * @struct PgSQL_Variable_Snapshot - * @brief Represents a snapshot of PostgreSQL variables during a transaction. - * - * This structure is used to store the state of PostgreSQL variables, including - * their values and hash representations, at a specific point in time during a transaction. - */ -typedef struct PgSQL_Variable_Snapshot { - char* var_value[PGSQL_NAME_LAST_HIGH_WM] = {}; - uint32_t var_hash[PGSQL_NAME_LAST_HIGH_WM] = {}; -} PgSQL_Variable_Snapshot_t; - -/** - * @struct TxnCmd - * @brief Represents a transaction command type begin executed and its associated metadata. - * - */ -struct TxnCmd { - /** - * @enum Type - * @brief Enumerates the types of transaction commands. - */ - enum Type { - UNKNOWN = -1, - BEGIN, - COMMIT, - ROLLBACK, - SAVEPOINT, - RELEASE, - ROLLBACK_TO - } type = Type::UNKNOWN; - std::string savepoint; //< The name of the savepoint, if applicable. -}; - -/** - * @class PgSQL_TxnCmdParser - * @brief Parses transaction-related commands for PostgreSQL. - * - * This class is responsible for tokenizing and interpreting transaction-related - * commands such as BEGIN, COMMIT, ROLLBACK, SAVEPOINT, etc. - */ -class PgSQL_TxnCmdParser { -public: - TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept; - -private: - std::vector tokens; - - TxnCmd parse_rollback(size_t& pos) noexcept; - TxnCmd parse_savepoint(size_t& pos) noexcept; - TxnCmd parse_release(size_t& pos) noexcept; - - // Helpers - static std::string to_lower(std::string_view s) noexcept { - std::string s_copy(s); - std::transform(s_copy.begin(), s_copy.end(), s_copy.begin(), ::tolower); - return s_copy; - } - - inline static bool contains(std::vector&& list, std::string_view value) noexcept { - for (const auto& item : list) if (item == value) return true; - return false; - } -}; - -/** - * @class PgSQL_ExplicitTxnStateMgr - * @brief Manages the state of explicit transactions in PostgreSQL. - * - * This class is responsible for handling explicit transaction commands such as - * BEGIN, COMMIT, ROLLBACK, SAVEPOINT, and managing the associated state. - */ -class PgSQL_ExplicitTxnStateMgr { -public: - PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess); - ~PgSQL_ExplicitTxnStateMgr(); - - bool handle_transaction(std::string_view input); - int get_savepoint_count() const { return savepoint.size(); } - void fill_internal_session(nlohmann::json& j); - -private: - PgSQL_Session* session; - std::vector transaction_state; - std::vector savepoint; - PgSQL_TxnCmdParser tx_parser; - - void start_transaction(); - void commit(); - void rollback(); - bool add_savepoint(std::string_view name); - bool rollback_to_savepoint(std::string_view name); - bool release_savepoint(std::string_view name); - - static void reset_variable_snapshot(PgSQL_Variable_Snapshot_t& var_snapshot) noexcept; -}; - -#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H +#ifndef PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H +#define PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H +#include +#include +#include +#include "proxysql.h" +#include "cpp.h" +#include "PgSQL_Connection.h" + +#ifndef PROXYJSON +#define PROXYJSON +#include "../deps/json/json_fwd.hpp" +#endif // PROXYJSON + +/** + * @struct PgSQL_Variable_Snapshot + * @brief Represents a snapshot of PostgreSQL variables during a transaction. + * + * This structure is used to store the state of PostgreSQL variables, including + * their values and hash representations, at a specific point in time during a transaction. + */ +struct PgSQL_Variable_Snapshot { + char* var_value[PGSQL_NAME_LAST_HIGH_WM] = {}; // Not using smart pointers because we need fine-grained control over hashing when values change + uint32_t var_hash[PGSQL_NAME_LAST_HIGH_WM] = {}; +}; + +/** + * @struct TxnCmd + * @brief Represents a transaction command type begin executed and its associated metadata. + * + */ +struct TxnCmd { + /** + * @enum Type + * @brief Enumerates the types of transaction commands. + */ + enum Type { + UNKNOWN = -1, + BEGIN, + COMMIT, + ROLLBACK, + SAVEPOINT, + RELEASE, + ROLLBACK_TO + } type = Type::UNKNOWN; + std::string savepoint; //< The name of the savepoint, if applicable. +}; + +/** + * @class PgSQL_TxnCmdParser + * @brief Parses transaction-related commands for PostgreSQL. + * + * This class is responsible for tokenizing and interpreting transaction-related + * commands such as BEGIN, COMMIT, ROLLBACK, SAVEPOINT, etc. + */ +class PgSQL_TxnCmdParser { +public: + TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept; + +private: + std::vector tokens; + + TxnCmd parse_rollback(size_t& pos) noexcept; + TxnCmd parse_savepoint(size_t& pos) noexcept; + TxnCmd parse_release(size_t& pos) noexcept; + + // Helpers + static std::string to_lower(std::string_view s) noexcept { + std::string s_copy(s); + std::transform(s_copy.begin(), s_copy.end(), s_copy.begin(), ::tolower); + return s_copy; + } + + inline static bool contains(std::vector&& list, std::string_view value) noexcept { + for (const auto& item : list) if (item == value) return true; + return false; + } +}; + +/** + * @class PgSQL_ExplicitTxnStateMgr + * @brief Manages the state of explicit transactions in PostgreSQL. + * + * This class is responsible for handling explicit transaction commands such as + * BEGIN, COMMIT, ROLLBACK, SAVEPOINT, and managing the associated state. + */ +class PgSQL_ExplicitTxnStateMgr { +public: + PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess); + ~PgSQL_ExplicitTxnStateMgr(); + + bool handle_transaction(std::string_view input); + int get_savepoint_count() const { return savepoint.size(); } + void fill_internal_session(nlohmann::json& j); + +private: + PgSQL_Session* session; + std::vector transaction_state; + std::vector savepoint; + PgSQL_TxnCmdParser tx_parser; + + void start_transaction(); + void commit(); + void rollback(); + bool add_savepoint(std::string_view name); + bool rollback_to_savepoint(std::string_view name); + bool release_savepoint(std::string_view name); + + static void reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept; +}; + +#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H diff --git a/lib/PgSQL_ExplicitTxnStateMgr.cpp b/lib/PgSQL_ExplicitTxnStateMgr.cpp index f29cb1cc9..268a11101 100644 --- a/lib/PgSQL_ExplicitTxnStateMgr.cpp +++ b/lib/PgSQL_ExplicitTxnStateMgr.cpp @@ -1,4 +1,4 @@ -#include "PgSQL_ExplicitTxnStateMgr.h" +#include "PgSQL_ExplicitTxnStateMgr.h" #include "proxysql.h" #include "PgSQL_Session.h" #include "PgSQL_Data_Stream.h" @@ -18,7 +18,7 @@ PgSQL_ExplicitTxnStateMgr::~PgSQL_ExplicitTxnStateMgr() { savepoint.clear(); } -void PgSQL_ExplicitTxnStateMgr::reset_variable_snapshot(PgSQL_Variable_Snapshot_t& var_snapshot) noexcept { +void PgSQL_ExplicitTxnStateMgr::reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept { for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { if (var_snapshot.var_value[idx]) { free(var_snapshot.var_value[idx]); @@ -50,7 +50,7 @@ void PgSQL_ExplicitTxnStateMgr::start_transaction() { assert(session->client_myds && session->client_myds->myconn); - PgSQL_Variable_Snapshot_t var_snapshot{}; + PgSQL_Variable_Snapshot var_snapshot{}; // check if already in transaction, if yes then do nothing for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { @@ -96,7 +96,7 @@ void PgSQL_ExplicitTxnStateMgr::rollback() { assert(session->client_myds && session->client_myds->myconn); - const PgSQL_Variable_Snapshot_t& var_snapshot = transaction_state.front(); + const PgSQL_Variable_Snapshot& var_snapshot = transaction_state.front(); for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = var_snapshot.var_hash[idx]; @@ -159,7 +159,7 @@ bool PgSQL_ExplicitTxnStateMgr::rollback_to_savepoint(std::string_view name) { assert(tran_state_idx + 1 < (int)transaction_state.size()); - PgSQL_Variable_Snapshot_t& var_snapshot = transaction_state[tran_state_idx+1]; + PgSQL_Variable_Snapshot& var_snapshot = transaction_state[tran_state_idx+1]; for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = var_snapshot.var_hash[idx]; if (hash != 0) { @@ -177,6 +177,14 @@ bool PgSQL_ExplicitTxnStateMgr::rollback_to_savepoint(std::string_view name) { pgsql_variables.server_reset_value(session, idx, false); } } + + session->client_myds->myconn->reorder_dynamic_variables_idx(); + if (session->mybe) { + session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); + + verify_server_variables(session); + } + for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) { reset_variable_snapshot(transaction_state[idx]); } @@ -236,7 +244,7 @@ bool PgSQL_ExplicitTxnStateMgr::add_savepoint(std::string_view name) { }); if (it != savepoint.end()) return false; - PgSQL_Variable_Snapshot_t var_snapshot{}; + PgSQL_Variable_Snapshot var_snapshot{}; for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = pgsql_variables.client_get_hash(session, idx); @@ -344,6 +352,7 @@ TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mod if (first == "begin") cmd.type = TxnCmd::BEGIN; else if (first == "savepoint") cmd = parse_savepoint(pos); else if (first == "release") cmd = parse_release(pos); + else if (first == "rollback") cmd = parse_rollback(pos); } else { if (first == "commit") cmd.type = TxnCmd::COMMIT; else if (first == "rollback" || (first == "abort")) cmd = parse_rollback(pos); From 0adf56eff1f764bb88a5ffe9584575d70670ff70 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 6 May 2025 11:46:02 +0500 Subject: [PATCH 14/16] Improved set parser. SET LOCAL will now lock on hostgroup. --- lib/PgSQL_Set_Stmt_Parser.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/PgSQL_Set_Stmt_Parser.cpp b/lib/PgSQL_Set_Stmt_Parser.cpp index 3f92ef12a..159924dc2 100644 --- a/lib/PgSQL_Set_Stmt_Parser.cpp +++ b/lib/PgSQL_Set_Stmt_Parser.cpp @@ -74,7 +74,8 @@ void PgSQL_Set_Stmt_Parser::generateRE_parse1v2() { // Function Call: Check if Group 3 is populated. // Literal: Check if Group 4 is populated. //const std::string pattern = "(?:(SESSION|LOCAL)\\s+)?((?:\\S+(?:\\s+\\S+)*?))(?:\\s+(?:TO|=)\\s+|\\s+)(?:(\\w+\\s*\\([^)]*\\))|((?:'(?:''|[^'])*'|-?\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?|t|true|f|false|on|off|default|\\S+)))\\s*;?"; - const std::string pattern = "(?:(SESSION|LOCAL)\\s+)?((?:\\S+(?:\\s+\\S+)*?))(?:\\s*(?:TO|=)\\s*|\\s+)(?:(\\w+\\s*\\([^)]*\\))|((?:'(?:''|[^'])*'|-?\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?|true|t|1|yes|false|f|0|no|on|off|default|\\S+)))\\s*;?"; + //const std::string pattern = "(?:(SESSION|LOCAL)\\s+)?((?:\\S+(?:\\s+\\S+)*?))(?:\\s*(?:TO|=)\\s*|\\s+)(?:(\\w+\\s*\\([^)]*\\))|((?:'(?:''|[^'])*'|-?\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?|true|t|1|yes|false|f|0|no|on|off|default|\\S+)))\\s*;?"; + const std::string pattern = R"((?:(SESSION)\s+)?((?:\S+(?:\s+\S+)*?))(?:\s*(?:TO|=)\s*|\s+)(?:(\w+\s*\([^)]*\))|((?:'(?:''|[^'])*'|-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?|[^;]+)))\s*;?)"; #ifdef DEBUG VALGRIND_DISABLE_ERROR_REPORTING; From f87fc302612ae704124ce704b4102fae2b5cccea Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 6 May 2025 11:52:07 +0500 Subject: [PATCH 15/16] Added TAP test --- ...-transaction_variable_state_tracking-t.cpp | 290 ++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp diff --git a/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp new file mode 100644 index 000000000..730a5218f --- /dev/null +++ b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp @@ -0,0 +1,290 @@ +/** + * @file pgsql-transaction_variable_state_tracking-t.cpp + * @brief TAP test validating PostgreSQL session parameter behavior in transactions. + * Tests rollback/commit/savepoint behavior for session variables to ensure state consistency. + */ + +#include +#include +#include +#include +#include +#include "libpq-fe.h" +#include "command_line.h" +#include "tap.h" +#include "utils.h" + +CommandLine cl; + +using PGConnPtr = std::unique_ptr; +using PGResultPtr = std::unique_ptr; + +enum ConnType { + ADMIN, + BACKEND +}; + +PGConnPtr createNewConnection(ConnType conn_type, const std::string& options = "", bool with_ssl = false) { + + const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host; + int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port; + const char* username = (conn_type == BACKEND) ? cl.pgsql_root_username : cl.admin_username; + const char* password = (conn_type == BACKEND) ? cl.pgsql_root_password : cl.admin_password; + + std::stringstream ss; + + ss << "host=" << host << " port=" << port; + ss << " user=" << username << " password=" << password; + ss << (with_ssl ? " sslmode=require" : " sslmode=disable"); + + if (options.empty() == false) { + ss << " options='" << options << "'"; + } + + PGconn* conn = PQconnectdb(ss.str().c_str()); + if (PQstatus(conn) != CONNECTION_OK) { + fprintf(stderr, "Connection failed to '%s': %s", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn)); + PQfinish(conn); + return PGConnPtr(nullptr, &PQfinish); + } + return PGConnPtr(conn, &PQfinish); +} + +struct TestCase { + std::string name; + std::function test_fn; + bool should_fail; +}; + +struct TestVariable { + std::string name; + std::vector test_values; +}; + +std::vector tests; + +PGResultPtr executeQuery(PGconn* conn, const std::string& query) { + diag("Executing: %s", query.c_str()); + PGresult* res = PQexec(conn, query.c_str()); + if (PQresultStatus(res) != PGRES_COMMAND_OK && PQresultStatus(res) != PGRES_TUPLES_OK) { + diag("Query failed: %s", PQerrorMessage(conn)); + } + return PGResultPtr(res, &PQclear); +} + +std::string getVariable(PGconn* conn, const std::string& var) { + auto res = executeQuery(conn, ("SHOW " + var)); + return std::string(PQgetvalue(res.get(), 0, 0)); +} + +void reset_variable(PGconn* conn, const std::string& var, const std::string& original) { + executeQuery(conn, "SET " + var + " = " + original); +} + +void add_test(const std::string& name, std::function fn, bool should_fail = false) { + tests.push_back({ name, fn, should_fail }); +} + +void run_tests() { + + for (const auto& test : tests) { + bool result = false; + + try { + result = test.test_fn(); + if (test.should_fail) result = !result; + } + catch (const std::exception& e) { + result = false; + } + + ok(result, "Test:%s should %s", test.name.c_str(), test.should_fail ? "FAIL" : "PASS"); + } +} + +// Common test scenarios +bool test_transaction_rollback(const TestVariable& var) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original = getVariable(conn.get(), var.name); + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + executeQuery(conn.get(), "ROLLBACK"); + + const bool success = getVariable(conn.get(), var.name) == original; + return success; +} + +bool test_savepoint_rollback(const TestVariable& var) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original = getVariable(conn.get(), var.name); + diag(">>>>> Original value:'%s'", original.c_str()); + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + executeQuery(conn.get(), "ROLLBACK TO sp1"); + executeQuery(conn.get(), "COMMIT"); + auto value = getVariable(conn.get(), var.name); + const bool success = value == original; + diag(">>>>> Rollback value:'%s'", value.c_str()); + return success; +} + +bool test_transaction_commit(const TestVariable& var, const std::map& original_values) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto test_value = var.test_values[0]; + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET " + var.name + " = " + test_value); + executeQuery(conn.get(), "COMMIT"); + + const bool success = getVariable(conn.get(), var.name) == test_value; + reset_variable(conn.get(), var.name, original_values.at(var.name)); + return success; +} + +bool test_savepoint_commit(const TestVariable& var, const std::map& original_values) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto test_value = var.test_values[0]; + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET " + var.name + " = " + test_value); + executeQuery(conn.get(), "RELEASE SAVEPOINT sp1"); + executeQuery(conn.get(), "COMMIT"); + + const bool success = getVariable(conn.get(), var.name) == test_value; + reset_variable(conn.get(), var.name, original_values.at(var.name)); + return success; +} + +bool test_savepoint_rollback_partial(const TestVariable& var, const std::map& original_values) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original = getVariable(conn.get(), var.name); + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[1]); + executeQuery(conn.get(), "RELEASE SAVEPOINT sp1"); + executeQuery(conn.get(), "COMMIT"); + + const bool success = getVariable(conn.get(), var.name) == var.test_values[1]; + reset_variable(conn.get(), var.name, original_values.at(var.name)); + return success; +} + +int main(int argc, char** argv) { + + if (cl.getEnv()) + return exit_status(); + + std::map original_values; + std::map tracked_vars = { + {"client_encoding", {"client_encoding", {"LATIN1", "UTF8"}}}, + {"datestyle", {"datestyle", {"ISO, MDY", "SQL, DMY"}}}, + {"intervalstyle", {"intervalstyle", {"postgres", "iso_8601"}}}, + {"standard_conforming_strings", {"standard_conforming_strings", {"on", "off"}}}, + {"timezone", {"timezone", {"UTC", "PST8PDT"}}}, + {"bytea_output", {"bytea_output", {"hex", "escape"}}}, + {"allow_in_place_tablespaces", {"allow_in_place_tablespaces", {"on", "off"}}}, + {"enable_bitmapscan", {"enable_bitmapscan", {"on", "off"}}}, + {"enable_hashjoin", {"enable_hashjoin", {"on", "off"}}}, + {"enable_indexscan", {"enable_indexscan", {"on", "off"}}}, + {"enable_nestloop", {"enable_nestloop", {"on", "off"}}}, + {"enable_seqscan", {"enable_seqscan", {"on", "off"}}}, + {"enable_sort", {"enable_sort", {"on", "off"}}}, + {"escape_string_warning", {"escape_string_warning", {"on", "off"}}}, + {"synchronous_commit", {"synchronous_commit", {"on", "off"}}}, + {"extra_float_digits", {"extra_float_digits", {"0", "3"}}}, + {"client_min_messages", {"client_min_messages", {"notice", "warning"}}} + }; + + + PGConnPtr conn = createNewConnection(ConnType::BACKEND, "", false); + if (!conn || PQstatus(conn.get()) != CONNECTION_OK) { + BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + + // Store original values + for (const auto& [name, var] : tracked_vars) { + original_values[name] = getVariable(conn.get(), name); + } + + + // Add generic tests + add_test("Commit without transaction should fail", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + PGresult* res = PQexec(conn.get(), "COMMIT"); + const bool failed = PQresultStatus(res) != PGRES_COMMAND_OK; + PQclear(res); + return failed; + }, true); + + // Add variable-specific tests using containers + for (const auto& [name, var] : tracked_vars) { + add_test("Rollback reverts " + var.name, [var]() { + return test_transaction_rollback(var); + }); + + add_test("Commit persists " + var.name, [&]() { + return test_transaction_commit(var, original_values); + }); + + add_test("Savepoint rollback for " + var.name, [var]() { + return test_savepoint_rollback(var); + }); + + add_test("Savepoint commit for " + var.name, [&]() { + return test_savepoint_commit(var, original_values); + }); + + // Multi-value savepoint test + if (var.test_values.size() > 1) { + add_test("Multi-value savepoint for " + var.name, [&]() { + return test_savepoint_rollback_partial(var, original_values); + }); + } + } + + // Add complex scenario tests + add_test("Nested BEGIN with rollback", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original_tz = getVariable(conn.get(), "timezone"); + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET timezone = 'UTC'"); + executeQuery(conn.get(), "BEGIN"); // Second BEGIN + executeQuery(conn.get(), "SET timezone = 'PST8PDT'"); + executeQuery(conn.get(), "ROLLBACK"); + + const bool success = getVariable(conn.get(), "timezone") == original_tz; + return success; + }); + + add_test("Mixed variables in transaction", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + bool success = true; + + executeQuery(conn.get(), "BEGIN"); + for (const auto& [name, var] : tracked_vars) { + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + } + executeQuery(conn.get(), "ROLLBACK"); + + for (const auto& [name, var] : tracked_vars) { + success = (getVariable(conn.get(), var.name) == original_values.at(var.name)); + } + return success; + }); + + int total_tests = 0; + + total_tests = tests.size(); + plan(total_tests); + + run_tests(); + + return exit_status(); +} From 4d5ca5adbedeb14a4dd7df82bd450b6330fd1249 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 6 May 2025 12:38:46 +0500 Subject: [PATCH 16/16] Improved TAP test logging --- ...sql-transaction_variable_state_tracking-t.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp index 730a5218f..605ce1ac2 100644 --- a/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp +++ b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp @@ -74,7 +74,9 @@ PGResultPtr executeQuery(PGconn* conn, const std::string& query) { std::string getVariable(PGconn* conn, const std::string& var) { auto res = executeQuery(conn, ("SHOW " + var)); - return std::string(PQgetvalue(res.get(), 0, 0)); + const std::string& val = std::string(PQgetvalue(res.get(), 0, 0)); + diag(">> '%s' = '%s'", var.c_str(), val.c_str()); + return val; } void reset_variable(PGconn* conn, const std::string& var, const std::string& original) { @@ -118,15 +120,13 @@ bool test_transaction_rollback(const TestVariable& var) { bool test_savepoint_rollback(const TestVariable& var) { auto conn = createNewConnection(ConnType::BACKEND, "", false); const auto original = getVariable(conn.get(), var.name); - diag(">>>>> Original value:'%s'", original.c_str()); executeQuery(conn.get(), "BEGIN"); executeQuery(conn.get(), "SAVEPOINT sp1"); executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); executeQuery(conn.get(), "ROLLBACK TO sp1"); executeQuery(conn.get(), "COMMIT"); - auto value = getVariable(conn.get(), var.name); - const bool success = value == original; - diag(">>>>> Rollback value:'%s'", value.c_str()); + + const bool success = getVariable(conn.get(), var.name) == original; return success; } @@ -158,17 +158,15 @@ bool test_savepoint_commit(const TestVariable& var, const std::map& original_values) { +bool test_savepoint_release_commit(const TestVariable& var, const std::map& original_values) { auto conn = createNewConnection(ConnType::BACKEND, "", false); const auto original = getVariable(conn.get(), var.name); - executeQuery(conn.get(), "BEGIN"); executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); executeQuery(conn.get(), "SAVEPOINT sp1"); executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[1]); executeQuery(conn.get(), "RELEASE SAVEPOINT sp1"); executeQuery(conn.get(), "COMMIT"); - const bool success = getVariable(conn.get(), var.name) == var.test_values[1]; reset_variable(conn.get(), var.name, original_values.at(var.name)); return success; @@ -243,7 +241,7 @@ int main(int argc, char** argv) { // Multi-value savepoint test if (var.test_values.size() > 1) { add_test("Multi-value savepoint for " + var.name, [&]() { - return test_savepoint_rollback_partial(var, original_values); + return test_savepoint_release_commit(var, original_values); }); } }