#include "PgSQL_ExplicitTxnStateMgr.h" #include "proxysql.h" #include "PgSQL_Session.h" #include "PgSQL_Data_Stream.h" #include "PgSQL_Connection.h" extern class PgSQL_Variables pgsql_variables; PgSQL_ExplicitTxnStateMgr::PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess) : session(sess) { } PgSQL_ExplicitTxnStateMgr::~PgSQL_ExplicitTxnStateMgr() { for (auto& tran_state : transaction_state) { reset_variable_snapshot(tran_state); } transaction_state.clear(); savepoint.clear(); } void PgSQL_ExplicitTxnStateMgr::reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept { for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { if (var_snapshot.var_value[idx]) { free(var_snapshot.var_value[idx]); var_snapshot.var_value[idx] = nullptr; } var_snapshot.var_hash[idx] = 0; } } void verify_server_variables(PgSQL_Session* session) { #ifdef DEBUG for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { const char* conn_param_status = session->mybe->server_myds->myconn->get_pg_parameter_status(pgsql_tracked_variables[idx].set_variable_name); const char* param_value = session->mybe->server_myds->myconn->variables[idx].value; if (conn_param_status && param_value) { //assert(strcmp(conn_param_status, param_value) == 0); if (strcmp(conn_param_status, param_value) != 0) { // This isn’t actually a bug, but it can occur in an edge case — for example, when a COPY FROM STDIN fails. // In that situation, the ParameterStatus message sent from the server is received and forwarded to the client // via fast-forwarding, so the internal ParameterStatus in libpq isn’t updated. proxy_warning("Server variable '%s' mismatch. Parameter status value: '%s', Expected value: '%s'\n", pgsql_tracked_variables[idx].set_variable_name, conn_param_status, param_value); } } } #endif } void PgSQL_ExplicitTxnStateMgr::start_transaction() { if (transaction_state.empty() == false) { // Transaction already started, do nothing and return proxy_warning("Received BEGIN command. There is already a transaction in progress\n"); assert(session->NumActiveTransactions() > 0); return; } assert(session->client_myds && session->client_myds->myconn); PgSQL_Variable_Snapshot var_snapshot{}; // check if already in transaction, if yes then do nothing for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = pgsql_variables.client_get_hash(session, idx); if (hash != 0) { var_snapshot.var_hash[idx] = hash; var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx)); } else { assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null // no need to store the value //var_snapshot.var_hash[idx] = 0; //var_snapshot.var_value[idx] = NULL; } } transaction_state.emplace_back(std::move(var_snapshot)); } void PgSQL_ExplicitTxnStateMgr::commit() { if (transaction_state.empty()) { proxy_warning("Received COMMIT command. There is no transaction in progress\n"); assert(session->NumActiveTransactions() == 0); return; } assert(session->client_myds && session->client_myds->myconn); for (auto& tran_state : transaction_state) { reset_variable_snapshot(tran_state); } transaction_state.clear(); savepoint.clear(); verify_server_variables(session); } void PgSQL_ExplicitTxnStateMgr::rollback(bool rollback_and_chain) { if (transaction_state.empty()) { proxy_warning("Received ROLLBACK command. There is no transaction in progress\n"); assert(session->NumActiveTransactions() == 0); return; } assert(session->client_myds && session->client_myds->myconn); const PgSQL_Variable_Snapshot& var_snapshot = transaction_state.front(); for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = var_snapshot.var_hash[idx]; if (hash != 0) { uint32_t client_hash = pgsql_variables.client_get_hash(session, idx); uint32_t server_hash = pgsql_variables.server_get_hash(session, idx); assert(client_hash == server_hash); if (hash == client_hash) continue; pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); } else { assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null pgsql_variables.client_reset_value(session, idx, false); pgsql_variables.server_reset_value(session, idx, false); } } // reuse of connection that has extra param set in connection session->client_myds->myconn->reorder_dynamic_variables_idx(); if (session->mybe) { session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); verify_server_variables(session); } // Keep the transaction state intact when executing ROLLBACK AND CHAIN if (rollback_and_chain == false) { // Clear savepoints and reset the initial snapshot for (auto& tran_state : transaction_state) { reset_variable_snapshot(tran_state); } transaction_state.clear(); } savepoint.clear(); } bool PgSQL_ExplicitTxnStateMgr::rollback_to_savepoint(std::string_view name) { if (transaction_state.empty()) { proxy_warning("Received ROLLBACK TO SAVEPOINT '%s' command. There is no transaction in progress\n", name.data()); assert(session->NumActiveTransactions() == 0); return false; } assert(session->client_myds && session->client_myds->myconn); int tran_state_idx = -1; for (size_t idx = 0; idx < savepoint.size(); idx++) { if (savepoint[idx].size() == name.size() && strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) { tran_state_idx = idx; break; } } if (tran_state_idx == -1) { proxy_warning("Savepoint '%s' not found.\n", name.data()); return false; }; assert(tran_state_idx + 1 < (int)transaction_state.size()); PgSQL_Variable_Snapshot& var_snapshot = transaction_state[tran_state_idx+1]; for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = var_snapshot.var_hash[idx]; if (hash != 0) { uint32_t client_hash = pgsql_variables.client_get_hash(session, idx); uint32_t server_hash = pgsql_variables.server_get_hash(session, idx); assert(client_hash == server_hash); if (hash == client_hash) continue; pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash); } else { assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null pgsql_variables.client_reset_value(session, idx, false); pgsql_variables.server_reset_value(session, idx, false); } } session->client_myds->myconn->reorder_dynamic_variables_idx(); if (session->mybe) { session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); verify_server_variables(session); } for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) { reset_variable_snapshot(transaction_state[idx]); } transaction_state.resize(tran_state_idx + 1); savepoint.resize(tran_state_idx); return true; } bool PgSQL_ExplicitTxnStateMgr::release_savepoint(std::string_view name) { if (transaction_state.empty()) { proxy_warning("Received RELEASE SAVEPOINT '%s' command. There is no transaction in progress\n", name.data()); assert(session->NumActiveTransactions() == 0); return false; } assert(session->client_myds && session->client_myds->myconn); int tran_state_idx = -1; for (size_t idx = 0; idx < savepoint.size(); idx++) { if (savepoint[idx].size() == name.size() && strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) { tran_state_idx = idx; break; } } if (tran_state_idx == -1) { proxy_warning("SAVEPOINT '%s' not found.\n", name.data()); return false; }; for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) { reset_variable_snapshot(transaction_state[idx]); } transaction_state.resize(tran_state_idx + 1); savepoint.resize(tran_state_idx); return true; } bool PgSQL_ExplicitTxnStateMgr::add_savepoint(std::string_view name) { if (transaction_state.empty()) { proxy_warning("Received SAVEPOINT '%s' command. There is no transaction in progress\n", name.data()); assert(session->NumActiveTransactions() == 0); return false; } assert(session->client_myds && session->client_myds->myconn); auto it = std::find_if(savepoint.begin(), savepoint.end(), [name](std::string_view sp) { return sp.size() == name.size() && strncasecmp(sp.data(), name.data(), name.size()) == 0; }); if (it != savepoint.end()) return false; PgSQL_Variable_Snapshot var_snapshot{}; for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = pgsql_variables.client_get_hash(session, idx); if (hash != 0) { var_snapshot.var_hash[idx] = hash; var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx)); } } transaction_state.emplace_back(std::move(var_snapshot)); savepoint.emplace_back(name); assert((transaction_state.size() - 1) == savepoint.size()); return true; } void PgSQL_ExplicitTxnStateMgr::fill_internal_session(nlohmann::json& j) { if (transaction_state.empty()) return; auto& initial_state = j["initial_state"]; for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) { uint32_t hash = transaction_state[0].var_hash[idx]; if (hash != 0) { initial_state[pgsql_tracked_variables[idx].set_variable_name] = transaction_state[0].var_value[idx]; } } if (savepoint.empty()) return; for (size_t idx = 0; idx < savepoint.size(); idx++) { auto& savepoint_json = j["savepoints"][savepoint[idx]]; int tran_state_idx = idx + 1; for (int idx2 = 0; idx2 < PGSQL_NAME_LAST_HIGH_WM; idx2++) { uint32_t hash = transaction_state[tran_state_idx].var_hash[idx2]; if (hash != 0) { savepoint_json[pgsql_tracked_variables[idx2].set_variable_name] = transaction_state[tran_state_idx].var_value[idx2]; } } } } bool PgSQL_ExplicitTxnStateMgr::handle_transaction(std::string_view input) { TxnCmd cmd = tx_parser.parse(input, (session->active_transactions > 0)); switch (cmd.type) { case TxnCmd::BEGIN: start_transaction(); break; case TxnCmd::COMMIT: commit(); break; case TxnCmd::ROLLBACK: rollback(false); break; case TxnCmd::ROLLBACK_AND_CHAIN: rollback(true); break; case TxnCmd::SAVEPOINT: return add_savepoint(cmd.savepoint); case TxnCmd::RELEASE: return release_savepoint(cmd.savepoint); case TxnCmd::ROLLBACK_TO: return rollback_to_savepoint(cmd.savepoint); default: // Unknown command return false; } return true; } TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mode) noexcept { TxnCmd cmd; if (input.empty()) return cmd; // Extract first word without full tokenization size_t start = 0; size_t end = 0; while (start < input.size() && fast_isspace(input[start])) { start++; } if (start >= input.size()) return cmd; // Find end of first word end = start; bool in_quote = false; char quote_char = 0; while (end < input.size()) { char c = input[end]; if (!in_quote && (c == '"' || c == '\'')) { // If we hit a quote at the start, this isn't a transaction command return cmd; } if (fast_isspace(c) || c == ';') { break; } end++; } std::string_view first_word = input.substr(start, end - start); // Check if this is a transaction command we care about TxnCmd::Type cmd_type = TxnCmd::UNKNOWN; if (in_transaction_mode) { if (iequals(first_word, "begin")) { cmd.type = TxnCmd::BEGIN; return cmd; } if (iequals(first_word, "start")) { cmd_type = TxnCmd::BEGIN; } else if (iequals(first_word, "savepoint")) { cmd_type = TxnCmd::SAVEPOINT; } else if (iequals(first_word, "release")) { cmd_type = TxnCmd::RELEASE; } else if (iequals(first_word, "rollback")) { cmd_type = TxnCmd::ROLLBACK; } } else { if (iequals(first_word, "commit") || iequals(first_word, "end")) { cmd.type = TxnCmd::COMMIT; return cmd; } if (iequals(first_word, "abort")) { cmd.type = TxnCmd::ROLLBACK; return cmd; } if (iequals(first_word, "rollback")) { cmd_type = TxnCmd::ROLLBACK; } } // If not a transaction command, return early if (cmd_type == TxnCmd::UNKNOWN) { return cmd; } // Continue tokenization from where we left off tokens.clear(); // Continue tokenizing the rest of the input in_quote = false; quote_char = 0; start = end; // Continue from after the first word while (start < input.size() && fast_isspace(input[start])) { start++; } // Tokenize the remaining input for (size_t i = start; i <= input.size(); ++i) { const bool at_end = i == input.size(); const char c = at_end ? 0 : input[i]; if (in_quote) { if (c == quote_char || at_end) { tokens.emplace_back(input.substr(start + 1, i - start - 1)); in_quote = false; start = i + 1; } continue; } if (c == '"' || c == '\'') { in_quote = true; quote_char = c; start = i; } else if (fast_isspace(c) || c == ';' || at_end) { if (start < i) tokens.emplace_back(input.substr(start, i - start)); start = i + 1; } } size_t pos = 0; if (in_transaction_mode) { switch (cmd_type) { case TxnCmd::BEGIN: cmd = parse_start(pos); break; case TxnCmd::SAVEPOINT: cmd = parse_savepoint(pos); break; case TxnCmd::RELEASE: cmd = parse_release(pos); break; case TxnCmd::ROLLBACK: cmd = parse_rollback(pos); break; default: break; } } else { if (cmd_type == TxnCmd::ROLLBACK) cmd = parse_rollback(pos); } return cmd; } TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept { TxnCmd cmd{ TxnCmd::ROLLBACK }; while (pos < tokens.size() && contains({ "work", "transaction" }, tokens[pos])) pos++; if (pos < tokens.size() && iequals(tokens[pos], "to")) { cmd.type = TxnCmd::ROLLBACK_TO; if (++pos < tokens.size() && iequals(tokens[pos], "savepoint")) pos++; if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; } else if (pos < tokens.size() && iequals(tokens[pos], "and")) { if (++pos < tokens.size() && iequals(tokens[pos], "chain")) { cmd.type = TxnCmd::ROLLBACK_AND_CHAIN; pos++; } } return cmd; } TxnCmd PgSQL_TxnCmdParser::parse_savepoint(size_t& pos) noexcept { TxnCmd cmd{ TxnCmd::SAVEPOINT }; if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; return cmd; } TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept { TxnCmd cmd{ TxnCmd::RELEASE }; if (pos < tokens.size() && iequals(tokens[pos], "savepoint")) pos++; if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; return cmd; } TxnCmd PgSQL_TxnCmdParser::parse_start(size_t& pos) noexcept { TxnCmd cmd{ TxnCmd::UNKNOWN }; if (pos < tokens.size() && iequals(tokens[pos], "transaction")) { cmd.type = TxnCmd::BEGIN; pos++; } return cmd; }