mirror of https://github.com/sysown/proxysql
Merge pull request #4929 from sysown/v3.0_track_transaction_param_state_4907
Track Variable State Across Transactions and Savepoints - v3.0pull/4941/head
commit
9f4b139e6c
@ -0,0 +1,112 @@
|
||||
#ifndef PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H
|
||||
#define PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "proxysql.h"
|
||||
#include "cpp.h"
|
||||
#include "PgSQL_Connection.h"
|
||||
|
||||
#ifndef PROXYJSON
|
||||
#define PROXYJSON
|
||||
#include "../deps/json/json_fwd.hpp"
|
||||
#endif // PROXYJSON
|
||||
|
||||
/**
|
||||
* @struct PgSQL_Variable_Snapshot
|
||||
* @brief Represents a snapshot of PostgreSQL variables during a transaction.
|
||||
*
|
||||
* This structure is used to store the state of PostgreSQL variables, including
|
||||
* their values and hash representations, at a specific point in time during a transaction.
|
||||
*/
|
||||
struct PgSQL_Variable_Snapshot {
|
||||
char* var_value[PGSQL_NAME_LAST_HIGH_WM] = {}; // Not using smart pointers because we need fine-grained control over hashing when values change
|
||||
uint32_t var_hash[PGSQL_NAME_LAST_HIGH_WM] = {};
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct TxnCmd
|
||||
* @brief Represents a transaction command type begin executed and its associated metadata.
|
||||
*
|
||||
*/
|
||||
struct TxnCmd {
|
||||
/**
|
||||
* @enum Type
|
||||
* @brief Enumerates the types of transaction commands.
|
||||
*/
|
||||
enum Type {
|
||||
UNKNOWN = -1,
|
||||
BEGIN,
|
||||
COMMIT,
|
||||
ROLLBACK,
|
||||
SAVEPOINT,
|
||||
RELEASE,
|
||||
ROLLBACK_TO
|
||||
} type = Type::UNKNOWN;
|
||||
std::string savepoint; //< The name of the savepoint, if applicable.
|
||||
};
|
||||
|
||||
/**
|
||||
* @class PgSQL_TxnCmdParser
|
||||
* @brief Parses transaction-related commands for PostgreSQL.
|
||||
*
|
||||
* This class is responsible for tokenizing and interpreting transaction-related
|
||||
* commands such as BEGIN, COMMIT, ROLLBACK, SAVEPOINT, etc.
|
||||
*/
|
||||
class PgSQL_TxnCmdParser {
|
||||
public:
|
||||
TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept;
|
||||
|
||||
private:
|
||||
std::vector<std::string_view> tokens;
|
||||
|
||||
TxnCmd parse_rollback(size_t& pos) noexcept;
|
||||
TxnCmd parse_savepoint(size_t& pos) noexcept;
|
||||
TxnCmd parse_release(size_t& pos) noexcept;
|
||||
|
||||
// Helpers
|
||||
static std::string to_lower(std::string_view s) noexcept {
|
||||
std::string s_copy(s);
|
||||
std::transform(s_copy.begin(), s_copy.end(), s_copy.begin(), ::tolower);
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
inline static bool contains(std::vector<std::string_view>&& list, std::string_view value) noexcept {
|
||||
for (const auto& item : list) if (item == value) return true;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class PgSQL_ExplicitTxnStateMgr
|
||||
* @brief Manages the state of explicit transactions in PostgreSQL.
|
||||
*
|
||||
* This class is responsible for handling explicit transaction commands such as
|
||||
* BEGIN, COMMIT, ROLLBACK, SAVEPOINT, and managing the associated state.
|
||||
*/
|
||||
class PgSQL_ExplicitTxnStateMgr {
|
||||
public:
|
||||
PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess);
|
||||
~PgSQL_ExplicitTxnStateMgr();
|
||||
|
||||
bool handle_transaction(std::string_view input);
|
||||
int get_savepoint_count() const { return savepoint.size(); }
|
||||
void fill_internal_session(nlohmann::json& j);
|
||||
|
||||
private:
|
||||
PgSQL_Session* session;
|
||||
std::vector<PgSQL_Variable_Snapshot> transaction_state;
|
||||
std::vector<std::string> savepoint;
|
||||
PgSQL_TxnCmdParser tx_parser;
|
||||
|
||||
void start_transaction();
|
||||
void commit();
|
||||
void rollback();
|
||||
bool add_savepoint(std::string_view name);
|
||||
bool rollback_to_savepoint(std::string_view name);
|
||||
bool release_savepoint(std::string_view name);
|
||||
|
||||
static void reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept;
|
||||
};
|
||||
|
||||
#endif // PGSQL_EXPLICIT_TRANSACTION_STATE_MANAGER_H
|
||||
@ -0,0 +1,288 @@
|
||||
/**
|
||||
* @file pgsql-transaction_variable_state_tracking-t.cpp
|
||||
* @brief TAP test validating PostgreSQL session parameter behavior in transactions.
|
||||
* Tests rollback/commit/savepoint behavior for session variables to ensure state consistency.
|
||||
*/
|
||||
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "libpq-fe.h"
|
||||
#include "command_line.h"
|
||||
#include "tap.h"
|
||||
#include "utils.h"
|
||||
|
||||
CommandLine cl;
|
||||
|
||||
using PGConnPtr = std::unique_ptr<PGconn, decltype(&PQfinish)>;
|
||||
using PGResultPtr = std::unique_ptr<PGresult, decltype(&PQclear)>;
|
||||
|
||||
enum ConnType {
|
||||
ADMIN,
|
||||
BACKEND
|
||||
};
|
||||
|
||||
PGConnPtr createNewConnection(ConnType conn_type, const std::string& options = "", bool with_ssl = false) {
|
||||
|
||||
const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host;
|
||||
int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port;
|
||||
const char* username = (conn_type == BACKEND) ? cl.pgsql_root_username : cl.admin_username;
|
||||
const char* password = (conn_type == BACKEND) ? cl.pgsql_root_password : cl.admin_password;
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "host=" << host << " port=" << port;
|
||||
ss << " user=" << username << " password=" << password;
|
||||
ss << (with_ssl ? " sslmode=require" : " sslmode=disable");
|
||||
|
||||
if (options.empty() == false) {
|
||||
ss << " options='" << options << "'";
|
||||
}
|
||||
|
||||
PGconn* conn = PQconnectdb(ss.str().c_str());
|
||||
if (PQstatus(conn) != CONNECTION_OK) {
|
||||
fprintf(stderr, "Connection failed to '%s': %s", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn));
|
||||
PQfinish(conn);
|
||||
return PGConnPtr(nullptr, &PQfinish);
|
||||
}
|
||||
return PGConnPtr(conn, &PQfinish);
|
||||
}
|
||||
|
||||
struct TestCase {
|
||||
std::string name;
|
||||
std::function<bool()> test_fn;
|
||||
bool should_fail;
|
||||
};
|
||||
|
||||
struct TestVariable {
|
||||
std::string name;
|
||||
std::vector<std::string> test_values;
|
||||
};
|
||||
|
||||
std::vector<TestCase> tests;
|
||||
|
||||
PGResultPtr executeQuery(PGconn* conn, const std::string& query) {
|
||||
diag("Executing: %s", query.c_str());
|
||||
PGresult* res = PQexec(conn, query.c_str());
|
||||
if (PQresultStatus(res) != PGRES_COMMAND_OK && PQresultStatus(res) != PGRES_TUPLES_OK) {
|
||||
diag("Query failed: %s", PQerrorMessage(conn));
|
||||
}
|
||||
return PGResultPtr(res, &PQclear);
|
||||
}
|
||||
|
||||
std::string getVariable(PGconn* conn, const std::string& var) {
|
||||
auto res = executeQuery(conn, ("SHOW " + var));
|
||||
const std::string& val = std::string(PQgetvalue(res.get(), 0, 0));
|
||||
diag(">> '%s' = '%s'", var.c_str(), val.c_str());
|
||||
return val;
|
||||
}
|
||||
|
||||
void reset_variable(PGconn* conn, const std::string& var, const std::string& original) {
|
||||
executeQuery(conn, "SET " + var + " = " + original);
|
||||
}
|
||||
|
||||
void add_test(const std::string& name, std::function<bool()> fn, bool should_fail = false) {
|
||||
tests.push_back({ name, fn, should_fail });
|
||||
}
|
||||
|
||||
void run_tests() {
|
||||
|
||||
for (const auto& test : tests) {
|
||||
bool result = false;
|
||||
|
||||
try {
|
||||
result = test.test_fn();
|
||||
if (test.should_fail) result = !result;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
result = false;
|
||||
}
|
||||
|
||||
ok(result, "Test:%s should %s", test.name.c_str(), test.should_fail ? "FAIL" : "PASS");
|
||||
}
|
||||
}
|
||||
|
||||
// Common test scenarios
|
||||
bool test_transaction_rollback(const TestVariable& var) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original = getVariable(conn.get(), var.name);
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
executeQuery(conn.get(), "ROLLBACK");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == original;
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_savepoint_rollback(const TestVariable& var) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original = getVariable(conn.get(), var.name);
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
executeQuery(conn.get(), "ROLLBACK TO sp1");
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == original;
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_transaction_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto test_value = var.test_values[0];
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + test_value);
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == test_value;
|
||||
reset_variable(conn.get(), var.name, original_values.at(var.name));
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_savepoint_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto test_value = var.test_values[0];
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + test_value);
|
||||
executeQuery(conn.get(), "RELEASE SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == test_value;
|
||||
reset_variable(conn.get(), var.name, original_values.at(var.name));
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_savepoint_release_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original = getVariable(conn.get(), var.name);
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
executeQuery(conn.get(), "SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[1]);
|
||||
executeQuery(conn.get(), "RELEASE SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
const bool success = getVariable(conn.get(), var.name) == var.test_values[1];
|
||||
reset_variable(conn.get(), var.name, original_values.at(var.name));
|
||||
return success;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
if (cl.getEnv())
|
||||
return exit_status();
|
||||
|
||||
std::map<std::string, std::string> original_values;
|
||||
std::map<std::string, TestVariable> tracked_vars = {
|
||||
{"client_encoding", {"client_encoding", {"LATIN1", "UTF8"}}},
|
||||
{"datestyle", {"datestyle", {"ISO, MDY", "SQL, DMY"}}},
|
||||
{"intervalstyle", {"intervalstyle", {"postgres", "iso_8601"}}},
|
||||
{"standard_conforming_strings", {"standard_conforming_strings", {"on", "off"}}},
|
||||
{"timezone", {"timezone", {"UTC", "PST8PDT"}}},
|
||||
{"bytea_output", {"bytea_output", {"hex", "escape"}}},
|
||||
{"allow_in_place_tablespaces", {"allow_in_place_tablespaces", {"on", "off"}}},
|
||||
{"enable_bitmapscan", {"enable_bitmapscan", {"on", "off"}}},
|
||||
{"enable_hashjoin", {"enable_hashjoin", {"on", "off"}}},
|
||||
{"enable_indexscan", {"enable_indexscan", {"on", "off"}}},
|
||||
{"enable_nestloop", {"enable_nestloop", {"on", "off"}}},
|
||||
{"enable_seqscan", {"enable_seqscan", {"on", "off"}}},
|
||||
{"enable_sort", {"enable_sort", {"on", "off"}}},
|
||||
{"escape_string_warning", {"escape_string_warning", {"on", "off"}}},
|
||||
{"synchronous_commit", {"synchronous_commit", {"on", "off"}}},
|
||||
{"extra_float_digits", {"extra_float_digits", {"0", "3"}}},
|
||||
{"client_min_messages", {"client_min_messages", {"notice", "warning"}}}
|
||||
};
|
||||
|
||||
|
||||
PGConnPtr conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
if (!conn || PQstatus(conn.get()) != CONNECTION_OK) {
|
||||
BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__);
|
||||
return exit_status();
|
||||
}
|
||||
|
||||
// Store original values
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
original_values[name] = getVariable(conn.get(), name);
|
||||
}
|
||||
|
||||
|
||||
// Add generic tests
|
||||
add_test("Commit without transaction should fail", [&]() {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
PGresult* res = PQexec(conn.get(), "COMMIT");
|
||||
const bool failed = PQresultStatus(res) != PGRES_COMMAND_OK;
|
||||
PQclear(res);
|
||||
return failed;
|
||||
}, true);
|
||||
|
||||
// Add variable-specific tests using containers
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
add_test("Rollback reverts " + var.name, [var]() {
|
||||
return test_transaction_rollback(var);
|
||||
});
|
||||
|
||||
add_test("Commit persists " + var.name, [&]() {
|
||||
return test_transaction_commit(var, original_values);
|
||||
});
|
||||
|
||||
add_test("Savepoint rollback for " + var.name, [var]() {
|
||||
return test_savepoint_rollback(var);
|
||||
});
|
||||
|
||||
add_test("Savepoint commit for " + var.name, [&]() {
|
||||
return test_savepoint_commit(var, original_values);
|
||||
});
|
||||
|
||||
// Multi-value savepoint test
|
||||
if (var.test_values.size() > 1) {
|
||||
add_test("Multi-value savepoint for " + var.name, [&]() {
|
||||
return test_savepoint_release_commit(var, original_values);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add complex scenario tests
|
||||
add_test("Nested BEGIN with rollback", [&]() {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original_tz = getVariable(conn.get(), "timezone");
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET timezone = 'UTC'");
|
||||
executeQuery(conn.get(), "BEGIN"); // Second BEGIN
|
||||
executeQuery(conn.get(), "SET timezone = 'PST8PDT'");
|
||||
executeQuery(conn.get(), "ROLLBACK");
|
||||
|
||||
const bool success = getVariable(conn.get(), "timezone") == original_tz;
|
||||
return success;
|
||||
});
|
||||
|
||||
add_test("Mixed variables in transaction", [&]() {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
bool success = true;
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
}
|
||||
executeQuery(conn.get(), "ROLLBACK");
|
||||
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
success = (getVariable(conn.get(), var.name) == original_values.at(var.name));
|
||||
}
|
||||
return success;
|
||||
});
|
||||
|
||||
int total_tests = 0;
|
||||
|
||||
total_tests = tests.size();
|
||||
plan(total_tests);
|
||||
|
||||
run_tests();
|
||||
|
||||
return exit_status();
|
||||
}
|
||||
Loading…
Reference in new issue