diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 967515fa5..396235a18 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -419,7 +419,7 @@ private: * * @return void. */ - void switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::string_view command, SESSION_FORWARD_TYPE session_type); + bool switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::string_view command, SESSION_FORWARD_TYPE session_type); /** * @brief Switches session from fast forward mode to normal mode. diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 086148c38..facea8850 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -2259,6 +2259,23 @@ __implicit_sync: } } break; + // Handle PostgreSQL COPY protocol frontend messages that may arrive + // after an error during COPY FROM STDIN caused a switch back to normal mode. + // + // Race condition scenario: + // 1. COPY FROM STDIN starts -> session switches to fast_forward mode + // 2. Backend returns error during COPY -> session switches back to normal mode + // 3. Client has already pipelined CopyData('d')/CopyDone('c')/CopyFail('f') messages + // 4. These messages are now in the queue but session is no longer in fast_forward + // + // These messages are meant for fast_forward mode. Simply ignore them. + case 'd': + case 'c': + case 'f': + proxy_debug(PROXY_DEBUG_NET, 5, "Ignoring late COPY protocol message '%c' from client - COPY operation already terminated, session no longer in fast_forward mode\n", c); + l_free(pkt.size, pkt.ptr); + pkt = { 0, nullptr }; + break; default: proxy_error("Not implemented yet. Message type:'%c'\n", c); client_myds->setDSS_STATE_QUERY_SENT_NET(); @@ -2934,7 +2951,15 @@ handler_again: goto __exit_DSS__STATE_NOT_INITIALIZED; } - switch_normal_to_fast_forward_mode(pkt, std::string(matched.data(), matched.size()), SESSION_FORWARD_TYPE_COPY_FROM_STDIN_STDOUT); + if (!switch_normal_to_fast_forward_mode(pkt, std::string(matched.data(), matched.size()), SESSION_FORWARD_TYPE_COPY_FROM_STDIN_STDOUT)) { + // Failed to switch to fast forward mode due to pending packets + client_myds->setDSS_STATE_QUERY_SENT_NET(); + client_myds->myprot.generate_error_packet(true, true, "Unexpected packet sequence during COPY command", + PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, false, true); + RequestEnd(myds, true); + finishQuery(myds, myconn, false); + goto __exit_DSS__STATE_NOT_INITIALIZED; + } break; } } @@ -5632,9 +5657,17 @@ void PgSQL_Session::set_previous_status_mode3(bool allow_execute) { } } -void PgSQL_Session::switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::string_view command, SESSION_FORWARD_TYPE session_type) { +bool PgSQL_Session::switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::string_view command, SESSION_FORWARD_TYPE session_type) { - if (session_fast_forward || session_type == SESSION_FORWARD_TYPE_PERMANENT) return; + if (session_fast_forward || session_type == SESSION_FORWARD_TYPE_PERMANENT) return true; + + // Check if there are pending packets in client_myds->PSarrayIN + // This is an error condition, we cannot switch to fast forward mode + if (client_myds->PSarrayIN->len) { + proxy_error("Cannot switch to fast forward mode: unexpected pending packets in client_myds->PSarrayIN (len=%d). Command: %s\n", + client_myds->PSarrayIN->len, command.data()); + return false; + } // we use a switch to write the command in the info message std::string client_info; @@ -5646,11 +5679,6 @@ void PgSQL_Session::switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::stri command.data(), client_info.c_str(), session_type); session_fast_forward = session_type; - if (client_myds->PSarrayIN->len) { - proxy_error("UNEXPECTED PACKET FROM CLIENT -- PLEASE REPORT A BUG\n"); - assert(0); - } - mybe->server_myds->reinit_queues(); // reinitialize the queues in the myds . By default, they are not active // We reinitialize the 'wait_until' since this session shouldn't wait for processing as // we are now transitioning to 'FAST_FORWARD'. @@ -5690,6 +5718,8 @@ void PgSQL_Session::switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::stri // need to reset mysql_real_query mybe->server_myds->pgsql_real_query.reset(); CurrentQuery.end(); + + return true; } void PgSQL_Session::switch_fast_forward_to_normal_mode() { diff --git a/test/tap/tests/pgsql-copy_freeze_error_recovery-t.cpp b/test/tap/tests/pgsql-copy_freeze_error_recovery-t.cpp new file mode 100644 index 000000000..dc7765992 --- /dev/null +++ b/test/tap/tests/pgsql-copy_freeze_error_recovery-t.cpp @@ -0,0 +1,410 @@ +/** + * @file pgsql-copy_freeze_error_recovery-t.cpp + * @brief Tests COPY FROM ... FREEZE error recovery in ProxySQL + * + * This test reproduces the scenario where: + * 1. COPY command enters fast_forward mode + * 2. Backend returns ERROR + ReadyForQuery immediately (before client sends data) + * because FREEZE requires table to be created or truncated in the current subtransaction + * 3. Session should correctly return to normal mode + * 4. Subsequent queries should work normally + * + * This is a regression test for proper session state recovery after a failed COPY + * command that entered fast_forward mode. + */ + +#include +#include +#include +#include +#include "libpq-fe.h" +#include "command_line.h" +#include "tap.h" +#include "utils.h" + +CommandLine cl; + +using PGConnPtr = std::unique_ptr; + +/** + * @brief Creates a new PostgreSQL connection + * @param with_ssl Whether to use SSL for the connection + * @return A unique pointer to the PGconn structure + */ +PGConnPtr createNewConnection(bool with_ssl) { + std::stringstream ss; + ss << "host=" << cl.pgsql_host << " port=" << cl.pgsql_port; + ss << " user=" << cl.pgsql_username << " password=" << cl.pgsql_password; + ss << " dbname=postgres"; + ss << (with_ssl ? " sslmode=require" : " sslmode=disable"); + + PGconn* conn = PQconnectdb(ss.str().c_str()); + if (PQstatus(conn) != CONNECTION_OK) { + fprintf(stderr, "Connection failed: %s", PQerrorMessage(conn)); + PQfinish(conn); + return PGConnPtr(nullptr, &PQfinish); + } + return PGConnPtr(conn, &PQfinish); +} + +/** + * @brief Executes a single query and checks the result status + * @param conn The PostgreSQL connection + * @param query The query to execute + * @param expected_status The expected result status + * @return true if the query succeeded with expected status, false otherwise + */ +bool executeQuery(PGconn* conn, const char* query, ExecStatusType expected_status = PGRES_COMMAND_OK) { + PGresult* res = PQexec(conn, query); + bool success = PQresultStatus(res) == expected_status; + if (!success) { + diag("Query '%s' failed: %s", query, PQerrorMessage(conn)); + } + PQclear(res); + return success; +} + +/** + * @brief Setup test table + * @param conn The PostgreSQL connection + * @return true if setup succeeded, false otherwise + */ +bool setupTestTable(PGconn* conn) { + PGresult* res = PQexec(conn, "DROP TABLE IF EXISTS copy_freeze_test"); + PQclear(res); + + res = PQexec(conn, "CREATE TABLE copy_freeze_test (id int, name text)"); + bool success = PQresultStatus(res) == PGRES_COMMAND_OK; + if (!success) { + diag("Failed to create table: %s", PQerrorMessage(conn)); + } + PQclear(res); + return success; +} + +/** + * @brief Cleanup test table + * @param conn The PostgreSQL connection + */ +void cleanupTestTable(PGconn* conn) { + PGresult* res = PQexec(conn, "DROP TABLE IF EXISTS copy_freeze_test"); + PQclear(res); +} + +/** + * @brief Test 1: COPY FREEZE fails immediately and session recovers + * + * This test verifies that when a COPY ... FREEZE command fails because the table + * was not created or truncated in the current subtransaction, the session properly + * returns to normal mode and subsequent queries work correctly. + * + * @param conn The PostgreSQL connection + */ +void testCopyFreezeFailsImmediately(PGconn* conn) { + diag("Test: COPY FREEZE fails immediately (table not truncated in current transaction)"); + + // Execute COPY FREEZE - this should fail because table was not truncated + // in the current subtransaction + PGresult* res = PQexec(conn, "COPY copy_freeze_test FROM stdin CSV FREEZE"); + + // The COPY may return PGRES_COPY_IN (if server sends CopyIn before error) + // or PGRES_FATAL_ERROR (if server sends error immediately) + ExecStatusType status = PQresultStatus(res); + + if (status == PGRES_COPY_IN) { + diag("COPY entered COPY_IN mode, sending data..."); + + // Send data - but backend will reject it + if (PQputCopyData(conn, "1,test1\n", 8) != 1) { + diag("PQputCopyData failed: %s", PQerrorMessage(conn)); + } + if (PQputCopyEnd(conn, NULL) != 1) { + diag("PQputCopyEnd failed: %s", PQerrorMessage(conn)); + } + + // Get the final result + PQclear(res); + res = PQgetResult(conn); + status = PQresultStatus(res); + } + + // The COPY should fail + ok(status == PGRES_FATAL_ERROR, + "COPY FREEZE should fail when table not truncated in current transaction: %s", + PQresultErrorMessage(res)); + PQclear(res); + + // Consume any remaining results + while ((res = PQgetResult(conn)) != NULL) { + PQclear(res); + } + + diag("Testing subsequent queries after COPY error..."); + + // Test: BEGIN should work + res = PQexec(conn, "BEGIN"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "BEGIN should work after COPY error: %s", PQerrorMessage(conn)); + PQclear(res); + + // Test: TRUNCATE should work + res = PQexec(conn, "TRUNCATE copy_freeze_test"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "TRUNCATE should work: %s", PQerrorMessage(conn)); + PQclear(res); + + // Test: SAVEPOINT should work + res = PQexec(conn, "SAVEPOINT s1"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "SAVEPOINT should work: %s", PQerrorMessage(conn)); + PQclear(res); + + // Test: COMMIT should work + res = PQexec(conn, "COMMIT"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "COMMIT should work: %s", PQerrorMessage(conn)); + PQclear(res); +} + +/** + * @brief Test 2: COPY FREEZE succeeds when properly set up + * + * This test verifies that COPY ... FREEZE works correctly when the table + * is properly truncated within the same transaction before the COPY command. + * + * IMPORTANT: COPY FREEZE requires that the table was created or truncated + * in the CURRENT subtransaction. Using SAVEPOINT between TRUNCATE and COPY + * FREEZE will cause failure because TRUNCATE is then in the parent subtransaction. + * + * @param conn The PostgreSQL connection + */ +void testCopyFreezeSucceedsWithProperSetup(PGconn* conn) { + diag("Test: COPY FREEZE succeeds with proper transaction setup (no savepoint between TRUNCATE and COPY)"); + + // Begin transaction + PGresult* res = PQexec(conn, "BEGIN"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "BEGIN should succeed: %s", PQerrorMessage(conn)); + PQclear(res); + + // Truncate table in same transaction + // NOTE: No SAVEPOINT here - COPY FREEZE requires TRUNCATE in current subtransaction + res = PQexec(conn, "TRUNCATE copy_freeze_test"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "TRUNCATE should succeed: %s", PQerrorMessage(conn)); + PQclear(res); + + // Now COPY FREEZE should work (TRUNCATE is in same subtransaction) + res = PQexec(conn, "COPY copy_freeze_test FROM stdin CSV FREEZE"); + ok(PQresultStatus(res) == PGRES_COPY_IN, + "COPY FREEZE should enter COPY_IN mode: %s", PQerrorMessage(conn)); + + // Send data + ok(PQputCopyData(conn, "1,test1\n", 8) == 1, + "PQputCopyData should succeed"); + ok(PQputCopyData(conn, "2,test2\n", 8) == 1, + "PQputCopyData should succeed"); + ok(PQputCopyEnd(conn, NULL) == 1, + "PQputCopyEnd should succeed"); + + PQclear(res); + res = PQgetResult(conn); + + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "COPY FREEZE should succeed after proper setup: %s", + PQresultErrorMessage(res)); + PQclear(res); + + // Consume any remaining results + while ((res = PQgetResult(conn)) != NULL) { + PQclear(res); + } + + // Commit transaction + res = PQexec(conn, "COMMIT"); + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "COMMIT should succeed: %s", PQerrorMessage(conn)); + PQclear(res); +} + +/** + * @brief Test 3: Verify data was inserted correctly + * + * @param conn The PostgreSQL connection + */ +void testDataVerification(PGconn* conn) { + diag("Test: Verify data was inserted correctly"); + + PGresult* res = PQexec(conn, "SELECT * FROM copy_freeze_test ORDER BY id"); + ok(PQresultStatus(res) == PGRES_TUPLES_OK, + "SELECT should succeed: %s", PQerrorMessage(conn)); + + int rows = PQntuples(res); + ok(rows == 2, "Should have 2 rows, got %d", rows); + + bool row1_ok = (rows >= 1) && (strcmp(PQgetvalue(res, 0, 0), "1") == 0) && + (strcmp(PQgetvalue(res, 0, 1), "test1") == 0); + ok(row1_ok, "Row 1 should be (1, test1)"); + + bool row2_ok = (rows >= 2) && (strcmp(PQgetvalue(res, 1, 0), "2") == 0) && + (strcmp(PQgetvalue(res, 1, 1), "test2") == 0); + ok(row2_ok, "Row 2 should be (2, test2)"); + PQclear(res); +} + +/** + * @brief Test 4: Multiple COPY errors in sequence + * + * This test verifies that the session can recover from multiple consecutive + * COPY errors. + * + * @param conn The PostgreSQL connection + */ +void testMultipleCopyErrors(PGconn* conn) { + diag("Test: Multiple consecutive COPY errors"); + + // First COPY error + PGresult* res = PQexec(conn, "COPY copy_freeze_test FROM stdin CSV FREEZE"); + ExecStatusType status = PQresultStatus(res); + + if (status == PGRES_COPY_IN) { + PQputCopyEnd(conn, NULL); + PQclear(res); + res = PQgetResult(conn); + } + ok(PQresultStatus(res) == PGRES_FATAL_ERROR, + "First COPY FREEZE should fail: %s", PQresultErrorMessage(res)); + PQclear(res); + while ((res = PQgetResult(conn)) != NULL) PQclear(res); + + // Second COPY error + res = PQexec(conn, "COPY copy_freeze_test FROM stdin CSV FREEZE"); + status = PQresultStatus(res); + + if (status == PGRES_COPY_IN) { + PQputCopyEnd(conn, NULL); + PQclear(res); + res = PQgetResult(conn); + } + ok(PQresultStatus(res) == PGRES_FATAL_ERROR, + "Second COPY FREEZE should fail: %s", PQresultErrorMessage(res)); + PQclear(res); + while ((res = PQgetResult(conn)) != NULL) PQclear(res); + + // Verify subsequent normal query works + res = PQexec(conn, "SELECT 1"); + ok(PQresultStatus(res) == PGRES_TUPLES_OK, + "SELECT should work after multiple COPY errors: %s", PQerrorMessage(conn)); + PQclear(res); +} + +/** + * @brief Test 5: COPY error followed by successful COPY + * + * This test verifies that after a COPY error, a properly executed COPY + * command can succeed. + * + * @param conn The PostgreSQL connection + */ +void testCopyErrorThenSuccess(PGconn* conn) { + diag("Test: COPY error followed by successful COPY"); + + // Truncate table first + PGresult* res = PQexec(conn, "TRUNCATE copy_freeze_test"); + PQclear(res); + + // First COPY - will fail (no transaction/truncate in same transaction) + res = PQexec(conn, "COPY copy_freeze_test FROM stdin CSV FREEZE"); + ExecStatusType status = PQresultStatus(res); + + if (status == PGRES_COPY_IN) { + PQputCopyEnd(conn, NULL); + PQclear(res); + res = PQgetResult(conn); + } + ok(PQresultStatus(res) == PGRES_FATAL_ERROR, + "First COPY FREEZE should fail: %s", PQresultErrorMessage(res)); + PQclear(res); + while ((res = PQgetResult(conn)) != NULL) PQclear(res); + + // Now do it properly + res = PQexec(conn, "BEGIN"); + PQclear(res); + res = PQexec(conn, "TRUNCATE copy_freeze_test"); + PQclear(res); + + res = PQexec(conn, "COPY copy_freeze_test FROM stdin CSV"); + if (PQresultStatus(res) == PGRES_COPY_IN) { + PQputCopyData(conn, "3,test3\n", 7); + PQputCopyEnd(conn, NULL); + PQclear(res); + res = PQgetResult(conn); + } + ok(PQresultStatus(res) == PGRES_COMMAND_OK, + "Regular COPY should succeed after COPY FREEZE error: %s", + PQresultErrorMessage(res)); + PQclear(res); + while ((res = PQgetResult(conn)) != NULL) PQclear(res); + + res = PQexec(conn, "COMMIT"); + PQclear(res); + + // Verify data + res = PQexec(conn, "SELECT COUNT(*) FROM copy_freeze_test"); + ok(PQresultStatus(res) == PGRES_TUPLES_OK && + PQntuples(res) > 0 && + atoi(PQgetvalue(res, 0, 0)) == 1, + "Should have 1 row after successful COPY, got %s", + PQgetvalue(res, 0, 0)); + PQclear(res); +} + +/** + * @brief Run all tests + */ +void runTests(PGconn* conn) { + // Setup + if (!setupTestTable(conn)) { + BAIL_OUT("Failed to setup test table"); + return; + } + + // Run test functions + testCopyFreezeFailsImmediately(conn); + testCopyFreezeSucceedsWithProperSetup(conn); + testDataVerification(conn); + testMultipleCopyErrors(conn); + testCopyErrorThenSuccess(conn); + + // Cleanup + cleanupTestTable(conn); +} + +int main(int argc, char** argv) { + // Total tests: + // testCopyFreezeFailsImmediately: 5 tests (COPY fail, BEGIN, TRUNCATE, SAVEPOINT, COMMIT) + // testCopyFreezeSucceedsWithProperSetup: 8 tests (BEGIN, TRUNCATE, COPY_IN, 3x data, result, COMMIT) + // testDataVerification: 4 tests (SELECT, row count, 2x row data) + // testMultipleCopyErrors: 3 tests (2x error, SELECT) + // testCopyErrorThenSuccess: 3 tests (error, success, count) + // Total: 23 tests + plan(23); + + if (cl.getEnv()) { + return exit_status(); + } + + // Create connection + PGConnPtr conn = createNewConnection(false); + if (!conn) { + BAIL_OUT("Failed to connect to ProxySQL"); + return exit_status(); + } + + diag("Connected to ProxySQL via port %d", cl.pgsql_port); + + // Run tests + runTests(conn.get()); + + return exit_status(); +}