Merge pull request #4929 from sysown/v3.0_track_transaction_param_state_4907

Track Variable State Across Transactions and Savepoints - v3.0
pull/4941/head
René Cannaò 1 year ago committed by GitHub
commit 9f4b139e6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -363,7 +363,7 @@ public:
bool set_single_row_mode();
void update_bytes_recv(uint64_t bytes_recv);
void update_bytes_sent(uint64_t bytes_sent);
void ProcessQueryAndSetStatusFlags(char* query_digest_text);
void ProcessQueryAndSetStatusFlags(char* query_digest_text, int savepoint_count);
inline const PGconn* get_pg_connection() const { return pgsql_conn; }
inline int get_pg_server_version() { return PQserverVersion(pgsql_conn); }
@ -388,6 +388,7 @@ public:
inline int get_pg_is_threadsafe() { return PQisthreadsafe(); }
inline const char* get_pg_error_message() { return PQerrorMessage(pgsql_conn); }
inline SSL* get_pg_ssl_object() { return (SSL*)PQsslStruct(pgsql_conn, "OpenSSL"); }
inline const char* get_pg_parameter_status(const char* param) { return PQparameterStatus(pgsql_conn, param); }
const char* get_pg_server_version_str(char* buff, int buff_size);
const char* get_pg_connection_status_str();
const char* get_pg_transaction_status_str();

@ -0,0 +1,112 @@
#ifndef PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H
#define PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H
#include <string>
#include <vector>
#include <algorithm>
#include "proxysql.h"
#include "cpp.h"
#include "PgSQL_Connection.h"
#ifndef PROXYJSON
#define PROXYJSON
#include "../deps/json/json_fwd.hpp"
#endif // PROXYJSON
/**
* @struct PgSQL_Variable_Snapshot
* @brief Represents a snapshot of PostgreSQL variables during a transaction.
*
* This structure is used to store the state of PostgreSQL variables, including
* their values and hash representations, at a specific point in time during a transaction.
*/
struct PgSQL_Variable_Snapshot {
char* var_value[PGSQL_NAME_LAST_HIGH_WM] = {}; // Not using smart pointers because we need fine-grained control over hashing when values change
uint32_t var_hash[PGSQL_NAME_LAST_HIGH_WM] = {};
};
/**
* @struct TxnCmd
* @brief Represents a transaction command type begin executed and its associated metadata.
*
*/
struct TxnCmd {
/**
* @enum Type
* @brief Enumerates the types of transaction commands.
*/
enum Type {
UNKNOWN = -1,
BEGIN,
COMMIT,
ROLLBACK,
SAVEPOINT,
RELEASE,
ROLLBACK_TO
} type = Type::UNKNOWN;
std::string savepoint; //< The name of the savepoint, if applicable.
};
/**
* @class PgSQL_TxnCmdParser
* @brief Parses transaction-related commands for PostgreSQL.
*
* This class is responsible for tokenizing and interpreting transaction-related
* commands such as BEGIN, COMMIT, ROLLBACK, SAVEPOINT, etc.
*/
class PgSQL_TxnCmdParser {
public:
TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept;
private:
std::vector<std::string_view> tokens;
TxnCmd parse_rollback(size_t& pos) noexcept;
TxnCmd parse_savepoint(size_t& pos) noexcept;
TxnCmd parse_release(size_t& pos) noexcept;
// Helpers
static std::string to_lower(std::string_view s) noexcept {
std::string s_copy(s);
std::transform(s_copy.begin(), s_copy.end(), s_copy.begin(), ::tolower);
return s_copy;
}
inline static bool contains(std::vector<std::string_view>&& list, std::string_view value) noexcept {
for (const auto& item : list) if (item == value) return true;
return false;
}
};
/**
* @class PgSQL_ExplicitTxnStateMgr
* @brief Manages the state of explicit transactions in PostgreSQL.
*
* This class is responsible for handling explicit transaction commands such as
* BEGIN, COMMIT, ROLLBACK, SAVEPOINT, and managing the associated state.
*/
class PgSQL_ExplicitTxnStateMgr {
public:
PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess);
~PgSQL_ExplicitTxnStateMgr();
bool handle_transaction(std::string_view input);
int get_savepoint_count() const { return savepoint.size(); }
void fill_internal_session(nlohmann::json& j);
private:
PgSQL_Session* session;
std::vector<PgSQL_Variable_Snapshot> transaction_state;
std::vector<std::string> savepoint;
PgSQL_TxnCmdParser tx_parser;
void start_transaction();
void commit();
void rollback();
bool add_savepoint(std::string_view name);
bool rollback_to_savepoint(std::string_view name);
bool release_savepoint(std::string_view name);
static void reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept;
};
#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H

@ -279,7 +279,7 @@ private:
std::vector<unsigned int> pkt_offset;
bool multiple_pkt_mode = false;
bool ownership = true;
friend void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type);
friend void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type, char txn_state);
};
class PgSQL_Protocol;
@ -1052,6 +1052,6 @@ private:
friend void admin_session_handler(S* sess, void* _pa, PtrSize_t* pkt);
};
void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type);
void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type, char txn_state = 'I');
#endif // __POSTGRES_PROTOCOL_H

