mirror of https://github.com/sysown/proxysql
parent
0adf56eff1
commit
f87fc30261
@ -0,0 +1,290 @@
|
||||
/**
|
||||
* @file pgsql-transaction_variable_state_tracking-t.cpp
|
||||
* @brief TAP test validating PostgreSQL session parameter behavior in transactions.
|
||||
* Tests rollback/commit/savepoint behavior for session variables to ensure state consistency.
|
||||
*/
|
||||
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "libpq-fe.h"
|
||||
#include "command_line.h"
|
||||
#include "tap.h"
|
||||
#include "utils.h"
|
||||
|
||||
CommandLine cl;
|
||||
|
||||
using PGConnPtr = std::unique_ptr<PGconn, decltype(&PQfinish)>;
|
||||
using PGResultPtr = std::unique_ptr<PGresult, decltype(&PQclear)>;
|
||||
|
||||
enum ConnType {
|
||||
ADMIN,
|
||||
BACKEND
|
||||
};
|
||||
|
||||
PGConnPtr createNewConnection(ConnType conn_type, const std::string& options = "", bool with_ssl = false) {
|
||||
|
||||
const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host;
|
||||
int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port;
|
||||
const char* username = (conn_type == BACKEND) ? cl.pgsql_root_username : cl.admin_username;
|
||||
const char* password = (conn_type == BACKEND) ? cl.pgsql_root_password : cl.admin_password;
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "host=" << host << " port=" << port;
|
||||
ss << " user=" << username << " password=" << password;
|
||||
ss << (with_ssl ? " sslmode=require" : " sslmode=disable");
|
||||
|
||||
if (options.empty() == false) {
|
||||
ss << " options='" << options << "'";
|
||||
}
|
||||
|
||||
PGconn* conn = PQconnectdb(ss.str().c_str());
|
||||
if (PQstatus(conn) != CONNECTION_OK) {
|
||||
fprintf(stderr, "Connection failed to '%s': %s", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn));
|
||||
PQfinish(conn);
|
||||
return PGConnPtr(nullptr, &PQfinish);
|
||||
}
|
||||
return PGConnPtr(conn, &PQfinish);
|
||||
}
|
||||
|
||||
struct TestCase {
|
||||
std::string name;
|
||||
std::function<bool()> test_fn;
|
||||
bool should_fail;
|
||||
};
|
||||
|
||||
struct TestVariable {
|
||||
std::string name;
|
||||
std::vector<std::string> test_values;
|
||||
};
|
||||
|
||||
std::vector<TestCase> tests;
|
||||
|
||||
PGResultPtr executeQuery(PGconn* conn, const std::string& query) {
|
||||
diag("Executing: %s", query.c_str());
|
||||
PGresult* res = PQexec(conn, query.c_str());
|
||||
if (PQresultStatus(res) != PGRES_COMMAND_OK && PQresultStatus(res) != PGRES_TUPLES_OK) {
|
||||
diag("Query failed: %s", PQerrorMessage(conn));
|
||||
}
|
||||
return PGResultPtr(res, &PQclear);
|
||||
}
|
||||
|
||||
std::string getVariable(PGconn* conn, const std::string& var) {
|
||||
auto res = executeQuery(conn, ("SHOW " + var));
|
||||
return std::string(PQgetvalue(res.get(), 0, 0));
|
||||
}
|
||||
|
||||
void reset_variable(PGconn* conn, const std::string& var, const std::string& original) {
|
||||
executeQuery(conn, "SET " + var + " = " + original);
|
||||
}
|
||||
|
||||
void add_test(const std::string& name, std::function<bool()> fn, bool should_fail = false) {
|
||||
tests.push_back({ name, fn, should_fail });
|
||||
}
|
||||
|
||||
void run_tests() {
|
||||
|
||||
for (const auto& test : tests) {
|
||||
bool result = false;
|
||||
|
||||
try {
|
||||
result = test.test_fn();
|
||||
if (test.should_fail) result = !result;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
result = false;
|
||||
}
|
||||
|
||||
ok(result, "Test:%s should %s", test.name.c_str(), test.should_fail ? "FAIL" : "PASS");
|
||||
}
|
||||
}
|
||||
|
||||
// Common test scenarios
|
||||
bool test_transaction_rollback(const TestVariable& var) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original = getVariable(conn.get(), var.name);
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
executeQuery(conn.get(), "ROLLBACK");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == original;
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_savepoint_rollback(const TestVariable& var) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original = getVariable(conn.get(), var.name);
|
||||
diag(">>>>> Original value:'%s'", original.c_str());
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
executeQuery(conn.get(), "ROLLBACK TO sp1");
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
auto value = getVariable(conn.get(), var.name);
|
||||
const bool success = value == original;
|
||||
diag(">>>>> Rollback value:'%s'", value.c_str());
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_transaction_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto test_value = var.test_values[0];
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + test_value);
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == test_value;
|
||||
reset_variable(conn.get(), var.name, original_values.at(var.name));
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_savepoint_commit(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto test_value = var.test_values[0];
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + test_value);
|
||||
executeQuery(conn.get(), "RELEASE SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == test_value;
|
||||
reset_variable(conn.get(), var.name, original_values.at(var.name));
|
||||
return success;
|
||||
}
|
||||
|
||||
bool test_savepoint_rollback_partial(const TestVariable& var, const std::map<std::string, std::string>& original_values) {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original = getVariable(conn.get(), var.name);
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
executeQuery(conn.get(), "SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[1]);
|
||||
executeQuery(conn.get(), "RELEASE SAVEPOINT sp1");
|
||||
executeQuery(conn.get(), "COMMIT");
|
||||
|
||||
const bool success = getVariable(conn.get(), var.name) == var.test_values[1];
|
||||
reset_variable(conn.get(), var.name, original_values.at(var.name));
|
||||
return success;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
if (cl.getEnv())
|
||||
return exit_status();
|
||||
|
||||
std::map<std::string, std::string> original_values;
|
||||
std::map<std::string, TestVariable> tracked_vars = {
|
||||
{"client_encoding", {"client_encoding", {"LATIN1", "UTF8"}}},
|
||||
{"datestyle", {"datestyle", {"ISO, MDY", "SQL, DMY"}}},
|
||||
{"intervalstyle", {"intervalstyle", {"postgres", "iso_8601"}}},
|
||||
{"standard_conforming_strings", {"standard_conforming_strings", {"on", "off"}}},
|
||||
{"timezone", {"timezone", {"UTC", "PST8PDT"}}},
|
||||
{"bytea_output", {"bytea_output", {"hex", "escape"}}},
|
||||
{"allow_in_place_tablespaces", {"allow_in_place_tablespaces", {"on", "off"}}},
|
||||
{"enable_bitmapscan", {"enable_bitmapscan", {"on", "off"}}},
|
||||
{"enable_hashjoin", {"enable_hashjoin", {"on", "off"}}},
|
||||
{"enable_indexscan", {"enable_indexscan", {"on", "off"}}},
|
||||
{"enable_nestloop", {"enable_nestloop", {"on", "off"}}},
|
||||
{"enable_seqscan", {"enable_seqscan", {"on", "off"}}},
|
||||
{"enable_sort", {"enable_sort", {"on", "off"}}},
|
||||
{"escape_string_warning", {"escape_string_warning", {"on", "off"}}},
|
||||
{"synchronous_commit", {"synchronous_commit", {"on", "off"}}},
|
||||
{"extra_float_digits", {"extra_float_digits", {"0", "3"}}},
|
||||
{"client_min_messages", {"client_min_messages", {"notice", "warning"}}}
|
||||
};
|
||||
|
||||
|
||||
PGConnPtr conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
if (!conn || PQstatus(conn.get()) != CONNECTION_OK) {
|
||||
BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__);
|
||||
return exit_status();
|
||||
}
|
||||
|
||||
// Store original values
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
original_values[name] = getVariable(conn.get(), name);
|
||||
}
|
||||
|
||||
|
||||
// Add generic tests
|
||||
add_test("Commit without transaction should fail", [&]() {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
PGresult* res = PQexec(conn.get(), "COMMIT");
|
||||
const bool failed = PQresultStatus(res) != PGRES_COMMAND_OK;
|
||||
PQclear(res);
|
||||
return failed;
|
||||
}, true);
|
||||
|
||||
// Add variable-specific tests using containers
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
add_test("Rollback reverts " + var.name, [var]() {
|
||||
return test_transaction_rollback(var);
|
||||
});
|
||||
|
||||
add_test("Commit persists " + var.name, [&]() {
|
||||
return test_transaction_commit(var, original_values);
|
||||
});
|
||||
|
||||
add_test("Savepoint rollback for " + var.name, [var]() {
|
||||
return test_savepoint_rollback(var);
|
||||
});
|
||||
|
||||
add_test("Savepoint commit for " + var.name, [&]() {
|
||||
return test_savepoint_commit(var, original_values);
|
||||
});
|
||||
|
||||
// Multi-value savepoint test
|
||||
if (var.test_values.size() > 1) {
|
||||
add_test("Multi-value savepoint for " + var.name, [&]() {
|
||||
return test_savepoint_rollback_partial(var, original_values);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add complex scenario tests
|
||||
add_test("Nested BEGIN with rollback", [&]() {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
const auto original_tz = getVariable(conn.get(), "timezone");
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
executeQuery(conn.get(), "SET timezone = 'UTC'");
|
||||
executeQuery(conn.get(), "BEGIN"); // Second BEGIN
|
||||
executeQuery(conn.get(), "SET timezone = 'PST8PDT'");
|
||||
executeQuery(conn.get(), "ROLLBACK");
|
||||
|
||||
const bool success = getVariable(conn.get(), "timezone") == original_tz;
|
||||
return success;
|
||||
});
|
||||
|
||||
add_test("Mixed variables in transaction", [&]() {
|
||||
auto conn = createNewConnection(ConnType::BACKEND, "", false);
|
||||
bool success = true;
|
||||
|
||||
executeQuery(conn.get(), "BEGIN");
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]);
|
||||
}
|
||||
executeQuery(conn.get(), "ROLLBACK");
|
||||
|
||||
for (const auto& [name, var] : tracked_vars) {
|
||||
success = (getVariable(conn.get(), var.name) == original_values.at(var.name));
|
||||
}
|
||||
return success;
|
||||
});
|
||||
|
||||
int total_tests = 0;
|
||||
|
||||
total_tests = tests.size();
|
||||
plan(total_tests);
|
||||
|
||||
run_tests();
|
||||
|
||||
return exit_status();
|
||||
}
|
||||
Loading…
Reference in new issue