diff --git a/test/tap/tests/pgsql-test_malformed_packet-t.cpp b/test/tap/tests/pgsql-test_malformed_packet-t.cpp index 55cdeef4a..267fe3323 100644 --- a/test/tap/tests/pgsql-test_malformed_packet-t.cpp +++ b/test/tap/tests/pgsql-test_malformed_packet-t.cpp @@ -53,6 +53,10 @@ constexpr uint32_t PG_SSL_REQUEST_CODE = 80877103; // (1234 << 16) | 5679 constexpr uint32_t PG_CANCEL_REQUEST_CODE = 80877102; // (1234 << 16) | 5678 constexpr uint32_t PG_GSS_ENCRYPT_CODE = 80877104; // (1234 << 16) | 5680 +// Forward declarations for helper functions +bool send_exact(int sock, const void* buf, size_t len); +bool recv_exact(int sock, void* buf, size_t len); + #define REPORT_ERROR_AND_EXIT(fmt, ...) \ do { \ fprintf(stderr, "File %s, line %d: " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__); \ @@ -128,8 +132,7 @@ void test_malformed_packet(const std::string& test_name, } // Send the malformed data - ssize_t bytes_sent = send(sock, data.data(), data.size(), 0); - if (bytes_sent < 0) { + if (!send_exact(sock, data.data(), data.size())) { close(sock); ok(0, "%s: Failed to send data", test_name.c_str()); return; @@ -139,14 +142,17 @@ void test_malformed_packet(const std::string& test_name, std::vector buffer(BUFFER_SIZE); ssize_t bytes_received = recv(sock, buffer.data(), buffer.size(), 0); - // Valid outcomes: connection closed OR any response - // The key is that ProxySQL handles the packet without crashing - bool connection_closed = (bytes_received <= 0); - bool got_response = (bytes_received > 0); - bool handled_gracefully = connection_closed || got_response; + // Valid outcomes: + // - Connection closed (bytes_received == 0): ProxySQL rejected/closed connection + // - Timeout (bytes_received < 0 with EAGAIN/EWOULDBLOCK): No response within timeout + // - Error response (bytes_received > 0 && buffer[0] == 'E'): ProxySQL sent error + bool connection_closed = (bytes_received == 0); + bool timeout = (bytes_received < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)); + bool got_error_response = (bytes_received > 0 && buffer[0] == 'E'); + bool handled_gracefully = connection_closed || timeout || got_error_response; - ok(handled_gracefully, "%s: Malformed packet handled (received: %ld bytes)", - test_name.c_str(), (long)bytes_received); + ok(handled_gracefully, "%s: Malformed packet handled (closed=%d, timeout=%d, error=%d)", + test_name.c_str(), (int)connection_closed, (int)timeout, (int)got_error_response); close(sock); } @@ -488,6 +494,21 @@ bool recv_exact(int sock, void* buf, size_t len) { return true; } +/** + * @brief Send an exact number of bytes to socket + * @return true if all bytes sent, false on error/short write + */ +bool send_exact(int sock, const void* buf, size_t len) { + const char* ptr = (const char*)buf; + size_t sent = 0; + while (sent < len) { + ssize_t n = send(sock, ptr + sent, len - sent, 0); + if (n <= 0) return false; + sent += n; + } + return true; +} + /** * @brief Read a PostgreSQL message from socket */ @@ -537,7 +558,7 @@ void test_malformed_packet_phase2(const std::string& test_name, // Step 1: Send valid startup packet auto startup = build_startup_packet({{"user", username}}); - if (send(sock, startup.data(), startup.size(), 0) < 0) { + if (!send_exact(sock, startup.data(), startup.size())) { close(sock); ok(0, "%s: Failed to send startup", test_name.c_str()); return; @@ -554,7 +575,7 @@ void test_malformed_packet_phase2(const std::string& test_name, // Step 3: Send password response (cleartext for simplicity) auto password_pkt = build_password_packet(password); - if (send(sock, password_pkt.data(), password_pkt.size(), 0) < 0) { + if (!send_exact(sock, password_pkt.data(), password_pkt.size())) { close(sock); ok(0, "%s: Failed to send password", test_name.c_str()); return; @@ -599,36 +620,79 @@ void test_malformed_packet_phase2(const std::string& test_name, // Authentication successful - now we have an authenticated connection - // Step 5: Send the malformed packet on the authenticated connection - ssize_t sent = send(sock, malformed_data.data(), malformed_data.size(), 0); - if (sent < 0) { + // Step 5: Drain post-authentication messages until ReadyForQuery ('Z') + // PostgreSQL sends ParameterStatus ('S'), BackendKeyData ('K'), etc. after auth + // We need to consume these before sending the malformed packet to ensure + // the response we read is actually for the malformed packet, not startup noise + bool ready_for_query = false; + for (int i = 0; i < 16; ++i) { + char msg_type = 0; + std::vector msg_data; + if (!read_message(sock, msg_type, msg_data)) break; + if (msg_type == 'Z') { + ready_for_query = true; + break; + } + if (msg_type == 'E') break; // Error during startup + } + if (!ready_for_query) { close(sock); - ok(0, "%s: Failed to send malformed data", test_name.c_str()); + ok(0, "%s: Did not reach ReadyForQuery before malformed packet", test_name.c_str()); return; } - // Step 6: Try to receive response - // ProxySQL may send multiple responses (auth completion, parameter status, etc.) - // followed by an error response 'E' for the malformed packet, or close connection - std::vector buffer(BUFFER_SIZE); - ssize_t bytes_received = recv(sock, buffer.data(), buffer.size(), 0); + // Step 6: Send the malformed packet on the authenticated connection + if (!send_exact(sock, malformed_data.data(), malformed_data.size())) { + close(sock); + ok(0, "%s: Failed to send malformed data", test_name.c_str()); + return; + } - // Check if we got an error response 'E' or connection closed - // Note: There may be pending post-auth messages, so any of these outcomes is valid: - // 1. Connection closed (ProxySQL rejected the malformed packet) - // 2. 'E' error response (ProxySQL sent an error for the malformed packet) - // 3. Other messages (post-auth protocol messages) - bool connection_closed = (bytes_received <= 0); - bool got_error_response = (bytes_received > 0 && buffer[0] == 'E'); - bool got_other_response = (bytes_received > 0 && buffer[0] != 'E'); + // Step 7: Wait for response to the malformed packet + // Keep reading messages until we get an error response, connection close, or timeout + bool got_error_response = false; + bool connection_closed = false; + bool timeout = false; + char first_byte = 0; + + for (int i = 0; i < 16; ++i) { + char msg_type = 0; + std::vector msg_data; + + // Try to read a message with timeout + ssize_t bytes_received = recv(sock, &msg_type, 1, MSG_PEEK | MSG_DONTWAIT); + if (bytes_received == 0) { + connection_closed = true; + break; + } + if (bytes_received < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + timeout = true; + } + break; + } + + // Message available, read it fully + if (!read_message(sock, msg_type, msg_data)) { + connection_closed = true; + break; + } + + first_byte = msg_type; + + // Check if this is an error response to our malformed packet + if (msg_type == 'E') { + got_error_response = true; + break; + } + // Continue reading if it's other protocol traffic + } - // Any outcome is acceptable as long as ProxySQL doesn't crash - // The key test is the final operational check - bool handled_gracefully = connection_closed || got_error_response || got_other_response; + bool handled_gracefully = connection_closed || timeout || got_error_response; - ok(handled_gracefully, "%s: Phase 2 malformed packet sent (received: %ld bytes, first=0x%02X)", - test_name.c_str(), (long)bytes_received, - bytes_received > 0 ? (unsigned char)buffer[0] : 0); + ok(handled_gracefully, "%s: Phase 2 malformed packet handled (closed=%d, timeout=%d, error=%d, first=0x%02X)", + test_name.c_str(), (int)connection_closed, (int)timeout, (int)got_error_response, + first_byte); close(sock); }