@ -15,6 +15,7 @@
class PgSQL_Query_Result;
class PgSQL_ExplicitTxnStateMgr;
//#include "../deps/json/json.hpp"
//using json = nlohmann::json;
@ -389,6 +390,7 @@ public:
PgSQL_Data_Stream* client_myds;
#endif // 0
PgSQL_Data_Stream* server_myds;
PgSQL_ExplicitTxnStateMgr* transaction_state_manager;
#if 0
/*
* @brief Store the hostgroups that hold connections that have been flagged as 'expired' by the

@ -31,15 +31,15 @@ public:
bool client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx);
bool client_set_hash_and_value(PgSQL_Session* session, int idx, const std::string& value, uint32_t hash);
void client_reset_value(PgSQL_Session* session, int idx);
void client_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx);
const char* client_get_value(PgSQL_Session* session, int idx) const;
uint32_t client_get_hash(PgSQL_Session* session, int idx) const;
void server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx);
void server_set_hash_and_value(PgSQL_Session* session, int idx, const char* value, uint32_t hash);
void server_reset_value(PgSQL_Session* session, int idx);
void server_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx);
const char* server_get_value(PgSQL_Session* session, int idx) const;
inline uint32_t server_get_hash(PgSQL_Session* session, int idx) const;
uint32_t server_get_hash(PgSQL_Session* session, int idx) const;
bool verify_variable(PgSQL_Session* session, int idx) const;
bool update_variable(PgSQL_Session* session, session_status status, int &_rc);

@ -627,6 +627,7 @@ enum PGSQL_QUERY_command {
PGSQL_QUERY_BEGIN,
PGSQL_QUERY_COMMIT,
PGSQL_QUERY_ROLLBACK,
PGSQL_QUERY_ABORT,
PGSQL_QUERY_DECLARE_CURSOR,
PGSQL_QUERY_CLOSE_CURSOR,
PGSQL_QUERY_DISCARD,

@ -307,7 +307,9 @@ struct free_deleter {
template <typename T>
using mf_unique_ptr = std::unique_ptr<T, free_deleter>;
static inline void set_thread_name(const char name[16], const bool en = true) {
template<std::size_t LEN>
static inline void set_thread_name(const char(&name)[LEN], const bool en = true) {
static_assert(LEN < 17, "Thread name must not exceed 16 characters");
if (en == false) {
return;
}

@ -330,7 +330,10 @@ void Base_Session<S, DS, B, T>::return_proxysql_internal(PtrSize_t* pkt) {
char* pta[1];
pta[0] = (char*)s.c_str();
resultset->add_row(pta);
SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset, nullptr, 0, (const char*)pkt->ptr + 5);
unsigned int nTxn = NumActiveTransactions();
char txn_state = (nTxn ? 'T' : 'I');
SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset, nullptr, 0, (const char*)pkt->ptr + 5, txn_state);
delete resultset;
l_free(pkt->size, pkt->ptr);
return;
@ -599,12 +602,14 @@ unsigned int Base_Session<S,DS,B,T>::NumActiveTransactions(bool check_savepoint)
if (_mybe->server_myds->myconn->IsActiveTransaction()) {
ret++;
} else {
// we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due
// to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to
// SAVEPOINT and autocommit=0
if (check_savepoint) {
if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) {
ret++;
if constexpr (std::is_same_v<S, MySQL_Session>) {
// we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due
// to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to
// SAVEPOINT and autocommit=0
if (check_savepoint) {
if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) {
ret++;
}
}
}
}
@ -663,17 +668,18 @@ int Base_Session<S,DS,B,T>::FindOneActiveTransaction(bool check_savepoint) {
if (_mybe->server_myds->myconn) {
if (_mybe->server_myds->myconn->IsKnownActiveTransaction()) {
return (int)_mybe->server_myds->myconn->parent->myhgc->hid;
}
else if (_mybe->server_myds->myconn->IsActiveTransaction()) {
} else if (_mybe->server_myds->myconn->IsActiveTransaction()) {
ret = (int)_mybe->server_myds->myconn->parent->myhgc->hid;
}
else {
// we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due
// to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to
// SAVEPOINT and autocommit=0
if (check_savepoint) {
if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) {
return (int)_mybe->server_myds->myconn->parent->myhgc->hid;
if constexpr (std::is_same_v<S, MySQL_Session>) {
// we use check_savepoint to check if we shouldn't ignore COMMIT or ROLLBACK due
// to MySQL bug https://bugs.mysql.com/bug.php?id=107875 related to
// SAVEPOINT and autocommit=0
if (check_savepoint) {
if (_mybe->server_myds->myconn->AutocommitFalse_AndSavepoint() == true) {
return (int)_mybe->server_myds->myconn->parent->myhgc->hid;
}
}
}
}

@ -164,7 +164,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo
PgSQL_Protocol.oo PgSQL_Thread.oo PgSQL_Data_Stream.oo PgSQL_Session.oo PgSQL_Variables.oo PgSQL_HostGroups_Manager.oo PgSQL_Connection.oo PgSQL_Backend.oo PgSQL_Logger.oo PgSQL_Authentication.oo PgSQL_Error_Helper.oo \
MySQL_Query_Cache.oo PgSQL_Query_Cache.oo PgSQL_Monitor.oo \
MySQL_Set_Stmt_Parser.oo PgSQL_Set_Stmt_Parser.oo \
PgSQL_Variables_Validator.oo
PgSQL_Variables_Validator.oo PgSQL_ExplicitTxnStateMgr.oo
OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX))
HEADERS := ../include/*.h ../include/*.hpp

@ -2956,7 +2956,6 @@ __exit_handler_again___status_CONNECTING_SERVER_with_err:
sprintf(sqlstate,"%s",mysql_sqlstate(myconn->mysql));
client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,mysql_errno(myconn->mysql),sqlstate, errmsg.c_str(), true);
} else {
char buf[256];
errmsg = "Max connect failure while reaching hostgroup " + to_string(current_hostgroup);
client_myds->myprot.generate_pkt_ERR(true,NULL,NULL,1,9002,(char *)"HY000", errmsg.c_str(), true);
if (thread) {

@ -812,7 +812,7 @@ void PgSQL_Connection::connect_start() {
const std::string& conninfo_str = conninfo.str();
pgsql_conn = PQconnectStart(conninfo_str.c_str());
//PQsetErrorVerbosity(pgsql_conn, PQERRORS_VERBOSE);
//PQsetErrorVerbosity(pgsql_conn, PQERRORS_SQLSTATE);
//PQsetErrorContextVisibility(pgsql_conn, PQSHOW_CONTEXT_ERRORS);
if (pgsql_conn == NULL || PQstatus(pgsql_conn) == CONNECTION_BAD) {
@ -1042,15 +1042,12 @@ int PgSQL_Connection::async_connect(short event) {
async_state_machine = ASYNC_IDLE;
myds->wait_until = 0;
return 0;
break;
case ASYNC_CONNECT_FAILED:
return -1;
break;
case ASYNC_CONNECT_TIMEOUT:
return -2;
break;
default:
return 1;
break;
}
return 1;
}
@ -1416,8 +1413,8 @@ bool PgSQL_Connection::IsServerOffline() {
}
bool PgSQL_Connection::is_connection_in_reusable_state() const {
const PGTransactionStatusType txn_status = PQtransactionStatus(pgsql_conn);
const bool conn_usable = !(txn_status == PQTRANS_UNKNOWN || txn_status == PQTRANS_ACTIVE);
PGTransactionStatusType txn_status = PQtransactionStatus(pgsql_conn);
bool conn_usable = !(txn_status == PQTRANS_UNKNOWN || txn_status == PQTRANS_ACTIVE);
assert(!(conn_usable == false && is_error_present() == false));
return conn_usable;
}
@ -1678,7 +1675,7 @@ void PgSQL_Connection::unhandled_notice_cb(void* arg, const PGresult* result) {
#endif
}
void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) {
void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text, int savepoint_count) {
if (query_digest_text == NULL) return;
// unknown what to do with multiplex
int mul = -1;
@ -1721,8 +1718,7 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) {
default:
break;
}
}
else {
} else {
if (mul != 2 && index(query_digest_text, '.')) { // mul = 2 has a special meaning : do not disable multiplex for variables in THIS QUERY ONLY
if (!IsKeepMultiplexEnabledVariables(query_digest_text)) {
set_status(true, STATUS_MYSQL_CONNECTION_USER_VARIABLE);
@ -1767,18 +1763,26 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) {
}
}*/
if (get_status(STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT) == false) {
if (IsKnownActiveTransaction()) {
if (!strncasecmp(query_digest_text, "SAVEPOINT ", strlen("SAVEPOINT "))) {
set_status(true, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT);
if (savepoint_count > 0) {
set_status(true, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT);
} else if (savepoint_count == -1) {
if (IsKnownActiveTransaction()) {
if (!strncasecmp(query_digest_text, "SAVEPOINT ", strlen("SAVEPOINT "))) {
set_status(true, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT);
}
}
}
}
} else {
if ((IsKnownActiveTransaction() == false) ||
(strcasecmp(query_digest_text, "COMMIT") == 0) ||
(strcasecmp(query_digest_text, "ROLLBACK") == 0)) {
if (savepoint_count == 0) {
set_status(false, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT);
}
} else if (savepoint_count == -1) {
if ((IsKnownActiveTransaction() == false) ||
(strncasecmp(query_digest_text, "COMMIT", strlen("COMMIT")) == 0) ||
(strncasecmp(query_digest_text, "ROLLBACK", strlen("ROLLBACK")) == 0) ||
(strncasecmp(query_digest_text, "ABORT", strlen("ABORT")) == 0)) {
set_status(false, STATUS_MYSQL_CONNECTION_HAS_SAVEPOINT);
}
}
}
}

