diff --git a/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp new file mode 100644 index 000000000..730a5218f --- /dev/null +++ b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp @@ -0,0 +1,290 @@ +/** + * @file pgsql-transaction_variable_state_tracking-t.cpp + * @brief TAP test validating PostgreSQL session parameter behavior in transactions. + * Tests rollback/commit/savepoint behavior for session variables to ensure state consistency. + */ + +#include +#include +#include +#include +#include +#include "libpq-fe.h" +#include "command_line.h" +#include "tap.h" +#include "utils.h" + +CommandLine cl; + +using PGConnPtr = std::unique_ptr; +using PGResultPtr = std::unique_ptr; + +enum ConnType { + ADMIN, + BACKEND +}; + +PGConnPtr createNewConnection(ConnType conn_type, const std::string& options = "", bool with_ssl = false) { + + const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host; + int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port; + const char* username = (conn_type == BACKEND) ? cl.pgsql_root_username : cl.admin_username; + const char* password = (conn_type == BACKEND) ? cl.pgsql_root_password : cl.admin_password; + + std::stringstream ss; + + ss << "host=" << host << " port=" << port; + ss << " user=" << username << " password=" << password; + ss << (with_ssl ? " sslmode=require" : " sslmode=disable"); + + if (options.empty() == false) { + ss << " options='" << options << "'"; + } + + PGconn* conn = PQconnectdb(ss.str().c_str()); + if (PQstatus(conn) != CONNECTION_OK) { + fprintf(stderr, "Connection failed to '%s': %s", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn)); + PQfinish(conn); + return PGConnPtr(nullptr, &PQfinish); + } + return PGConnPtr(conn, &PQfinish); +} + +struct TestCase { + std::string name; + std::function test_fn; + bool should_fail; +}; + +struct TestVariable { + std::string name; + std::vector test_values; +}; + +std::vector tests; + +PGResultPtr executeQuery(PGconn* conn, const std::string& query) { + diag("Executing: %s", query.c_str()); + PGresult* res = PQexec(conn, query.c_str()); + if (PQresultStatus(res) != PGRES_COMMAND_OK && PQresultStatus(res) != PGRES_TUPLES_OK) { + diag("Query failed: %s", PQerrorMessage(conn)); + } + return PGResultPtr(res, &PQclear); +} + +std::string getVariable(PGconn* conn, const std::string& var) { + auto res = executeQuery(conn, ("SHOW " + var)); + return std::string(PQgetvalue(res.get(), 0, 0)); +} + +void reset_variable(PGconn* conn, const std::string& var, const std::string& original) { + executeQuery(conn, "SET " + var + " = " + original); +} + +void add_test(const std::string& name, std::function fn, bool should_fail = false) { + tests.push_back({ name, fn, should_fail }); +} + +void run_tests() { + + for (const auto& test : tests) { + bool result = false; + + try { + result = test.test_fn(); + if (test.should_fail) result = !result; + } + catch (const std::exception& e) { + result = false; + } + + ok(result, "Test:%s should %s", test.name.c_str(), test.should_fail ? "FAIL" : "PASS"); + } +} + +// Common test scenarios +bool test_transaction_rollback(const TestVariable& var) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original = getVariable(conn.get(), var.name); + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + executeQuery(conn.get(), "ROLLBACK"); + + const bool success = getVariable(conn.get(), var.name) == original; + return success; +} + +bool test_savepoint_rollback(const TestVariable& var) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original = getVariable(conn.get(), var.name); + diag(">>>>> Original value:'%s'", original.c_str()); + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + executeQuery(conn.get(), "ROLLBACK TO sp1"); + executeQuery(conn.get(), "COMMIT"); + auto value = getVariable(conn.get(), var.name); + const bool success = value == original; + diag(">>>>> Rollback value:'%s'", value.c_str()); + return success; +} + +bool test_transaction_commit(const TestVariable& var, const std::map& original_values) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto test_value = var.test_values[0]; + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET " + var.name + " = " + test_value); + executeQuery(conn.get(), "COMMIT"); + + const bool success = getVariable(conn.get(), var.name) == test_value; + reset_variable(conn.get(), var.name, original_values.at(var.name)); + return success; +} + +bool test_savepoint_commit(const TestVariable& var, const std::map& original_values) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto test_value = var.test_values[0]; + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET " + var.name + " = " + test_value); + executeQuery(conn.get(), "RELEASE SAVEPOINT sp1"); + executeQuery(conn.get(), "COMMIT"); + + const bool success = getVariable(conn.get(), var.name) == test_value; + reset_variable(conn.get(), var.name, original_values.at(var.name)); + return success; +} + +bool test_savepoint_rollback_partial(const TestVariable& var, const std::map& original_values) { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original = getVariable(conn.get(), var.name); + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[1]); + executeQuery(conn.get(), "RELEASE SAVEPOINT sp1"); + executeQuery(conn.get(), "COMMIT"); + + const bool success = getVariable(conn.get(), var.name) == var.test_values[1]; + reset_variable(conn.get(), var.name, original_values.at(var.name)); + return success; +} + +int main(int argc, char** argv) { + + if (cl.getEnv()) + return exit_status(); + + std::map original_values; + std::map tracked_vars = { + {"client_encoding", {"client_encoding", {"LATIN1", "UTF8"}}}, + {"datestyle", {"datestyle", {"ISO, MDY", "SQL, DMY"}}}, + {"intervalstyle", {"intervalstyle", {"postgres", "iso_8601"}}}, + {"standard_conforming_strings", {"standard_conforming_strings", {"on", "off"}}}, + {"timezone", {"timezone", {"UTC", "PST8PDT"}}}, + {"bytea_output", {"bytea_output", {"hex", "escape"}}}, + {"allow_in_place_tablespaces", {"allow_in_place_tablespaces", {"on", "off"}}}, + {"enable_bitmapscan", {"enable_bitmapscan", {"on", "off"}}}, + {"enable_hashjoin", {"enable_hashjoin", {"on", "off"}}}, + {"enable_indexscan", {"enable_indexscan", {"on", "off"}}}, + {"enable_nestloop", {"enable_nestloop", {"on", "off"}}}, + {"enable_seqscan", {"enable_seqscan", {"on", "off"}}}, + {"enable_sort", {"enable_sort", {"on", "off"}}}, + {"escape_string_warning", {"escape_string_warning", {"on", "off"}}}, + {"synchronous_commit", {"synchronous_commit", {"on", "off"}}}, + {"extra_float_digits", {"extra_float_digits", {"0", "3"}}}, + {"client_min_messages", {"client_min_messages", {"notice", "warning"}}} + }; + + + PGConnPtr conn = createNewConnection(ConnType::BACKEND, "", false); + if (!conn || PQstatus(conn.get()) != CONNECTION_OK) { + BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + + // Store original values + for (const auto& [name, var] : tracked_vars) { + original_values[name] = getVariable(conn.get(), name); + } + + + // Add generic tests + add_test("Commit without transaction should fail", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + PGresult* res = PQexec(conn.get(), "COMMIT"); + const bool failed = PQresultStatus(res) != PGRES_COMMAND_OK; + PQclear(res); + return failed; + }, true); + + // Add variable-specific tests using containers + for (const auto& [name, var] : tracked_vars) { + add_test("Rollback reverts " + var.name, [var]() { + return test_transaction_rollback(var); + }); + + add_test("Commit persists " + var.name, [&]() { + return test_transaction_commit(var, original_values); + }); + + add_test("Savepoint rollback for " + var.name, [var]() { + return test_savepoint_rollback(var); + }); + + add_test("Savepoint commit for " + var.name, [&]() { + return test_savepoint_commit(var, original_values); + }); + + // Multi-value savepoint test + if (var.test_values.size() > 1) { + add_test("Multi-value savepoint for " + var.name, [&]() { + return test_savepoint_rollback_partial(var, original_values); + }); + } + } + + // Add complex scenario tests + add_test("Nested BEGIN with rollback", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + const auto original_tz = getVariable(conn.get(), "timezone"); + + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET timezone = 'UTC'"); + executeQuery(conn.get(), "BEGIN"); // Second BEGIN + executeQuery(conn.get(), "SET timezone = 'PST8PDT'"); + executeQuery(conn.get(), "ROLLBACK"); + + const bool success = getVariable(conn.get(), "timezone") == original_tz; + return success; + }); + + add_test("Mixed variables in transaction", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + bool success = true; + + executeQuery(conn.get(), "BEGIN"); + for (const auto& [name, var] : tracked_vars) { + executeQuery(conn.get(), "SET " + var.name + " = " + var.test_values[0]); + } + executeQuery(conn.get(), "ROLLBACK"); + + for (const auto& [name, var] : tracked_vars) { + success = (getVariable(conn.get(), var.name) == original_values.at(var.name)); + } + return success; + }); + + int total_tests = 0; + + total_tests = tests.size(); + plan(total_tests); + + run_tests(); + + return exit_status(); +}