From abb66cc19fe1e67db97d1add09db17fb2fa27ee9 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Sun, 8 Mar 2026 02:54:40 +0500 Subject: [PATCH] Addressed AI review feeback --- lib/PgSQL_Data_Stream.cpp | 11 +- lib/PgSQL_ExplicitTxnStateMgr.cpp | 13 +- lib/PgSQL_Session.cpp | 11 +- test/tap/tests/pgsql-set_statement_test-t.cpp | 33 ++- ...-transaction_variable_state_tracking-t.cpp | 232 ++++++++++++++++++ 5 files changed, 281 insertions(+), 19 deletions(-) diff --git a/lib/PgSQL_Data_Stream.cpp b/lib/PgSQL_Data_Stream.cpp index ad06a1cd1..7392de226 100644 --- a/lib/PgSQL_Data_Stream.cpp +++ b/lib/PgSQL_Data_Stream.cpp @@ -1311,10 +1311,15 @@ int PgSQL_Data_Stream::buffer2array() { d = header[read_pos++]; pkgsize += (a << 24) | (b << 16) | (c << 8) | d; + if (pkgsize < sizeof(header)) { + proxy_error("Malformed packet (size=%u) received from received from client %s:%d\n", pkgsize, addr.addr ? addr.addr : "", addr.port); + shut_soft(); + return 0; + } + // PostgreSQL packets should always be >= 5 bytes. - const size_t alloc_size = (pkgsize < sizeof(header)) ? sizeof(header) : pkgsize; - queueIN.pkt.size = alloc_size; - queueIN.pkt.ptr = l_alloc(alloc_size); + queueIN.pkt.size = pkgsize; + queueIN.pkt.ptr = l_alloc(pkgsize); memcpy(queueIN.pkt.ptr, header, sizeof(header)); // immediately copy the header into the packet queueIN.partial = sizeof(header); diff --git a/lib/PgSQL_ExplicitTxnStateMgr.cpp b/lib/PgSQL_ExplicitTxnStateMgr.cpp index 34a0bb1e9..ec4a167de 100644 --- a/lib/PgSQL_ExplicitTxnStateMgr.cpp +++ b/lib/PgSQL_ExplicitTxnStateMgr.cpp @@ -129,9 +129,16 @@ void PgSQL_ExplicitTxnStateMgr::rollback(bool rollback_and_chain) { verify_server_variables(session); } - // Keep the transaction state intact when executing ROLLBACK AND CHAIN - if (rollback_and_chain == false) { - // Clear savepoints and reset the initial snapshot + // Handle transaction state cleanup based on rollback type + if (rollback_and_chain) { + // For ROLLBACK AND CHAIN: keep only initial snapshot, remove savepoint snapshots + while (transaction_state.size() > 1) { + reset_variable_snapshot(transaction_state.back()); + transaction_state.pop_back(); + } + } + else { + // For regular ROLLBACK: clear all snapshots for (auto& tran_state : transaction_state) { reset_variable_snapshot(tran_state); } diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index ecc135daf..95dc08975 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -2039,7 +2039,7 @@ __implicit_sync: l_free(pkt.size, pkt.ptr); continue; } else { - proxy_error("Not implemented yet. Message type:'%c'\n", c); + proxy_error("Not implemented yet. Message type:'0x%02X'\n", c); client_myds->setDSS_STATE_QUERY_SENT_NET(); client_myds->myprot.generate_error_packet(true, true, "Feature not supported", PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, false, true); @@ -2304,7 +2304,7 @@ __implicit_sync: break; default: reset_extended_query_frame(); - proxy_error("Not implemented yet. Message type:'%c'\n", c); + proxy_error("Not implemented yet. Message type:'0x%02X'\n", c); client_myds->setDSS_STATE_QUERY_SENT_NET(); bool send_ready_packet = is_extended_query_ready_for_query() && c != 'H'; @@ -4461,9 +4461,9 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___handle_ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___handle_special_commands(const char* dig, bool* lock_hostgroup) { if (!dig) return false; - // When hostgroup is locked, check for RESET commands that could set wrong values + // In pipline mode, when hostgroup is locked, check for RESET commands that could set wrong values // due to pooled connection having different startup parameters than current client - if (locked_on_hostgroup >= 0 && strncasecmp(dig, "RESET ", 6) == 0) { + if (extended_query_phase != EXTQ_PHASE_IDLE && locked_on_hostgroup >= 0 && strncasecmp(dig, "RESET ", 6) == 0) { // Check if startup parameter values differ between client and backend if (mybe && mybe->server_myds && mybe->server_myds->myconn) { // Quick check: see if ANY critical variable has different startup hash @@ -4478,6 +4478,9 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___handle_ } if (startup_mismatch) { + // Discard pending pipeline messages + reset_extended_query_frame(); + // Only do expensive parsing if we're going to block the command std::string nq = std::string(dig); RE2::GlobalReplace(&nq, "(?U)/\\*.*\\*/", ""); diff --git a/test/tap/tests/pgsql-set_statement_test-t.cpp b/test/tap/tests/pgsql-set_statement_test-t.cpp index dd3d4fb20..1afc7bee0 100644 --- a/test/tap/tests/pgsql-set_statement_test-t.cpp +++ b/test/tap/tests/pgsql-set_statement_test-t.cpp @@ -198,6 +198,10 @@ bool test_set_simple_verify_pipeline() { int count = 0; int result_idx = 0; int sock = PQsocket(conn.get()); + if (sock < 0) { + diag("Invalid socket descriptor from PQsocket"); + return false; + } PGresult* res; time_t start_time = time(NULL); const int max_wait_seconds = 30; // Increased timeout @@ -1466,13 +1470,11 @@ bool test_set_different_values_from_original() { continue; } std::string final_val = it->second; - bool changed = (final_val.find(var.test_value) != std::string::npos) || - (var.test_value.find(final_val) != std::string::npos); + // Use exact equality to avoid false positives from substring matching + bool changed = (final_val == var.test_value); // Also verify it's NOT the original - bool is_original = (final_val == var.initial_value) || - (var.initial_value.find(final_val) != std::string::npos) || - (final_val.find(var.initial_value) != std::string::npos); + bool is_original = (final_val == var.initial_value); diag("%s: original='%s', test='%s', final='%s', changed=%s, is_original=%s", var.name.c_str(), var.initial_value.c_str(), var.test_value.c_str(), @@ -1618,10 +1620,23 @@ bool test_multiple_vars_out_of_sync_pipeline() { std::string bo1 = get_variable_simple(conn1.get(), "bytea_output"); // Close connection (returns to pool with these values) conn1.reset(); - diag("Connection 1 closed - returned to pool with DateStyle='Postgres, DMY', TimeZone='PST8PDT', bytea_output='escape'"); - - // Small delay to ensure connection is returned to pool - usleep(100000); + diag("Connection 1 closed - returned to pool with DateStyle='Postgres, MDY', TimeZone='PST8PDT', bytea_output='escape'"); + + // Wait for connection to be returned to pool with polling (max 5 seconds) + // Fixed delay replaced with polling to handle slow CI systems + bool conn_in_pool = false; + for (int retry = 0; retry < 50; retry++) { + usleep(100000); // 100ms * 50 = 5 seconds max + // Check if we can create a new connection (indicates pool has capacity) + PGConnPtr test_conn = createNewConnection(BACKEND); + if (test_conn) { + conn_in_pool = true; + break; + } + } + if (!conn_in_pool) { + diag("Warning: Connection may not have returned to pool yet, continuing anyway"); + } // Step 2: Create new connection with DIFFERENT variable values (simple query mode) PGConnPtr conn2 = createNewConnection(BACKEND); diff --git a/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp index ee1a5a714..84d9917ee 100644 --- a/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp +++ b/test/tap/tests/pgsql-transaction_variable_state_tracking-t.cpp @@ -394,6 +394,49 @@ int main(int argc, char** argv) { return success; }); + // Test: ROLLBACK AND CHAIN with savepoints - verifies savepoint snapshots are cleaned + add_test("ROLLBACK AND CHAIN with savepoints", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + + const auto original = getVariable(conn.get(), "DateStyle"); + + // Start transaction and set initial value + executeQuery(conn.get(), "BEGIN"); + executeQuery(conn.get(), "SET DateStyle = 'Postgres, DMY'"); + + // Create savepoints (these create snapshots in transaction_state) + executeQuery(conn.get(), "SAVEPOINT sp1"); + executeQuery(conn.get(), "SET DateStyle = 'SQL, DMY'"); + executeQuery(conn.get(), "SAVEPOINT sp2"); + executeQuery(conn.get(), "SET DateStyle = 'ISO, MDY'"); + + // ROLLBACK AND CHAIN should clear all savepoint snapshots + executeQuery(conn.get(), "ROLLBACK AND CHAIN"); + + // Verify we're still in a transaction + char tran_stat = PQtransactionStatus(conn.get()); + if (tran_stat != PQTRANS_INTRANS) { + diag("Expected INTRANS after ROLLBACK AND CHAIN, got %d", tran_stat); + executeQuery(conn.get(), "ROLLBACK"); + return false; + } + + // Verify DateStyle was reset to original (before BEGIN) + bool datestyle_ok = (getVariable(conn.get(), "DateStyle") == original); + + // Now test that we can create new savepoints (this would fail if stale snapshots remained) + executeQuery(conn.get(), "SAVEPOINT sp_after_chain"); + executeQuery(conn.get(), "SET DateStyle = 'Postgres, DMY'"); + + // Rollback to savepoint + executeQuery(conn.get(), "ROLLBACK TO SAVEPOINT sp_after_chain"); + + // Final cleanup + executeQuery(conn.get(), "ROLLBACK"); + + return datestyle_ok; + }); + add_test("Prepared ROLLBACK statement", [&]() { auto conn = createNewConnection(ConnType::BACKEND, "", false); @@ -2647,6 +2690,195 @@ int main(int argc, char** argv) { return simple_correct && pipeline_correct && values_match; }); + // ============================================================================ + // RESET Pipeline Mode Tests - Verify pipeline invariant fix + // ============================================================================ + + // Test: RESET ALL in pipeline mode should be rejected and pipeline reset + // This tests the fix for: When RESET is rejected, reset_extended_query_frame() + // must be called to prevent subsequent messages from being processed incorrectly + add_test("Pipeline: RESET ALL rejected with proper pipeline reset", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + if (!conn) return false; + + // Step 1: Set a variable to create marker state + executeQuery(conn.get(), "SET DateStyle = 'Postgres, DMY'"); + std::string marker_val = getVariable(conn.get(), "DateStyle"); + diag("Marker value set: '%s'", marker_val.c_str()); + + // Step 2: Enter pipeline mode + if (PQenterPipelineMode(conn.get()) != 1) { + diag("Failed to enter pipeline mode"); + return false; + } + + // Step 3: Send RESET ALL (may be rejected due to startup mismatch) + if (PQsendQueryParams(conn.get(), "RESET ALL", 0, NULL, NULL, NULL, NULL, 0) == 0) { + diag("Failed to send RESET ALL"); + PQexitPipelineMode(conn.get()); + return false; + } + + // Step 4: Send a subsequent query - this tests the pipeline reset fix + if (PQsendQueryParams(conn.get(), "SELECT 1 as test_col", 0, NULL, NULL, NULL, NULL, 0) == 0) { + diag("Failed to send SELECT 1"); + PQexitPipelineMode(conn.get()); + return false; + } + + // Step 5: Sync + if (PQpipelineSync(conn.get()) != 1) { + diag("PQpipelineSync failed"); + PQexitPipelineMode(conn.get()); + return false; + } + + // Step 6: Consume results with proper loop (like working tests) + int count = 0; + int errors = 0; + int select_ok = 0; + int sock = PQsocket(conn.get()); + PGresult* res; + + while (count < 3) { + if (PQconsumeInput(conn.get()) == 0) { + diag("PQconsumeInput failed: %s", PQerrorMessage(conn.get())); + PQexitPipelineMode(conn.get()); + return false; + } + while ((res = PQgetResult(conn.get())) != NULL) { + ExecStatusType status = PQresultStatus(res); + if (status == PGRES_TUPLES_OK) { + select_ok++; + diag("SELECT 1 returned: %s", PQgetvalue(res, 0, 0)); + } else if (status == PGRES_FATAL_ERROR) { + errors++; + diag("Command failed (expected for RESET ALL): %s", PQresultErrorMessage(res)); + } else if (status == PGRES_PIPELINE_SYNC) { + PQclear(res); + count++; + break; + } + PQclear(res); + count++; + } + if (count >= 3) break; + if (!PQisBusy(conn.get())) continue; + fd_set input_mask; + FD_ZERO(&input_mask); + FD_SET(sock, &input_mask); + struct timeval timeout = {5, 0}; + select(sock + 1, &input_mask, NULL, NULL, &timeout); + } + + PQexitPipelineMode(conn.get()); + + // Cleanup + executeQuery(conn.get(), "SET DateStyle = 'ISO, MDY'"); + + // Test passes if: + // 1. RESET failed (errors > 0) - the error was returned + // 2. SELECT was NOT executed (select_ok == 0) - frame was reset/discarded + diag("Results: errors=%d, select_ok=%d", errors, select_ok); + return (errors > 0 && select_ok == 0); + }); + + // Test: RESET single variable in pipeline mode + add_test("Pipeline: RESET single variable with pipeline reset", [&]() { + auto conn = createNewConnection(ConnType::BACKEND, "", false); + if (!conn) return false; + + // Step 1: Set a marker value + executeQuery(conn.get(), "SET DateStyle = 'SQL, DMY'"); + + // Step 2: Enter pipeline mode + if (PQenterPipelineMode(conn.get()) != 1) { + diag("Failed to enter pipeline mode"); + return false; + } + + // Step 3: Send RESET DateStyle + if (PQsendQueryParams(conn.get(), "RESET DateStyle", 0, NULL, NULL, NULL, NULL, 0) == 0) { + diag("Failed to send RESET DateStyle"); + PQexitPipelineMode(conn.get()); + return false; + } + + // Step 4: Send subsequent query + if (PQsendQueryParams(conn.get(), "SELECT 2 as test_col", 0, NULL, NULL, NULL, NULL, 0) == 0) { + diag("Failed to send SELECT 2"); + PQexitPipelineMode(conn.get()); + return false; + } + + // Step 5: Sync + if (PQpipelineSync(conn.get()) != 1) { + diag("PQpipelineSync failed"); + PQexitPipelineMode(conn.get()); + return false; + } + + // Step 6: Consume results with proper loop + int count = 0; + int errors = 0; + int select_ok = 0; + int reset_ok = 0; + int sock = PQsocket(conn.get()); + PGresult* res; + + while (count < 3) { + if (PQconsumeInput(conn.get()) == 0) { + diag("PQconsumeInput failed: %s", PQerrorMessage(conn.get())); + PQexitPipelineMode(conn.get()); + return false; + } + while ((res = PQgetResult(conn.get())) != NULL) { + ExecStatusType status = PQresultStatus(res); + if (status == PGRES_COMMAND_OK) { + reset_ok++; + diag("RESET DateStyle succeeded"); + } else if (status == PGRES_TUPLES_OK) { + select_ok++; + diag("SELECT 2 returned: %s", PQgetvalue(res, 0, 0)); + } else if (status == PGRES_FATAL_ERROR) { + errors++; + diag("Command failed: %s", PQresultErrorMessage(res)); + } else if (status == PGRES_PIPELINE_SYNC) { + PQclear(res); + count++; + break; + } + PQclear(res); + count++; + } + if (count >= 3) break; + if (!PQisBusy(conn.get())) continue; + fd_set input_mask; + FD_ZERO(&input_mask); + FD_SET(sock, &input_mask); + struct timeval timeout = {5, 0}; + select(sock + 1, &input_mask, NULL, NULL, &timeout); + } + + PQexitPipelineMode(conn.get()); + + // Cleanup + executeQuery(conn.get(), "SET DateStyle = 'ISO, MDY'"); + + // Test logic: + // If RESET failed (errors > 0), SELECT should be discarded (select_ok == 0) - frame was reset + // If RESET succeeded (reset_ok > 0), SELECT should also succeed (select_ok > 0) - normal operation + diag("Results: reset_ok=%d, errors=%d, select_ok=%d", reset_ok, errors, select_ok); + if (errors > 0) { + // RESET was rejected - frame should be reset, SELECT discarded + return (select_ok == 0); + } else if (reset_ok > 0) { + // RESET succeeded - pipeline should continue normally + return (select_ok > 0); + } + return false; // Neither success nor error - unexpected + }); + int total_tests = 0; total_tests = tests.size();