@ -0,0 +1,386 @@
#include "PgSQL_ExplicitTxnStateMgr.h"
#include "proxysql.h"
#include "PgSQL_Session.h"
#include "PgSQL_Data_Stream.h"
#include "PgSQL_Connection.h"
extern class PgSQL_Variables pgsql_variables;
PgSQL_ExplicitTxnStateMgr::PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess) : session(sess) {
}
PgSQL_ExplicitTxnStateMgr::~PgSQL_ExplicitTxnStateMgr() {
for (auto& tran_state : transaction_state) {
reset_variable_snapshot(tran_state);
}
transaction_state.clear();
savepoint.clear();
}
void PgSQL_ExplicitTxnStateMgr::reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept {
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
if (var_snapshot.var_value[idx]) {
free(var_snapshot.var_value[idx]);
var_snapshot.var_value[idx] = nullptr;
}
var_snapshot.var_hash[idx] = 0;
}
}
void verify_server_variables(PgSQL_Session* session) {
#ifdef DEBUG
for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) {
const char* conn_param_status = session->mybe->server_myds->myconn->get_pg_parameter_status(pgsql_tracked_variables[idx].set_variable_name);
const char* param_value = session->mybe->server_myds->myconn->variables[idx].value;
if (conn_param_status && param_value) {
assert(strcmp(conn_param_status, param_value) == 0);
}
}
#endif
}
void PgSQL_ExplicitTxnStateMgr::start_transaction() {
if (transaction_state.empty() == false) {
// Transaction already started, do nothing and return
proxy_warning("Received BEGIN command. There is already a transaction in progress\n");
assert(session->NumActiveTransactions() > 0);
return;
}
assert(session->client_myds && session->client_myds->myconn);
PgSQL_Variable_Snapshot var_snapshot{};
// check if already in transaction, if yes then do nothing
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = pgsql_variables.client_get_hash(session, idx);
if (hash != 0) {
var_snapshot.var_hash[idx] = hash;
var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx));
} else {
assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null
// no need to store the value
//var_snapshot.var_hash[idx] = 0;
//var_snapshot.var_value[idx] = NULL;
}
}
transaction_state.emplace_back(std::move(var_snapshot));
}
void PgSQL_ExplicitTxnStateMgr::commit() {
if (transaction_state.empty()) {
proxy_warning("Received COMMIT command. There is no transaction in progress\n");
assert(session->NumActiveTransactions() == 0);
return;
}
assert(session->client_myds && session->client_myds->myconn);
for (auto& tran_state : transaction_state) {
reset_variable_snapshot(tran_state);
}
transaction_state.clear();
savepoint.clear();
verify_server_variables(session);
}
void PgSQL_ExplicitTxnStateMgr::rollback() {
if (transaction_state.empty()) {
proxy_warning("Received ROLLBACK command. There is no transaction in progress\n");
assert(session->NumActiveTransactions() == 0);
return;
}
assert(session->client_myds && session->client_myds->myconn);
const PgSQL_Variable_Snapshot& var_snapshot = transaction_state.front();
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = var_snapshot.var_hash[idx];
if (hash != 0) {
uint32_t client_hash = pgsql_variables.client_get_hash(session, idx);
uint32_t server_hash = pgsql_variables.server_get_hash(session, idx);
assert(client_hash == server_hash);
if (hash == client_hash)
continue;
pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
} else {
assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null
pgsql_variables.client_reset_value(session, idx, false);
pgsql_variables.server_reset_value(session, idx, false);
}
}
// reuse of connection that has extra param set in connection
session->client_myds->myconn->reorder_dynamic_variables_idx();
if (session->mybe) {
session->mybe->server_myds->myconn->reorder_dynamic_variables_idx();
verify_server_variables(session);
}
// Clear savepoints and reset the initial snapshot
for (auto& tran_state : transaction_state) {
reset_variable_snapshot(tran_state);
}
transaction_state.clear();
savepoint.clear();
}
bool PgSQL_ExplicitTxnStateMgr::rollback_to_savepoint(std::string_view name) {
if (transaction_state.empty()) {
proxy_warning("Received ROLLBACK TO SAVEPOINT '%s' command. There is no transaction in progress\n", name.data());
assert(session->NumActiveTransactions() == 0);
return false;
}
assert(session->client_myds && session->client_myds->myconn);
int tran_state_idx = -1;
for (size_t idx = 0; idx < savepoint.size(); idx++) {
if (savepoint[idx].size() == name.size() &&
strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) {
tran_state_idx = idx;
break;
}
}
if (tran_state_idx == -1) {
proxy_warning("Savepoint '%s' not found.\n", name.data());
return false;
};
assert(tran_state_idx + 1 < (int)transaction_state.size());
PgSQL_Variable_Snapshot& var_snapshot = transaction_state[tran_state_idx+1];
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = var_snapshot.var_hash[idx];
if (hash != 0) {
uint32_t client_hash = pgsql_variables.client_get_hash(session, idx);
uint32_t server_hash = pgsql_variables.server_get_hash(session, idx);
assert(client_hash == server_hash);
if (hash == client_hash)
continue;
pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
}
else {
assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null
pgsql_variables.client_reset_value(session, idx, false);
pgsql_variables.server_reset_value(session, idx, false);
}
}
session->client_myds->myconn->reorder_dynamic_variables_idx();
if (session->mybe) {
session->mybe->server_myds->myconn->reorder_dynamic_variables_idx();
verify_server_variables(session);
}
for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) {
reset_variable_snapshot(transaction_state[idx]);
}
transaction_state.resize(tran_state_idx + 1);
savepoint.resize(tran_state_idx);
return true;
}
bool PgSQL_ExplicitTxnStateMgr::release_savepoint(std::string_view name) {
if (transaction_state.empty()) {
proxy_warning("Received RELEASE SAVEPOINT '%s' command. There is no transaction in progress\n", name.data());
assert(session->NumActiveTransactions() == 0);
return false;
}
assert(session->client_myds && session->client_myds->myconn);
int tran_state_idx = -1;
for (size_t idx = 0; idx < savepoint.size(); idx++) {
if (savepoint[idx].size() == name.size() &&
strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) {
tran_state_idx = idx;
break;
}
}
if (tran_state_idx == -1) {
proxy_warning("SAVEPOINT '%s' not found.\n", name.data());
return false;
};
for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) {
reset_variable_snapshot(transaction_state[idx]);
}
transaction_state.resize(tran_state_idx + 1);
savepoint.resize(tran_state_idx);
return true;
}
bool PgSQL_ExplicitTxnStateMgr::add_savepoint(std::string_view name) {
if (transaction_state.empty()) {
proxy_warning("Received SAVEPOINT '%s' command. There is no transaction in progress\n", name.data());
assert(session->NumActiveTransactions() == 0);
return false;
}
assert(session->client_myds && session->client_myds->myconn);
auto it = std::find_if(savepoint.begin(), savepoint.end(), [name](std::string_view sp) {
return sp.size() == name.size() &&
strncasecmp(sp.data(), name.data(), name.size()) == 0;
});
if (it != savepoint.end()) return false;
PgSQL_Variable_Snapshot var_snapshot{};
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = pgsql_variables.client_get_hash(session, idx);
if (hash != 0) {
var_snapshot.var_hash[idx] = hash;
var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx));
}
}
transaction_state.emplace_back(std::move(var_snapshot));
savepoint.emplace_back(name);
assert((transaction_state.size() - 1) == savepoint.size());
return true;
}
void PgSQL_ExplicitTxnStateMgr::fill_internal_session(nlohmann::json& j) {
if (transaction_state.empty()) return;
auto& initial_state = j["initial_state"];
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = transaction_state[0].var_hash[idx];
if (hash != 0) {
initial_state[pgsql_tracked_variables[idx].set_variable_name] = transaction_state[0].var_value[idx];
}
}
if (savepoint.empty()) return;
for (size_t idx = 0; idx < savepoint.size(); idx++) {
auto& savepoint_json = j["savepoints"][savepoint[idx]];
int tran_state_idx = idx + 1;
for (int idx2 = 0; idx2 < PGSQL_NAME_LAST_HIGH_WM; idx2++) {
uint32_t hash = transaction_state[tran_state_idx].var_hash[idx2];
if (hash != 0) {
savepoint_json[pgsql_tracked_variables[idx2].set_variable_name] = transaction_state[tran_state_idx].var_value[idx2];
}
}
}
}
bool PgSQL_ExplicitTxnStateMgr::handle_transaction(std::string_view input) {
TxnCmd cmd = tx_parser.parse(input, (session->active_transactions > 0));
switch (cmd.type) {
case TxnCmd::BEGIN:
start_transaction();
break;
case TxnCmd::COMMIT:
commit();
break;
case TxnCmd::ROLLBACK:
rollback();
break;
case TxnCmd::SAVEPOINT:
return add_savepoint(cmd.savepoint);
case TxnCmd::RELEASE:
return release_savepoint(cmd.savepoint);
case TxnCmd::ROLLBACK_TO:
return rollback_to_savepoint(cmd.savepoint);
default:
// Unknown command
return false;
}
return true;
}
TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mode) noexcept {
tokens.clear();
TxnCmd cmd;
bool in_quote = false;
size_t start = 0;
char quote_char = 0;
// Tokenize with quote handling
for (size_t i = 0; i <= input.size(); ++i) {
const bool at_end = i == input.size();
const char c = at_end ? 0 : input[i];
if (in_quote) {
if (c == quote_char || at_end) {
tokens.emplace_back(input.substr(start + 1, i - start - 1));
in_quote = false;
}
continue;
}
if (c == '"' || c == '\'') {
in_quote = true;
quote_char = c;
start = i;
}
else if (isspace(c) || c == ';' || at_end) {
if (start < i) tokens.emplace_back(input.substr(start, i - start));
start = i + 1;
}
}
if (tokens.empty()) return cmd;
size_t pos = 0;
const std::string first = to_lower(tokens[pos++]);
if (in_transaction_mode == true) {
if (first == "begin") cmd.type = TxnCmd::BEGIN;
else if (first == "savepoint") cmd = parse_savepoint(pos);
else if (first == "release") cmd = parse_release(pos);
else if (first == "rollback") cmd = parse_rollback(pos);
} else {
if (first == "commit") cmd.type = TxnCmd::COMMIT;
else if (first == "rollback" || (first == "abort")) cmd = parse_rollback(pos);
}
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::ROLLBACK };
while (pos < tokens.size() && contains({ "work", "transaction" }, to_lower(tokens[pos]))) pos++;
if (pos < tokens.size() && to_lower(tokens[pos]) == "to") {
cmd.type = TxnCmd::ROLLBACK_TO;
if (++pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++;
if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
}
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_savepoint(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::SAVEPOINT };
if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::RELEASE };
if (pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++;
if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
return cmd;
}

