From 33a6682bb91037d7bfd536980183101dba3c4dae Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Mon, 9 Mar 2026 18:41:06 +0500 Subject: [PATCH] * Use libpq to authenticate in Phase2 * Added support for hostnames --- .../tests/pgsql-test_malformed_packet-t.cpp | 183 ++++++++---------- 1 file changed, 77 insertions(+), 106 deletions(-) diff --git a/test/tap/tests/pgsql-test_malformed_packet-t.cpp b/test/tap/tests/pgsql-test_malformed_packet-t.cpp index 267fe3323..59cc5253a 100644 --- a/test/tap/tests/pgsql-test_malformed_packet-t.cpp +++ b/test/tap/tests/pgsql-test_malformed_packet-t.cpp @@ -34,6 +34,8 @@ #include #include #include +#include +#include #include #include #include @@ -71,41 +73,68 @@ typedef enum { /** * @brief Create a raw TCP socket connection to the specified host and port + * + * Supports both IP addresses (IPv4 and IPv6) and domain names. + * Uses getaddrinfo() for name resolution. */ int create_raw_connection(const std::string& host, int port) { - int sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock < 0) { - fprintf(stderr, "Socket creation failed: %s\n", strerror(errno)); + struct addrinfo hints{}; + struct addrinfo* result = nullptr; + struct addrinfo* rp = nullptr; + int sock = -1; + + // Prepare hints for getaddrinfo + hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM; // TCP socket + + // Convert port to string for getaddrinfo + std::string port_str = std::to_string(port); + + // Resolve the hostname + int gai_err = getaddrinfo(host.c_str(), port_str.c_str(), &hints, &result); + if (gai_err != 0) { + fprintf(stderr, "Failed to resolve host '%s': %s\n", + host.c_str(), gai_strerror(gai_err)); return -1; } - // Set socket timeout - struct timeval timeout; - timeout.tv_sec = TIMEOUT_SEC; - timeout.tv_usec = 0; - if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { - fprintf(stderr, "Failed to set socket timeout: %s\n", strerror(errno)); - close(sock); - return -1; - } + // Try each address returned by getaddrinfo until we successfully connect + for (rp = result; rp != nullptr; rp = rp->ai_next) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (sock < 0) { + continue; // Try next address + } + + // Set socket timeout + struct timeval timeout; + timeout.tv_sec = TIMEOUT_SEC; + timeout.tv_usec = 0; + if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { + fprintf(stderr, "Failed to set socket timeout: %s\n", strerror(errno)); + close(sock); + sock = -1; + continue; + } - // Set TCP_NODELAY for immediate packet sending - int flag = 1; - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); + // Set TCP_NODELAY for immediate packet sending + int flag = 1; + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); - sockaddr_in server_addr{}; - server_addr.sin_family = AF_INET; - server_addr.sin_port = htons(port); + // Attempt to connect + if (connect(sock, rp->ai_addr, rp->ai_addrlen) == 0) { + // Success! + break; + } - if (inet_pton(AF_INET, host.c_str(), &server_addr.sin_addr) <= 0) { - fprintf(stderr, "Invalid address: %s\n", host.c_str()); + // Connection failed, try next address close(sock); - return -1; + sock = -1; } - if (connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { - fprintf(stderr, "Connection failed: %s\n", strerror(errno)); - close(sock); + freeaddrinfo(result); + + if (sock < 0) { + fprintf(stderr, "Connection failed to %s:%d\n", host.c_str(), port); return -1; } @@ -550,96 +579,37 @@ void test_malformed_packet_phase2(const std::string& test_name, const std::string& password, const std::vector& malformed_data) { - int sock = create_raw_connection(host, port); - if (sock < 0) { - ok(0, "%s: Failed to create connection", test_name.c_str()); - return; - } - - // Step 1: Send valid startup packet - auto startup = build_startup_packet({{"user", username}}); - if (!send_exact(sock, startup.data(), startup.size())) { - close(sock); - ok(0, "%s: Failed to send startup", test_name.c_str()); - return; - } - - // Step 2: Wait for authentication request (usually MD5 or SASL) - char auth_type; - std::vector auth_data; - if (!read_message(sock, auth_type, auth_data)) { - close(sock); - ok(0, "%s: Failed to read auth request", test_name.c_str()); - return; - } - - // Step 3: Send password response (cleartext for simplicity) - auto password_pkt = build_password_packet(password); - if (!send_exact(sock, password_pkt.data(), password_pkt.size())) { - close(sock); - ok(0, "%s: Failed to send password", test_name.c_str()); - return; - } - - // Step 4: Wait for authentication result - if (!read_message(sock, auth_type, auth_data)) { - close(sock); - ok(0, "%s: Failed to read auth result", test_name.c_str()); - return; - } + // Use libpq to establish authenticated connection, then extract socket + std::stringstream conninfo; + conninfo << "host=" << host << " port=" << port + << " user=" << username << " password=" << password + << " sslmode=disable"; - // Check if authentication succeeded - if (auth_type == 'E') { - // Authentication failed - fail the test - // Post-authentication tests require successful auth - close(sock); - ok(0, "%s: Authentication failed, cannot run post-auth test", - test_name.c_str()); + PGconn* conn = PQconnectdb(conninfo.str().c_str()); + if (PQstatus(conn) != CONNECTION_OK) { + diag("libpq connection failed: %s", PQerrorMessage(conn)); + PQfinish(conn); + ok(0, "%s: Failed to create libpq connection", test_name.c_str()); return; } - if (auth_type != 'R' || auth_data.size() < 4) { - // Unexpected response - fail the test - close(sock); - ok(0, "%s: Unexpected auth response, cannot run post-auth test", - test_name.c_str()); + // Get the underlying socket from libpq connection + int sock = PQsocket(conn); + if (sock < 0) { + diag("PQsocket failed"); + PQfinish(conn); + ok(0, "%s: Failed to get socket from libpq connection", test_name.c_str()); return; } - // Check auth result code - int32_t auth_code = (auth_data[0] << 24) | (auth_data[1] << 16) | - (auth_data[2] << 8) | auth_data[3]; - - if (auth_code != 0) { - // Authentication didn't complete successfully - fail the test - close(sock); - ok(0, "%s: Authentication incomplete (code: %d), cannot run post-auth test", - test_name.c_str(), auth_code); - return; + // Set socket to blocking mode (libpq may have set it non-blocking) + int flags = fcntl(sock, F_GETFL, 0); + if (flags >= 0) { + fcntl(sock, F_SETFL, flags & ~O_NONBLOCK); } - // Authentication successful - now we have an authenticated connection - - // Step 5: Drain post-authentication messages until ReadyForQuery ('Z') - // PostgreSQL sends ParameterStatus ('S'), BackendKeyData ('K'), etc. after auth - // We need to consume these before sending the malformed packet to ensure - // the response we read is actually for the malformed packet, not startup noise - bool ready_for_query = false; - for (int i = 0; i < 16; ++i) { - char msg_type = 0; - std::vector msg_data; - if (!read_message(sock, msg_type, msg_data)) break; - if (msg_type == 'Z') { - ready_for_query = true; - break; - } - if (msg_type == 'E') break; // Error during startup - } - if (!ready_for_query) { - close(sock); - ok(0, "%s: Did not reach ReadyForQuery before malformed packet", test_name.c_str()); - return; - } + // Note: libpq has already completed authentication and reached ReadyForQuery + // We can now send malformed packets on this authenticated connection // Step 6: Send the malformed packet on the authenticated connection if (!send_exact(sock, malformed_data.data(), malformed_data.size())) { @@ -694,7 +664,8 @@ void test_malformed_packet_phase2(const std::string& test_name, test_name.c_str(), (int)connection_closed, (int)timeout, (int)got_error_response, first_byte); - close(sock); + // Clean up libpq connection (closes socket internally) + PQfinish(conn); } /**