* Use libpq to authenticate in Phase2

* Added support for hostnames
pull/5433/head
Rahim Kanji 1 month ago
parent 8e047398ce
commit 33a6682bb9

@ -34,6 +34,8 @@
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <cstring>
#include <vector>
#include <string>
@ -71,41 +73,68 @@ typedef enum {
/**
* @brief Create a raw TCP socket connection to the specified host and port
*
* Supports both IP addresses (IPv4 and IPv6) and domain names.
* Uses getaddrinfo() for name resolution.
*/
int create_raw_connection(const std::string& host, int port) {
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
fprintf(stderr, "Socket creation failed: %s\n", strerror(errno));
struct addrinfo hints{};
struct addrinfo* result = nullptr;
struct addrinfo* rp = nullptr;
int sock = -1;
// Prepare hints for getaddrinfo
hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6
hints.ai_socktype = SOCK_STREAM; // TCP socket
// Convert port to string for getaddrinfo
std::string port_str = std::to_string(port);
// Resolve the hostname
int gai_err = getaddrinfo(host.c_str(), port_str.c_str(), &hints, &result);
if (gai_err != 0) {
fprintf(stderr, "Failed to resolve host '%s': %s\n",
host.c_str(), gai_strerror(gai_err));
return -1;
}
// Set socket timeout
struct timeval timeout;
timeout.tv_sec = TIMEOUT_SEC;
timeout.tv_usec = 0;
if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
fprintf(stderr, "Failed to set socket timeout: %s\n", strerror(errno));
close(sock);
return -1;
}
// Try each address returned by getaddrinfo until we successfully connect
for (rp = result; rp != nullptr; rp = rp->ai_next) {
sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sock < 0) {
continue; // Try next address
}
// Set socket timeout
struct timeval timeout;
timeout.tv_sec = TIMEOUT_SEC;
timeout.tv_usec = 0;
if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
fprintf(stderr, "Failed to set socket timeout: %s\n", strerror(errno));
close(sock);
sock = -1;
continue;
}
// Set TCP_NODELAY for immediate packet sending
int flag = 1;
setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
// Set TCP_NODELAY for immediate packet sending
int flag = 1;
setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
sockaddr_in server_addr{};
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
// Attempt to connect
if (connect(sock, rp->ai_addr, rp->ai_addrlen) == 0) {
// Success!
break;
}
if (inet_pton(AF_INET, host.c_str(), &server_addr.sin_addr) <= 0) {
fprintf(stderr, "Invalid address: %s\n", host.c_str());
// Connection failed, try next address
close(sock);
return -1;
sock = -1;
}
if (connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
fprintf(stderr, "Connection failed: %s\n", strerror(errno));
close(sock);
freeaddrinfo(result);
if (sock < 0) {
fprintf(stderr, "Connection failed to %s:%d\n", host.c_str(), port);
return -1;
}
@ -550,96 +579,37 @@ void test_malformed_packet_phase2(const std::string& test_name,
const std::string& password,
const std::vector<uint8_t>& malformed_data) {
int sock = create_raw_connection(host, port);
if (sock < 0) {
ok(0, "%s: Failed to create connection", test_name.c_str());
return;
}
// Step 1: Send valid startup packet
auto startup = build_startup_packet({{"user", username}});
if (!send_exact(sock, startup.data(), startup.size())) {
close(sock);
ok(0, "%s: Failed to send startup", test_name.c_str());
return;
}
// Step 2: Wait for authentication request (usually MD5 or SASL)
char auth_type;
std::vector<uint8_t> auth_data;
if (!read_message(sock, auth_type, auth_data)) {
close(sock);
ok(0, "%s: Failed to read auth request", test_name.c_str());
return;
}
// Step 3: Send password response (cleartext for simplicity)
auto password_pkt = build_password_packet(password);
if (!send_exact(sock, password_pkt.data(), password_pkt.size())) {
close(sock);
ok(0, "%s: Failed to send password", test_name.c_str());
return;
}
// Step 4: Wait for authentication result
if (!read_message(sock, auth_type, auth_data)) {
close(sock);
ok(0, "%s: Failed to read auth result", test_name.c_str());
return;
}
// Use libpq to establish authenticated connection, then extract socket
std::stringstream conninfo;
conninfo << "host=" << host << " port=" << port
<< " user=" << username << " password=" << password
<< " sslmode=disable";
// Check if authentication succeeded
if (auth_type == 'E') {
// Authentication failed - fail the test
// Post-authentication tests require successful auth
close(sock);
ok(0, "%s: Authentication failed, cannot run post-auth test",
test_name.c_str());
PGconn* conn = PQconnectdb(conninfo.str().c_str());
if (PQstatus(conn) != CONNECTION_OK) {
diag("libpq connection failed: %s", PQerrorMessage(conn));
PQfinish(conn);
ok(0, "%s: Failed to create libpq connection", test_name.c_str());
return;
}
if (auth_type != 'R' || auth_data.size() < 4) {
// Unexpected response - fail the test
close(sock);
ok(0, "%s: Unexpected auth response, cannot run post-auth test",
test_name.c_str());
// Get the underlying socket from libpq connection
int sock = PQsocket(conn);
if (sock < 0) {
diag("PQsocket failed");
PQfinish(conn);
ok(0, "%s: Failed to get socket from libpq connection", test_name.c_str());
return;
}
// Check auth result code
int32_t auth_code = (auth_data[0] << 24) | (auth_data[1] << 16) |
(auth_data[2] << 8) | auth_data[3];
if (auth_code != 0) {
// Authentication didn't complete successfully - fail the test
close(sock);
ok(0, "%s: Authentication incomplete (code: %d), cannot run post-auth test",
test_name.c_str(), auth_code);
return;
// Set socket to blocking mode (libpq may have set it non-blocking)
int flags = fcntl(sock, F_GETFL, 0);
if (flags >= 0) {
fcntl(sock, F_SETFL, flags & ~O_NONBLOCK);
}
// Authentication successful - now we have an authenticated connection
// Step 5: Drain post-authentication messages until ReadyForQuery ('Z')
// PostgreSQL sends ParameterStatus ('S'), BackendKeyData ('K'), etc. after auth
// We need to consume these before sending the malformed packet to ensure
// the response we read is actually for the malformed packet, not startup noise
bool ready_for_query = false;
for (int i = 0; i < 16; ++i) {
char msg_type = 0;
std::vector<uint8_t> msg_data;
if (!read_message(sock, msg_type, msg_data)) break;
if (msg_type == 'Z') {
ready_for_query = true;
break;
}
if (msg_type == 'E') break; // Error during startup
}
if (!ready_for_query) {
close(sock);
ok(0, "%s: Did not reach ReadyForQuery before malformed packet", test_name.c_str());
return;
}
// Note: libpq has already completed authentication and reached ReadyForQuery
// We can now send malformed packets on this authenticated connection
// Step 6: Send the malformed packet on the authenticated connection
if (!send_exact(sock, malformed_data.data(), malformed_data.size())) {
@ -694,7 +664,8 @@ void test_malformed_packet_phase2(const std::string& test_name,
test_name.c_str(), (int)connection_closed, (int)timeout, (int)got_error_response,
first_byte);
close(sock);
// Clean up libpq connection (closes socket internally)
PQfinish(conn);
}
/**

Loading…
Cancel
Save