@ -197,7 +197,7 @@ void PG_pkt::write_RowDescription(const char *tupdesc, ...) {
}
void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error, int affected_rows, const char *query_type) {
void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error, int affected_rows, const char *query_type, char txn_state) {
assert(psa != NULL);
const char *fs = strchr(query_type, ' ');
int qtlen = strlen(query_type);
@ -257,7 +257,7 @@ void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error,
pkt.write_CommandComplete(buf);
}
pkt.to_PtrSizeArray(psa);
pkt.write_ReadyForQuery();
pkt.write_ReadyForQuery(txn_state);
pkt.to_PtrSizeArray(psa);
} else { // no resultset
PG_pkt pkt(64);
@ -289,7 +289,7 @@ void SQLite3_to_Postgres(PtrSizeArray *psa, SQLite3_result *result, char *error,
}
}
pkt.to_PtrSizeArray(psa);
pkt.write_ReadyForQuery();
pkt.write_ReadyForQuery(txn_state);
pkt.to_PtrSizeArray(psa);
}
}

@ -102,6 +102,7 @@ static char* commands_counters_desc[PGSQL_QUERY___NONE] = {
[PGSQL_QUERY_BEGIN] = (char*)"BEGIN",
[PGSQL_QUERY_COMMIT] = (char*)"COMMIT",
[PGSQL_QUERY_ROLLBACK] = (char*)"ROLLBACK",
[PGSQL_QUERY_ABORT] = (char*)"ABORT",
[PGSQL_QUERY_DECLARE_CURSOR] = (char*)"DECLARE_CURSOR",
[PGSQL_QUERY_CLOSE_CURSOR] = (char*)"CLOSE_CURSOR",
[PGSQL_QUERY_DISCARD] = (char*)"DISCARD",
@ -664,7 +665,7 @@ enum PGSQL_QUERY_command PgSQL_Query_Processor::query_parser_command_type(SQP_pa
char c1;
tokenizer_t tok;
tokenizer(&tok, text, " ", TOKENIZER_NO_EMPTIES);
tokenizer(&tok, text, " ;", TOKENIZER_NO_EMPTIES);
char* token = NULL;
__get_token:
token = (char*)tokenize(&tok);
@ -761,6 +762,10 @@ __remove_parenthesis:
ret = PGSQL_QUERY_ANALYZE;
break;
}
if (!strcasecmp("ABORT", token)) {
ret = PGSQL_QUERY_ROLLBACK;
break;
}
break;
case 'b':
case 'B':

@ -25,7 +25,7 @@ using json = nlohmann::json;
#include "ProxySQL_Cluster.hpp"
#include "PgSQL_Query_Cache.h"
#include "PgSQL_Variables_Validator.h"
#include "PgSQL_ExplicitTxnStateMgr.h"
#include "libinjection.h"
#include "libinjection_sqli.h"
@ -81,8 +81,9 @@ static inline char is_normal_char(char c) {
}
*/
static const std::array<std::string,6> pgsql_critical_variables = {
static const std::array<std::string,7> pgsql_critical_variables = {
"client_encoding",
"names",
"datestyle",
"intervalstyle",
"standard_conforming_strings",
@ -579,6 +580,7 @@ PgSQL_Session::PgSQL_Session() {
last_HG_affected_rows = -1; // #1421 : advanced support for LAST_INSERT_ID()
proxysql_node_address = NULL;
use_ldap_auth = false;
transaction_state_manager = new PgSQL_ExplicitTxnStateMgr(this);
}
void PgSQL_Session::reset() {
@ -678,6 +680,8 @@ PgSQL_Session::~PgSQL_Session() {
for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
reset_default_session_variable((enum pgsql_variable_name)i);
}
if (transaction_state_manager)
delete transaction_state_manager;
}
bool PgSQL_Session::handler_CommitRollback(PtrSize_t* pkt) {
@ -777,6 +781,9 @@ void PgSQL_Session::generate_proxysql_internal_session_json(json& j) {
j["qpo"]["max_lag_ms"] = qpo->max_lag_ms;
j["user_attributes"] = (user_attributes ? user_attributes : "");
j["transaction_persistent"] = transaction_persistent;
transaction_state_manager->fill_internal_session(j["transaction_state"]);
if (client_myds != NULL) { // only if client_myds is defined
j["client"]["stream"]["pkts_recv"] = client_myds->pkts_recv;
j["client"]["stream"]["pkts_sent"] = client_myds->pkts_sent;
@ -1820,7 +1827,7 @@ bool PgSQL_Session::handler_again___status_CONNECTING_SERVER(int* _rc) {
NEXT_IMMEDIATE_NEW(CONNECTING_SERVER);
}
else {
__exit_handler_again___status_CONNECTING_SERVER_with_err:
__exit_handler_again___status_CONNECTING_SERVER_with_err:
bool is_error_present = myconn->is_error_present();
if (is_error_present) {
client_myds->myprot.generate_error_packet(true, true, myconn->error_info.message.c_str(),
@ -5463,26 +5470,26 @@ void PgSQL_Session::RequestEnd(PgSQL_Data_Stream* myds, const unsigned int myerr
if (status != PROCESSING_STMT_EXECUTE) {
qdt = CurrentQuery.get_digest_text();
}
else {
} else {
qdt = CurrentQuery.stmt_info->digest_text;
}
// is savepoint currently present in transaction.
int savepoint_count = -1; // haven't checked yet
// we do not maintain the transaction variable state if the session is locked on a hostgroup
// or is a Fast Forward session.
if (locked_on_hostgroup == -1 && session_fast_forward == SESSION_FORWARD_TYPE_NONE) {
transaction_state_manager->handle_transaction(qdt);
savepoint_count = transaction_state_manager->get_savepoint_count();
}
if (qdt && myds && myds->myconn) {
myds->myconn->ProcessQueryAndSetStatusFlags(qdt);
myds->myconn->ProcessQueryAndSetStatusFlags(qdt, savepoint_count);
}
switch (status) {
/*case PROCESSING_STMT_EXECUTE:
case PROCESSING_STMT_PREPARE:
// if a prepared statement is executed, LogQuery was already called
break;
*/
default:
if (session_fast_forward == SESSION_FORWARD_TYPE_NONE) {
LogQuery(myds);
}
break;
if (session_fast_forward == SESSION_FORWARD_TYPE_NONE) {
LogQuery(myds);
}
GloPgQPro->delete_QP_out(qpo);

@ -74,7 +74,8 @@ void PgSQL_Set_Stmt_Parser::generateRE_parse1v2() {
// Function Call: Check if Group 3 is populated.
// Literal: Check if Group 4 is populated.
//const std::string pattern = "(?:(SESSION|LOCAL)\\s+)?((?:\\S+(?:\\s+\\S+)*?))(?:\\s+(?:TO|=)\\s+|\\s+)(?:(\\w+\\s*\\([^)]*\\))|((?:'(?:''|[^'])*'|-?\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?|t|true|f|false|on|off|default|\\S+)))\\s*;?";
const std::string pattern = "(?:(SESSION|LOCAL)\\s+)?((?:\\S+(?:\\s+\\S+)*?))(?:\\s*(?:TO|=)\\s*|\\s+)(?:(\\w+\\s*\\([^)]*\\))|((?:'(?:''|[^'])*'|-?\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?|true|t|1|yes|false|f|0|no|on|off|default|\\S+)))\\s*;?";
//const std::string pattern = "(?:(SESSION|LOCAL)\\s+)?((?:\\S+(?:\\s+\\S+)*?))(?:\\s*(?:TO|=)\\s*|\\s+)(?:(\\w+\\s*\\([^)]*\\))|((?:'(?:''|[^'])*'|-?\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?|true|t|1|yes|false|f|0|no|on|off|default|\\S+)))\\s*;?";
const std::string pattern = R"((?:(SESSION)\s+)?((?:\S+(?:\s+\S+)*?))(?:\s*(?:TO|=)\s*|\s+)(?:(\w+\s*\([^)]*\))|((?:'(?:''|[^'])*'|-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?|[^;]+)))\s*;?)";
#ifdef DEBUG
VALGRIND_DISABLE_ERROR_REPORTING;

@ -85,7 +85,7 @@ bool PgSQL_Variables::client_set_hash_and_value(PgSQL_Session* session, int idx,
return true;
}
void PgSQL_Variables::client_reset_value(PgSQL_Session* session, int idx) {
void PgSQL_Variables::client_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx) {
if (!session || !session->client_myds || !session->client_myds->myconn) {
proxy_warning("Session validation failed\n");
return;
@ -99,7 +99,7 @@ void PgSQL_Variables::client_reset_value(PgSQL_Session* session, int idx) {
free(client_conn->variables[idx].value);
client_conn->variables[idx].value = NULL;
}
if (idx > PGSQL_NAME_LAST_LOW_WM) {
if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) {
// we now regererate dynamic_variables_idx
client_conn->reorder_dynamic_variables_idx();
}
@ -170,7 +170,7 @@ void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const ch
}
}
void PgSQL_Variables::server_reset_value(PgSQL_Session* session, int idx) {
void PgSQL_Variables::server_reset_value(PgSQL_Session* session, int idx, bool reorder_dynamic_variables_idx) {
assert(session);
assert(session->mybe);
assert(session->mybe->server_myds);
@ -184,7 +184,7 @@ void PgSQL_Variables::server_reset_value(PgSQL_Session* session, int idx) {
free(backend_conn->variables[idx].value);
backend_conn->variables[idx].value = NULL;
}
if (idx > PGSQL_NAME_LAST_LOW_WM) {
if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) {
// we now regererate dynamic_variables_idx
backend_conn->reorder_dynamic_variables_idx();
}

@ -0,0 +1,288 @@
/**
* @file pgsql-transaction_variable_state_tracking-t.cpp
* @brief TAP test validating PostgreSQL session parameter behavior in transactions.
* Tests rollback/commit/savepoint behavior for session variables to ensure state consistency.
*/
#include <unistd.h>
#include <string>
#include <sstream>
#include <chrono>
#include <thread>
#include "libpq-fe.h"
#include "command_line.h"
#include "tap.h"
#include "utils.h"
CommandLine cl;
using PGConnPtr = std::unique_ptr<PGconn, decltype(&PQfinish)>;
using PGResultPtr = std::unique_ptr<PGresult, decltype(&PQclear)>;
enum ConnType {
ADMIN,
BACKEND
};
PGConnPtr createNewConnection(ConnType conn_type, const std::string& options = "", bool with_ssl = false) {
const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host;
int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port;
const char* username = (conn_type == BACKEND) ? cl.pgsql_root_username : cl.admin_username;
const char* password = (conn_type == BACKEND) ? cl.pgsql_root_password : cl.admin_password;
std::stringstream ss;
ss << "host=" << host << " port=" << port;
ss << " user=" << username << " password=" << password;
ss << (with_ssl ? " sslmode=require" : " sslmode=disable");
if (options.empty() == false) {
ss << " options='" << options << "'";
}
PGconn* conn = PQconnectdb(ss.str().c_str());
if (PQstatus(conn) != CONNECTION_OK) {
fprintf(stderr, "Connection failed to '%s': %s", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn));
PQfinish(conn);
return PGConnPtr(nullptr, &PQfinish);
}
return PGConnPtr(conn, &PQfinish);
}
struct TestCase {
std::string name;
std::function<bool()> test_fn;
bool should_fail;
};
struct TestVariable {
std::string name;
std::vector<std::string> test_values;
};
std::vector<TestCase> tests;
PGResultPtr executeQuery(PGconn* conn, const std::string& query) {
diag("Executing: %s", query.c_str());
PGresult* res = PQexec(conn, query.c_str());
if (PQresultStatus(res) != PGRES_COMMAND_OK && PQresultStatus(res) != PGRES_TUPLES_OK) {
diag("Query failed: %s", PQerrorMessage(conn));
}
return PGResultPtr(res, &PQclear);
}
std::string getVariable(PGconn* conn, const std::string& var) {
auto res = executeQuery(conn, ("SHOW " + var));
const std::string& val = std::string(PQgetvalue(res.get(), 0, 0));
diag(">> '%s' = '%s'", var.c_str(), val.c_str());
return val;
}
void reset_variable(PGconn* conn, const std::string& var, const std::string& original) {
executeQuery(conn, "SET " + var + " = " + original);
}
void add_test(const std::string& name, std::function<bool()> fn, bool should_fail = false) {
tests.push_back({ name, fn, should_fail });
}
void run_tests() {
for (const auto& test : tests) {
bool result = false;
try {
result = test.test_fn();
if (test.should_fail) result = !result;
}
catch (const std::exception& e) {
result = false;
}
ok(result, "Test:%s should %s", test.name.c_str(), test.should_fail ? "FAIL" : "PASS");
}
}
// Common test scenarios
bool test_transaction_rollback(const TestVariable& var) {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
const auto original = getVariable(conn.get(), var.name);
executeQuery(conn.get(), "BEGIN");
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
executeQuery(conn.get(), "ROLLBACK");
const bool success = getVariable(conn.get(), var.name) == original;
return success;
}
bool test_savepoint_rollback(const TestVariable& var) {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
const auto original = getVariable(conn.get(), var.name);
executeQuery(conn.get(), "BEGIN");
executeQuery(conn.get(), "SAVEPOINT sp1");
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
executeQuery(conn.get(), "ROLLBACK TO sp1");
executeQuery(conn.get(), "COMMIT");
const bool success = getVariable(conn.get(), var.name) == original;
return success;
}
bool test_transaction_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
const auto test_value = var.test_values[0];
executeQuery(conn.get(), "BEGIN");
executeQuery(conn.get(), "SET " + var.name + " = " + test_value);
executeQuery(conn.get(), "COMMIT");
const bool success = getVariable(conn.get(), var.name) == test_value;
reset_variable(conn.get(), var.name, original_values.at(var.name));
return success;
}
bool test_savepoint_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
const auto test_value = var.test_values[0];
executeQuery(conn.get(), "BEGIN");
executeQuery(conn.get(), "SAVEPOINT sp1");
executeQuery(conn.get(), "SET " + var.name + " = " + test_value);
executeQuery(conn.get(), "RELEASE SAVEPOINT sp1");
executeQuery(conn.get(), "COMMIT");
const bool success = getVariable(conn.get(), var.name) == test_value;
reset_variable(conn.get(), var.name, original_values.at(var.name));
return success;
}
bool test_savepoint_release_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
const auto original = getVariable(conn.get(), var.name);
executeQuery(conn.get(), "BEGIN");
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
executeQuery(conn.get(), "SAVEPOINT sp1");
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[1]);
executeQuery(conn.get(), "RELEASE SAVEPOINT sp1");
executeQuery(conn.get(), "COMMIT");
const bool success = getVariable(conn.get(), var.name) == var.test_values[1];
reset_variable(conn.get(), var.name, original_values.at(var.name));
return success;
}
int main(int argc, char** argv) {
if (cl.getEnv())
return exit_status();
std::map<std::string, std::string> original_values;
std::map<std::string, TestVariable> tracked_vars = {
{"client_encoding", {"client_encoding", {"LATIN1", "UTF8"}}},
{"datestyle", {"datestyle", {"ISO, MDY", "SQL, DMY"}}},
{"intervalstyle", {"intervalstyle", {"postgres", "iso_8601"}}},
{"standard_conforming_strings", {"standard_conforming_strings", {"on", "off"}}},
{"timezone", {"timezone", {"UTC", "PST8PDT"}}},
{"bytea_output", {"bytea_output", {"hex", "escape"}}},
{"allow_in_place_tablespaces", {"allow_in_place_tablespaces", {"on", "off"}}},
{"enable_bitmapscan", {"enable_bitmapscan", {"on", "off"}}},
{"enable_hashjoin", {"enable_hashjoin", {"on", "off"}}},
{"enable_indexscan", {"enable_indexscan", {"on", "off"}}},
{"enable_nestloop", {"enable_nestloop", {"on", "off"}}},
{"enable_seqscan", {"enable_seqscan", {"on", "off"}}},
{"enable_sort", {"enable_sort", {"on", "off"}}},
{"escape_string_warning", {"escape_string_warning", {"on", "off"}}},
{"synchronous_commit", {"synchronous_commit", {"on", "off"}}},
{"extra_float_digits", {"extra_float_digits", {"0", "3"}}},
{"client_min_messages", {"client_min_messages", {"notice", "warning"}}}
};
PGConnPtr conn = createNewConnection(ConnType::BACKEND, "", false);
if (!conn || PQstatus(conn.get()) != CONNECTION_OK) {
BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__);
return exit_status();
}
// Store original values
for (const auto& [name, var] : tracked_vars) {
original_values[name] = getVariable(conn.get(), name);
}
// Add generic tests
add_test("Commit without transaction should fail", [&]() {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
PGresult* res = PQexec(conn.get(), "COMMIT");
const bool failed = PQresultStatus(res) != PGRES_COMMAND_OK;
PQclear(res);
return failed;
}, true);
// Add variable-specific tests using containers
for (const auto& [name, var] : tracked_vars) {
add_test("Rollback reverts " + var.name, [var]() {
return test_transaction_rollback(var);
});
add_test("Commit persists " + var.name, [&]() {
return test_transaction_commit(var, original_values);
});
add_test("Savepoint rollback for " + var.name, [var]() {
return test_savepoint_rollback(var);
});
add_test("Savepoint commit for " + var.name, [&]() {
return test_savepoint_commit(var, original_values);
});
// Multi-value savepoint test
if (var.test_values.size() > 1) {
add_test("Multi-value savepoint for " + var.name, [&]() {
return test_savepoint_release_commit(var, original_values);
});
}
}
// Add complex scenario tests
add_test("Nested BEGIN with rollback", [&]() {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
const auto original_tz = getVariable(conn.get(), "timezone");
executeQuery(conn.get(), "BEGIN");
executeQuery(conn.get(), "SET timezone = 'UTC'");
executeQuery(conn.get(), "BEGIN"); // Second BEGIN
executeQuery(conn.get(), "SET timezone = 'PST8PDT'");
executeQuery(conn.get(), "ROLLBACK");
const bool success = getVariable(conn.get(), "timezone") == original_tz;
return success;
});
add_test("Mixed variables in transaction", [&]() {
auto conn = createNewConnection(ConnType::BACKEND, "", false);
bool success = true;
executeQuery(conn.get(), "BEGIN");
for (const auto& [name, var] : tracked_vars) {
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
}
executeQuery(conn.get(), "ROLLBACK");
for (const auto& [name, var] : tracked_vars) {
success = (getVariable(conn.get(), var.name) == original_values.at(var.name));
}
return success;
});
int total_tests = 0;
total_tests = tests.size();
plan(total_tests);
run_tests();
return exit_status();
}
Loading…
Cancel
Save