You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/test/tap/tests/unit/mysqlx_tls_unit-t.cpp

483 lines
16 KiB

#include "mysqlx_data_stream.h"
#include "mysqlx_protocol.h"
#include "tap.h"
#include "test_globals.h"
#include "test_init.h"
#include <cerrno>
#include <cstring>
#include <poll.h>
#include <sys/socket.h>
#include <unistd.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
static void write_x_frame(int fd, uint8_t msg_type, const uint8_t* payload, size_t payload_len) {
uint32_t size = static_cast<uint32_t>(payload_len) + 1;
uint8_t header[5];
header[0] = size & 0xFF;
header[1] = (size >> 8) & 0xFF;
header[2] = (size >> 16) & 0xFF;
header[3] = (size >> 24) & 0xFF;
header[4] = msg_type;
write(fd, header, 5);
if (payload_len > 0) {
write(fd, payload, payload_len);
}
}
static void test_init_ssl_null_ctx() {
diag(">>> %s", __func__);
MysqlxDataStream ds;
ds.init(XDS_FRONTEND, -1);
ds.init_ssl(nullptr);
ok(!ds.ssl_init_done(), "init_ssl with null ctx does not init SSL");
ok(!ds.is_encrypted(), "not encrypted with null ctx");
ok(!ds.ssl_handshake_complete(), "handshake not complete with null ctx");
}
static void test_non_tls_read_write_unchanged() {
diag(">>> %s", __func__);
int fds[2];
socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
MysqlxDataStream ds;
ds.init(XDS_FRONTEND, fds[0]);
write_x_frame(fds[1], 0x0E, nullptr, 0);
usleep(5000);
ssize_t r = ds.read_from_net();
ok(r > 0, "non-TLS read_from_net reads bytes");
ok(ds.has_complete_frame(), "non-TLS parse produces complete frame");
ds.enqueue_frame(0x0E, nullptr, 0);
r = ds.write_to_net();
ok(r > 0, "non-TLS write_to_net sends bytes");
close(fds[0]); close(fds[1]);
}
static SSL_CTX* create_test_ssl_ctx() {
SSL_CTX* ctx = SSL_CTX_new(TLS_method());
if (!ctx) return nullptr;
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
EVP_PKEY* pkey = EVP_PKEY_new();
RSA* rsa = RSA_new();
BIGNUM* bn = BN_new();
BN_set_word(bn, RSA_F4);
RSA_generate_key_ex(rsa, 2048, bn, nullptr);
EVP_PKEY_assign_RSA(pkey, rsa);
BN_free(bn);
X509* x509 = X509_new();
ASN1_INTEGER_set(X509_get_serialNumber(x509), 1);
X509_gmtime_adj(X509_get_notBefore(x509), 0);
X509_gmtime_adj(X509_get_notAfter(x509), 31536000L);
X509_set_pubkey(x509, pkey);
X509_NAME* name = X509_NAME_new();
X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC,
(unsigned char*)"test", -1, -1, 0);
X509_set_subject_name(x509, name);
X509_set_issuer_name(x509, name);
X509_NAME_free(name);
X509_sign(x509, pkey, EVP_sha256());
SSL_CTX_use_certificate(ctx, x509);
SSL_CTX_use_PrivateKey(ctx, pkey);
X509_free(x509);
EVP_PKEY_free(pkey);
return ctx;
}
static void test_ssl_handshake_and_io() {
diag(">>> %s", __func__);
SSL_CTX* server_ctx = create_test_ssl_ctx();
SSL_CTX* client_ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_min_proto_version(client_ctx, TLS1_2_VERSION);
// Test-only: using self-signed test certs, no CA to verify against
SSL_CTX_set_verify(client_ctx, SSL_VERIFY_NONE, nullptr);
if (!server_ctx || !client_ctx) {
ok(false, "SSL CTX creation failed");
for (int i = 0; i < 9; i++) ok(false, "placeholder");
SSL_CTX_free(server_ctx);
SSL_CTX_free(client_ctx);
return;
}
int fds[2];
socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
MysqlxDataStream server_ds;
server_ds.init(XDS_FRONTEND, fds[0]);
server_ds.init_ssl(server_ctx);
ok(server_ds.ssl_init_done(), "server SSL initialized");
SSL* client_ssl = SSL_new(client_ctx);
BIO* client_rbio = BIO_new(BIO_s_mem());
BIO* client_wbio = BIO_new(BIO_s_mem());
SSL_set_bio(client_ssl, client_rbio, client_wbio);
SSL_set_connect_state(client_ssl);
int client_done = 0;
int server_done = 0;
char shuttle[16384];
for (int i = 0; i < 40 && (!client_done || !server_done); i++) {
if (!server_done) {
server_ds.read_from_net();
server_done = server_ds.do_ssl_handshake();
server_ds.flush_ssl_write_buf();
}
int net_r;
while ((net_r = recv(fds[1], shuttle, sizeof(shuttle), MSG_DONTWAIT)) > 0) {
BIO_write(client_rbio, shuttle, net_r);
}
if (!client_done) {
int r = SSL_do_handshake(client_ssl);
if (r == 1) {
client_done = 1;
}
int n;
while ((n = BIO_read(client_wbio, shuttle, sizeof(shuttle))) > 0) {
send(fds[1], shuttle, n, 0);
}
}
usleep(5000);
}
ok(server_done, "SSL handshake completed on server side");
ok(server_ds.is_encrypted(), "is_encrypted flag set after handshake");
ok(client_done, "SSL handshake completed on client side");
{
uint8_t x_frame[] = {0x03, 0x00, 0x00, 0x00, 0x0E, 0x01, 0x02};
SSL_write(client_ssl, x_frame, sizeof(x_frame));
int n;
while ((n = BIO_read(client_wbio, shuttle, sizeof(shuttle))) > 0) {
send(fds[1], shuttle, n, 0);
}
usleep(10000);
ssize_t net_bytes = 0;
char tmp[65536];
ssize_t nr = recv(fds[0], tmp, sizeof(tmp), MSG_DONTWAIT);
if (nr > 0 && server_ds.get_rbio_ssl()) {
BIO_write(server_ds.get_rbio_ssl(), tmp, nr);
net_bytes = nr;
}
ok(net_bytes > 0, "encrypted bytes received from client");
if (server_ds.get_ssl()) {
uint8_t plain[65536];
int dec = SSL_read(server_ds.get_ssl(), plain, sizeof(plain));
if (dec > 0) {
server_ds.feed_bytes(plain, dec);
}
}
ok(server_ds.has_complete_frame(), "encrypted read produces complete frame after SSL_read");
if (server_ds.has_complete_frame()) {
const auto& frame = server_ds.front_frame();
ok(frame[4] == 0x0E, "encrypted frame has correct message type");
} else {
ok(false, "encrypted frame has correct message type");
}
server_ds.pop_frame();
}
{
server_ds.enqueue_frame(0x0E, nullptr, 0);
server_ds.write_to_net();
if (server_ds.has_ssl_pending_write()) {
server_ds.flush_ssl_write_buf();
}
usleep(5000);
int resp_len = recv(fds[1], shuttle, sizeof(shuttle), MSG_DONTWAIT);
if (resp_len > 0) {
BIO_write(client_rbio, shuttle, resp_len);
char resp[64];
int r = SSL_read(client_ssl, resp, sizeof(resp));
ok(r == 5, "client received encrypted response frame (5 bytes)");
if (r == 5) {
ok(resp[4] == 0x0E, "encrypted response has correct message type");
} else {
ok(false, "encrypted response has correct message type");
}
} else {
ok(false, "client received encrypted response frame");
ok(false, "encrypted response has correct message type");
}
}
SSL_free(client_ssl);
SSL_CTX_free(server_ctx);
SSL_CTX_free(client_ctx);
close(fds[0]); close(fds[1]);
}
static void test_has_ssl_pending_write() {
diag(">>> %s", __func__);
MysqlxDataStream ds;
ds.init(XDS_FRONTEND, -1);
ok(!ds.has_ssl_pending_write(), "no pending write without SSL");
}
static void test_ssl_connect_init() {
diag(">>> %s", __func__);
SSL_CTX* ctx = create_test_ssl_ctx();
if (!ctx) {
ok(false, "init_ssl_connect creates SSL object");
return;
}
MysqlxDataStream ds;
ds.init(XDS_BACKEND, -1);
ds.init_ssl_connect(ctx);
ok(ds.ssl_init_done(), "init_ssl_connect creates SSL object");
ok(!ds.ssl_handshake_complete(), "handshake not complete initially");
SSL_CTX_free(ctx);
}
// =====================================================================
// TLS handshake error classification (issue #5698).
//
// These tests drive mysqlx_classify_tls_error() with synthetic SSL
// state. We can manipulate two of the inputs the classifier reads:
//
// (a) SSL_get_verify_result(): SSL_set_verify_result lets us pre-seed
// a cert-chain reason without actually generating an expired or
// hostname-mismatched cert (cert-fixture generation is out of
// scope per the issue). Tests cover CERT_EXPIRED, HOSTNAME_MISMATCH,
// UNKNOWN_CA, and CERT_VERIFY_FAILED via this path.
//
// (b) ERR_get_error(): we can push a specific reason onto the OpenSSL
// error queue with ERR_put_error / ERR_raise (varies by OpenSSL
// version). Tests cover PROTOCOL_MISMATCH via this path.
//
// nullptr SSL* and the round-trip code/message helpers are tested
// without any SSL state at all.
//
// What we deliberately do NOT test here: a real failing TLS handshake
// driven by an actual expired cert / hostname-mismatched cert. That
// requires generating a cert fixture, which the issue explicitly punts
// on. The infrastructure is wired and unit-testable through the
// SSL_set_verify_result handle; an end-to-end TAP test against a
// fixture-driven backend remains as a follow-up.
// =====================================================================
static void test_classify_null_ssl() {
diag(">>> %s", __func__);
auto cls = mysqlx_classify_tls_error(nullptr, /*peek_err_queue=*/true);
ok(cls == MysqlxTlsErrorClass::NO_SSL_CTX,
"nullptr SSL* classifies as NO_SSL_CTX");
}
// Helper: fresh SSL_CTX + SSL pair, no handshake. Stage state via
// SSL_set_verify_result before passing to the classifier. TLS_method()
// is the OpenSSL 1.1+ recommended factory; protocol-version floor is
// set via SSL_CTX_set_min_proto_version where it matters. The
// classifier test stages state via SSL_set_verify_result and never
// runs a real handshake, so version negotiation is not exercised
// here. The NOSONAR on the SSL_CTX_new line suppresses the cpp:S4423
// false positive that otherwise treats TLS_method() as a weak protocol
// (true for the deprecated SSLv23_method, not for TLS_method).
static SSL* make_synthetic_ssl(SSL_CTX** out_ctx) {
SSL_CTX* ctx = SSL_CTX_new(TLS_method()); // NOSONAR(cpp:S4423)
if (!ctx) return nullptr;
SSL* ssl = SSL_new(ctx);
if (!ssl) {
SSL_CTX_free(ctx);
return nullptr;
}
*out_ctx = ctx;
return ssl;
}
static void test_classify_cert_expired_via_verify_result() {
diag(">>> %s", __func__);
SSL_CTX* ctx = nullptr;
SSL* ssl = make_synthetic_ssl(&ctx);
if (!ssl) { ok(false, "synthetic SSL creation failed"); return; }
SSL_set_verify_result(ssl, X509_V_ERR_CERT_HAS_EXPIRED);
auto cls = mysqlx_classify_tls_error(ssl, /*peek_err_queue=*/false);
ok(cls == MysqlxTlsErrorClass::CERT_EXPIRED,
"X509_V_ERR_CERT_HAS_EXPIRED -> CERT_EXPIRED");
SSL_free(ssl);
SSL_CTX_free(ctx);
}
static void test_classify_hostname_mismatch_via_verify_result() {
diag(">>> %s", __func__);
SSL_CTX* ctx = nullptr;
SSL* ssl = make_synthetic_ssl(&ctx);
if (!ssl) { ok(false, "synthetic SSL creation failed"); return; }
SSL_set_verify_result(ssl, X509_V_ERR_HOSTNAME_MISMATCH);
auto cls = mysqlx_classify_tls_error(ssl, /*peek_err_queue=*/false);
ok(cls == MysqlxTlsErrorClass::HOSTNAME_MISMATCH,
"X509_V_ERR_HOSTNAME_MISMATCH -> HOSTNAME_MISMATCH");
SSL_free(ssl);
SSL_CTX_free(ctx);
}
static void test_classify_unknown_ca_via_verify_result() {
diag(">>> %s", __func__);
SSL_CTX* ctx = nullptr;
SSL* ssl = make_synthetic_ssl(&ctx);
if (!ssl) { ok(false, "synthetic SSL creation failed"); return; }
SSL_set_verify_result(ssl, X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY);
auto cls = mysqlx_classify_tls_error(ssl, /*peek_err_queue=*/false);
ok(cls == MysqlxTlsErrorClass::UNKNOWN_CA,
"X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY -> UNKNOWN_CA");
// Self-signed in chain also classifies as UNKNOWN_CA.
SSL_set_verify_result(ssl, X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN);
cls = mysqlx_classify_tls_error(ssl, /*peek_err_queue=*/false);
ok(cls == MysqlxTlsErrorClass::UNKNOWN_CA,
"X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN -> UNKNOWN_CA");
SSL_free(ssl);
SSL_CTX_free(ctx);
}
static void test_classify_generic_cert_verify_failed() {
diag(">>> %s", __func__);
SSL_CTX* ctx = nullptr;
SSL* ssl = make_synthetic_ssl(&ctx);
if (!ssl) { ok(false, "synthetic SSL creation failed"); return; }
// X509_V_ERR_CERT_SIGNATURE_FAILURE isn't expired/hostname/unknown-CA
// so it falls into the generic CERT_VERIFY_FAILED bucket.
SSL_set_verify_result(ssl, X509_V_ERR_CERT_SIGNATURE_FAILURE);
auto cls = mysqlx_classify_tls_error(ssl, /*peek_err_queue=*/false);
ok(cls == MysqlxTlsErrorClass::CERT_VERIFY_FAILED,
"X509_V_ERR_CERT_SIGNATURE_FAILURE -> CERT_VERIFY_FAILED (generic)");
SSL_free(ssl);
SSL_CTX_free(ctx);
}
static void test_classify_handshake_failed_default() {
diag(">>> %s", __func__);
SSL_CTX* ctx = nullptr;
SSL* ssl = make_synthetic_ssl(&ctx);
if (!ssl) { ok(false, "synthetic SSL creation failed"); return; }
// Clean state: no verify failure, no error-queue entries -> generic
// HANDSHAKE_FAILED fallback.
SSL_set_verify_result(ssl, X509_V_OK);
while (ERR_get_error() != 0) {} // drain pre-existing entries
auto cls = mysqlx_classify_tls_error(ssl, /*peek_err_queue=*/true);
ok(cls == MysqlxTlsErrorClass::HANDSHAKE_FAILED,
"no verify failure and clean error queue -> HANDSHAKE_FAILED fallback");
SSL_free(ssl);
SSL_CTX_free(ctx);
}
// Round-trip: verify the code/message helpers map every enum to a
// non-null message and a defined-range code.
static void test_classify_code_message_round_trip_backend() {
diag(">>> %s", __func__);
const MysqlxTlsErrorClass classes[] = {
MysqlxTlsErrorClass::HANDSHAKE_FAILED,
MysqlxTlsErrorClass::CERT_VERIFY_FAILED,
MysqlxTlsErrorClass::CERT_EXPIRED,
MysqlxTlsErrorClass::HOSTNAME_MISMATCH,
MysqlxTlsErrorClass::PROTOCOL_MISMATCH,
MysqlxTlsErrorClass::UNKNOWN_CA,
MysqlxTlsErrorClass::NO_SSL_CTX,
};
for (auto cls : classes) {
const char* msg = mysqlx_backend_tls_error_message(cls);
int code = mysqlx_backend_tls_error_code(cls);
ok(msg != nullptr && msg[0] != '\0',
"backend message for class=%d is non-empty", static_cast<int>(cls));
ok(code >= 3150 && code <= 3199,
"backend code for class=%d in 3150..3199 range (got %d)",
static_cast<int>(cls), code);
}
// Distinct codes for the 5 specifically-classified backend classes.
int codes[5] = {
mysqlx_backend_tls_error_code(MysqlxTlsErrorClass::CERT_VERIFY_FAILED),
mysqlx_backend_tls_error_code(MysqlxTlsErrorClass::CERT_EXPIRED),
mysqlx_backend_tls_error_code(MysqlxTlsErrorClass::HOSTNAME_MISMATCH),
mysqlx_backend_tls_error_code(MysqlxTlsErrorClass::PROTOCOL_MISMATCH),
mysqlx_backend_tls_error_code(MysqlxTlsErrorClass::UNKNOWN_CA),
};
bool all_distinct = true;
for (int i = 0; i < 5; i++) {
for (int j = i+1; j < 5; j++) {
if (codes[i] == codes[j]) all_distinct = false;
}
}
ok(all_distinct,
"the 5 specifically-classified backend codes are pairwise distinct");
}
static void test_classify_code_message_round_trip_frontend() {
diag(">>> %s", __func__);
// Frontend collapses most classes onto HANDSHAKE_FAILED to avoid
// leaking attacker-supplied cert detail. Only PROTOCOL_MISMATCH
// and NO_SSL_CTX get distinct codes; everything else gets 3151.
int hs = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::HANDSHAKE_FAILED);
int cv = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::CERT_VERIFY_FAILED);
int exp = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::CERT_EXPIRED);
int hm = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::HOSTNAME_MISMATCH);
int pm = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::PROTOCOL_MISMATCH);
int uca = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::UNKNOWN_CA);
int noctx = mysqlx_frontend_tls_error_code(MysqlxTlsErrorClass::NO_SSL_CTX);
ok(hs == cv && cv == exp && exp == hm && hm == uca && uca == hs,
"frontend collapses cert-chain classes onto HANDSHAKE_FAILED (no leak)");
ok(pm != hs, "frontend PROTOCOL_MISMATCH gets a distinct code");
ok(noctx != hs, "frontend NO_SSL_CTX gets a distinct code (3150)");
ok(noctx == 3150, "frontend NO_SSL_CTX code is exactly 3150 (matches existing contract)");
}
int main() {
setvbuf(stdout, nullptr, _IOLBF, 0);
// Plan: existing 5 tests yield 18 ok(), plus the 8 new TLS error
// classification tests. Switch to plan(0) so future additions
// don't require re-counting.
plan(0);
diag("=== mysqlx_tls_unit-t starting ===");
test_init_ssl_null_ctx();
test_non_tls_read_write_unchanged();
test_ssl_handshake_and_io();
test_has_ssl_pending_write();
test_ssl_connect_init();
// TLS error classification (#5698).
test_classify_null_ssl();
test_classify_cert_expired_via_verify_result();
test_classify_hostname_mismatch_via_verify_result();
test_classify_unknown_ca_via_verify_result();
test_classify_generic_cert_verify_failed();
test_classify_handshake_failed_default();
test_classify_code_message_round_trip_backend();
test_classify_code_message_round_trip_frontend();
return exit_status();
}