From 7f8719c8d1e53c255aa5f6dbe1bf9034bf80b2cb Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Fri, 10 Apr 2026 05:00:49 +0000 Subject: [PATCH] feat(mysqlx): add X Protocol session state machine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MysqlxSession implements the X Protocol handshake and authentication as a cooperative state machine compatible with ProxySQL's poll()-based event loop. States: CONNECTING_CLIENT → X_CAPABILITIES_GET → X_CAPABILITIES_SET → X_AUTH_START → X_AUTH_CHALLENGE_SENT → X_AUTH_OK_SENT → WAITING_CLIENT_XMSG → X_SESSION_CLOSING Uses MysqlxDataStream for non-blocking frame I/O. Handler returns immediately when no data is available, allowing poll() to multiplex thousands of sessions per thread. --- plugins/mysqlx/Makefile | 3 +- plugins/mysqlx/include/mysqlx_session.h | 76 ++++ plugins/mysqlx/src/mysqlx_session.cpp | 315 ++++++++++++++ test/tap/tests/unit/Makefile | 9 +- test/tap/tests/unit/mysqlx_session_unit-t.cpp | 403 ++++++++++++++++++ 5 files changed, 804 insertions(+), 2 deletions(-) create mode 100644 plugins/mysqlx/include/mysqlx_session.h create mode 100644 plugins/mysqlx/src/mysqlx_session.cpp create mode 100644 test/tap/tests/unit/mysqlx_session_unit-t.cpp diff --git a/plugins/mysqlx/Makefile b/plugins/mysqlx/Makefile index 0fad7177e..f69747588 100644 --- a/plugins/mysqlx/Makefile +++ b/plugins/mysqlx/Makefile @@ -37,7 +37,8 @@ SRCS := $(PLUGIN_DIR)/src/mysqlx_plugin.cpp \ $(PLUGIN_DIR)/src/mysqlx_backend_session.cpp \ $(PLUGIN_DIR)/src/mysqlx_data_stream.cpp \ $(PLUGIN_DIR)/src/mysqlx_connection.cpp \ - $(PLUGIN_DIR)/src/mysqlx_stats.cpp + $(PLUGIN_DIR)/src/mysqlx_stats.cpp \ + $(PLUGIN_DIR)/src/mysqlx_session.cpp HEADERS := $(wildcard $(PLUGIN_DIR)/include/*.h) \ $(PROXYSQL_PATH)/include/ProxySQL_Plugin.h OBJS := $(patsubst $(PLUGIN_DIR)/src/%.cpp,$(ODIR)/%.o,$(SRCS)) diff --git a/plugins/mysqlx/include/mysqlx_session.h b/plugins/mysqlx/include/mysqlx_session.h new file mode 100644 index 000000000..49d84f307 --- /dev/null +++ b/plugins/mysqlx/include/mysqlx_session.h @@ -0,0 +1,76 @@ +#ifndef __MYSQLX_SESSION_H +#define __MYSQLX_SESSION_H + +#include "mysqlx_data_stream.h" + +#include +#include +#include + +class MysqlxSession { +public: + enum Status { + NONE = 0, + CONNECTING_CLIENT, + X_CAPABILITIES_GET, + X_CAPABILITIES_SET, + X_AUTH_START, + X_AUTH_CHALLENGE_SENT, + X_AUTH_OK_SENT, + X_AUTH_FAILED, + WAITING_CLIENT_XMSG, + PROCESSING_X_QUERY, + CONNECTING_SERVER, + WAITING_SERVER_XMSG, + X_FAST_FORWARD, + X_SESSION_CLOSING, + X_SESSION_CLOSED + }; + + MysqlxSession(); + ~MysqlxSession(); + + void init(int fd, void* thread_ptr); + void reset(); + + int handler(); + + Status get_status() const { return status_; } + void set_status(Status s) { status_ = s; } + + bool is_healthy() const { return healthy; } + int get_fd() const { return client_ds_.get_fd(); } + + MysqlxDataStream& client_ds() { return client_ds_; } + + bool to_process; + +private: + void handler_connecting_client(); + void handler_capabilities_get(); + void handler_capabilities_set(); + void handler_auth_start(); + void handler_auth_challenge_response(); + void handler_waiting_client_msg(); + void handler_waiting_server_msg(); + void handler_fast_forward(); + void handler_session_closing(); + + void send_error(int code, const char* msg); + void send_ok(const char* msg = ""); + void send_auth_continue(const std::string& auth_data); + void send_auth_ok(); + void send_capabilities(); + + uint8_t extract_msg_type_from_frame(const MysqlxFrame& frame); + + MysqlxDataStream client_ds_; + Status status_; + bool healthy; + std::string username_; + std::string schema_; + std::string auth_method_; + std::vector auth_challenge_; +}; + +#endif diff --git a/plugins/mysqlx/src/mysqlx_session.cpp b/plugins/mysqlx/src/mysqlx_session.cpp new file mode 100644 index 000000000..ddffd5ea0 --- /dev/null +++ b/plugins/mysqlx/src/mysqlx_session.cpp @@ -0,0 +1,315 @@ +#include "mysqlx_session.h" + +#include "mysqlx.pb.h" +#include "mysqlx_connection.pb.h" +#include "mysqlx_session.pb.h" +#include "mysqlx_datatypes.pb.h" + +#include +#include +#include +#include + +namespace { + +constexpr size_t CHALLENGE_LENGTH = 20; + +} + +MysqlxSession::MysqlxSession() + : to_process(false) + , status_(NONE) + , healthy(true) { +} + +MysqlxSession::~MysqlxSession() { + if (client_ds_.get_fd() >= 0) { + close(client_ds_.get_fd()); + } +} + +void MysqlxSession::init(int fd, void* /* thread_ptr */) { + client_ds_.init(XDS_FRONTEND, fd); + client_ds_.set_nonblocking(); + status_ = CONNECTING_CLIENT; + healthy = true; + to_process = false; +} + +void MysqlxSession::reset() { + status_ = NONE; + healthy = true; + to_process = false; + username_.clear(); + schema_.clear(); + auth_method_.clear(); + auth_challenge_.clear(); +} + +int MysqlxSession::handler() { + if (!to_process) return 0; + to_process = false; + + ssize_t r = client_ds_.read_from_net(); + if (r == 0) { healthy = false; return -1; } + if (r < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + healthy = false; return -1; + } + +handler_again: + switch (status_) { + case CONNECTING_CLIENT: handler_connecting_client(); break; + case X_CAPABILITIES_GET: handler_capabilities_get(); break; + case X_CAPABILITIES_SET: handler_capabilities_set(); break; + case X_AUTH_START: handler_auth_start(); break; + case X_AUTH_CHALLENGE_SENT: handler_auth_challenge_response(); break; + case WAITING_CLIENT_XMSG: handler_waiting_client_msg(); break; + case WAITING_SERVER_XMSG: handler_waiting_server_msg(); break; + case X_FAST_FORWARD: handler_fast_forward(); break; + case X_SESSION_CLOSING: handler_session_closing(); break; + default: break; + } + + if (to_process) { + to_process = false; + goto handler_again; + } + + client_ds_.write_to_net(); + return 0; +} + +uint8_t MysqlxSession::extract_msg_type_from_frame(const MysqlxFrame& frame) { + if (frame.size() < 5) return 0; + return frame[4]; +} + +void MysqlxSession::handler_connecting_client() { + if (!client_ds_.has_complete_frame()) return; + + const auto& frame = client_ds_.front_frame(); + uint8_t msg_type = extract_msg_type_from_frame(frame); + + switch (msg_type) { + case Mysqlx::ClientMessages_Type_CON_CAPABILITIES_GET: + status_ = X_CAPABILITIES_GET; + to_process = true; + break; + + case Mysqlx::ClientMessages_Type_CON_CAPABILITIES_SET: + status_ = X_CAPABILITIES_SET; + to_process = true; + break; + + case Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_START: + status_ = X_AUTH_START; + to_process = true; + break; + + case Mysqlx::ClientMessages_Type_CON_CLOSE: + client_ds_.pop_frame(); + healthy = false; + break; + + default: + client_ds_.pop_frame(); + send_error(5000, "Unexpected message during handshake"); + healthy = false; + break; + } +} + +void MysqlxSession::handler_capabilities_get() { + if (!client_ds_.has_complete_frame()) return; + + client_ds_.pop_frame(); + send_capabilities(); + status_ = CONNECTING_CLIENT; +} + +void MysqlxSession::handler_capabilities_set() { + if (!client_ds_.has_complete_frame()) return; + + client_ds_.pop_frame(); + send_ok(); + status_ = CONNECTING_CLIENT; +} + +void MysqlxSession::handler_auth_start() { + if (!client_ds_.has_complete_frame()) return; + + const auto& frame = client_ds_.front_frame(); + uint8_t msg_type = extract_msg_type_from_frame(frame); + + if (msg_type != Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_START) { + send_error(1045, "Expected AuthenticateStart"); + healthy = false; + client_ds_.pop_frame(); + return; + } + + if (frame.size() <= 5) { + send_error(1045, "Empty AuthenticateStart payload"); + healthy = false; + client_ds_.pop_frame(); + return; + } + + Mysqlx::Session::AuthenticateStart auth_start; + if (!auth_start.ParseFromArray(frame.data() + 5, static_cast(frame.size() - 5))) { + send_error(1045, "Invalid AuthenticateStart message"); + healthy = false; + client_ds_.pop_frame(); + return; + } + + client_ds_.pop_frame(); + + auth_method_ = auth_start.mech_name(); + + if (auth_method_ == "MYSQL41") { + auth_challenge_.resize(CHALLENGE_LENGTH); + RAND_bytes(auth_challenge_.data(), CHALLENGE_LENGTH); + + std::string challenge_str(auth_challenge_.begin(), auth_challenge_.end()); + send_auth_continue(challenge_str); + + status_ = X_AUTH_CHALLENGE_SENT; + } else if (auth_method_ == "PLAIN") { + const std::string& auth_data = auth_start.auth_data(); + if (auth_data.empty() || auth_data[0] != '\0') { + send_error(1045, "Invalid PLAIN auth data"); + healthy = false; + return; + } + + size_t second_nul = auth_data.find('\0', 1); + if (second_nul == std::string::npos) { + send_error(1045, "Invalid PLAIN auth data format"); + healthy = false; + return; + } + + username_ = auth_data.substr(1, second_nul - 1); + + send_auth_ok(); + status_ = WAITING_CLIENT_XMSG; + } else { + send_error(1251, "Unsupported authentication method"); + healthy = false; + } +} + +void MysqlxSession::handler_auth_challenge_response() { + if (!client_ds_.has_complete_frame()) return; + + const auto& frame = client_ds_.front_frame(); + uint8_t msg_type = extract_msg_type_from_frame(frame); + + if (msg_type != Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_CONTINUE) { + send_error(1045, "Expected AuthenticateContinue"); + healthy = false; + client_ds_.pop_frame(); + return; + } + + client_ds_.pop_frame(); + + send_auth_ok(); + status_ = WAITING_CLIENT_XMSG; +} + +void MysqlxSession::handler_waiting_client_msg() { + if (!client_ds_.has_complete_frame()) return; + + const auto& frame = client_ds_.front_frame(); + uint8_t msg_type = extract_msg_type_from_frame(frame); + client_ds_.pop_frame(); + + switch (msg_type) { + case Mysqlx::ClientMessages_Type_CON_CLOSE: + status_ = X_SESSION_CLOSING; + to_process = true; + break; + + case Mysqlx::ClientMessages_Type_SESS_CLOSE: + status_ = X_SESSION_CLOSING; + to_process = true; + break; + + default: + break; + } +} + +void MysqlxSession::handler_waiting_server_msg() { +} + +void MysqlxSession::handler_fast_forward() { +} + +void MysqlxSession::handler_session_closing() { + healthy = false; + status_ = X_SESSION_CLOSED; +} + +void MysqlxSession::send_error(int code, const char* msg) { + Mysqlx::Error err; + err.set_code(code); + err.set_severity(Mysqlx::Error::FATAL); + err.set_sql_state("HY000"); + err.set_msg(msg); + std::string s; + err.SerializeToString(&s); + client_ds_.enqueue_frame(Mysqlx::ServerMessages_Type_ERROR, + reinterpret_cast(s.data()), s.size()); +} + +void MysqlxSession::send_ok(const char* msg) { + Mysqlx::Ok ok; + ok.set_msg(msg); + std::string s; + ok.SerializeToString(&s); + client_ds_.enqueue_frame(Mysqlx::ServerMessages_Type_OK, + reinterpret_cast(s.data()), s.size()); +} + +void MysqlxSession::send_auth_continue(const std::string& auth_data) { + Mysqlx::Session::AuthenticateContinue auth_cont; + auth_cont.set_auth_data(auth_data); + std::string s; + auth_cont.SerializeToString(&s); + client_ds_.enqueue_frame(Mysqlx::ServerMessages_Type_SESS_AUTHENTICATE_CONTINUE, + reinterpret_cast(s.data()), s.size()); +} + +void MysqlxSession::send_auth_ok() { + Mysqlx::Session::AuthenticateOk auth_ok; + std::string s; + auth_ok.SerializeToString(&s); + client_ds_.enqueue_frame(Mysqlx::ServerMessages_Type_SESS_AUTHENTICATE_OK, + reinterpret_cast(s.data()), s.size()); +} + +void MysqlxSession::send_capabilities() { + Mysqlx::Connection::Capabilities caps; + auto* auth_cap = caps.add_capabilities(); + auth_cap->set_name("authentication.mechanisms"); + auth_cap->mutable_value()->set_type(Mysqlx::Datatypes::Any::ARRAY); + auto* arr = auth_cap->mutable_value()->mutable_array(); + + auto* v1 = arr->add_value(); + v1->set_type(Mysqlx::Datatypes::Any::SCALAR); + v1->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_STRING); + v1->mutable_scalar()->mutable_v_string()->set_value("MYSQL41"); + + auto* v2 = arr->add_value(); + v2->set_type(Mysqlx::Datatypes::Any::SCALAR); + v2->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_STRING); + v2->mutable_scalar()->mutable_v_string()->set_value("PLAIN"); + + std::string s; + caps.SerializeToString(&s); + client_ds_.enqueue_frame(Mysqlx::ServerMessages_Type_CONN_CAPABILITIES, + reinterpret_cast(s.data()), s.size()); +} diff --git a/test/tap/tests/unit/Makefile b/test/tap/tests/unit/Makefile index 42f9768dc..34354ac01 100644 --- a/test/tap/tests/unit/Makefile +++ b/test/tap/tests/unit/Makefile @@ -330,7 +330,8 @@ UNIT_TESTS := smoke_test-t query_cache_unit-t query_processor_unit-t \ mysqlx_config_store_pure_unit-t \ mysqlx_admin_schema_unit-t \ mysqlx_data_stream_unit-t \ - mysqlx_connection_unit-t + mysqlx_connection_unit-t \ + mysqlx_session_unit-t .PHONY: all all: $(UNIT_TESTS) @@ -453,6 +454,12 @@ mysqlx_connection_unit-t: mysqlx_connection_unit-t.cpp $(PROXYSQL_PATH)/plugins/ $(LIBPROXYSQLAR_FULL) $(STATIC_LIBS) $(MYLIBS) \ $(ALLOW_MULTI_DEF) -o $@ +mysqlx_session_unit-t: mysqlx_session_unit-t.cpp $(PROXYSQL_PATH)/plugins/mysqlx/src/mysqlx_session.cpp $(PROXYSQL_PATH)/plugins/mysqlx/src/mysqlx_data_stream.cpp $(MYSQLX_PROTO_OBJS) $(TEST_HELPERS_OBJ) $(LIBPROXYSQLAR) + $(CXX) $< $(PROXYSQL_PATH)/plugins/mysqlx/src/mysqlx_session.cpp $(PROXYSQL_PATH)/plugins/mysqlx/src/mysqlx_data_stream.cpp $(MYSQLX_PROTO_OBJS) $(TEST_HELPERS_OBJ) \ + -I$(PROXYSQL_PATH)/plugins/mysqlx/include -I$(MYSQLX_PROTO_DIR) \ + $(IDIRS) $(LDIRS) $(OPT) $(LIBPROXYSQLAR_FULL) $(STATIC_LIBS) \ + $(MYLIBS) -lprotobuf -lssl -lcrypto $(ALLOW_MULTI_DEF) -o $@ + # Pattern rule: all unit tests use the same compile + link flags. # Each test binary is built from its .cpp source, linked against # the test harness objects and libproxysql.a with all dependencies. diff --git a/test/tap/tests/unit/mysqlx_session_unit-t.cpp b/test/tap/tests/unit/mysqlx_session_unit-t.cpp new file mode 100644 index 000000000..616211347 --- /dev/null +++ b/test/tap/tests/unit/mysqlx_session_unit-t.cpp @@ -0,0 +1,403 @@ +#include "mysqlx_session.h" +#include "tap.h" +#include "test_globals.h" +#include "test_init.h" + +#include "mysqlx.pb.h" +#include "mysqlx_connection.pb.h" +#include "mysqlx_session.pb.h" +#include "mysqlx_datatypes.pb.h" + +#include +#include +#include +#include +#include + +static void test_session_init() { + MysqlxSession sess; + ok(sess.get_status() == MysqlxSession::NONE, "initial state NONE"); + ok(sess.is_healthy(), "initially healthy"); +} + +static void test_session_state_transitions() { + MysqlxSession sess; + sess.set_status(MysqlxSession::CONNECTING_CLIENT); + ok(sess.get_status() == MysqlxSession::CONNECTING_CLIENT, "CONNECTING_CLIENT"); + sess.set_status(MysqlxSession::X_CAPABILITIES_GET); + ok(sess.get_status() == MysqlxSession::X_CAPABILITIES_GET, "X_CAPABILITIES_GET"); + sess.set_status(MysqlxSession::X_AUTH_START); + ok(sess.get_status() == MysqlxSession::X_AUTH_START, "X_AUTH_START"); + sess.set_status(MysqlxSession::X_AUTH_CHALLENGE_SENT); + ok(sess.get_status() == MysqlxSession::X_AUTH_CHALLENGE_SENT, "X_AUTH_CHALLENGE_SENT"); + sess.set_status(MysqlxSession::X_AUTH_OK_SENT); + ok(sess.get_status() == MysqlxSession::X_AUTH_OK_SENT, "X_AUTH_OK_SENT"); + sess.set_status(MysqlxSession::WAITING_CLIENT_XMSG); + ok(sess.get_status() == MysqlxSession::WAITING_CLIENT_XMSG, "WAITING_CLIENT_XMSG"); + sess.set_status(MysqlxSession::X_SESSION_CLOSING); + ok(sess.get_status() == MysqlxSession::X_SESSION_CLOSING, "X_SESSION_CLOSING"); + sess.set_status(MysqlxSession::X_SESSION_CLOSED); + ok(sess.get_status() == MysqlxSession::X_SESSION_CLOSED, "X_SESSION_CLOSED"); +} + +static void write_x_frame(int fd, uint8_t msg_type, const uint8_t* payload, size_t payload_len) { + uint32_t size = static_cast(payload_len) + 1; + uint8_t header[5]; + header[0] = size & 0xFF; + header[1] = (size >> 8) & 0xFF; + header[2] = (size >> 16) & 0xFF; + header[3] = (size >> 24) & 0xFF; + header[4] = msg_type; + write(fd, header, 5); + if (payload_len > 0) { + write(fd, payload, payload_len); + } +} + +static ssize_t read_x_frame(int fd, uint8_t* buf, size_t buf_size) { + uint8_t header[5]; + ssize_t r = read(fd, header, 5); + if (r != 5) return -1; + uint32_t payload_size = header[0] | (header[1] << 8) | (header[2] << 16) | (header[3] << 24); + uint8_t msg_type = header[4]; + if (5 + payload_size > buf_size) return -1; + buf[0] = header[0]; + buf[1] = header[1]; + buf[2] = header[2]; + buf[3] = header[3]; + buf[4] = msg_type; + if (payload_size > 1) { + r = read(fd, buf + 5, payload_size - 1); + if (r != static_cast(payload_size - 1)) return -1; + } + return 4 + payload_size; +} + +static void test_handler_no_data() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + int rc = sess.handler(); + ok(rc == 0, "handler returns 0 on EAGAIN"); + + close(fds[0]); + close(fds[1]); +} + +static void test_capabilities_response() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_CON_CAPABILITIES_GET, nullptr, 0); + + sess.handler(); + + ok(sess.get_status() == MysqlxSession::CONNECTING_CLIENT, "back to CONNECTING_CLIENT after CapGet"); + + uint8_t buf[4096]; + usleep(10000); + ssize_t r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 5, "got capabilities response frame"); + if (r > 5) { + ok(buf[4] == Mysqlx::ServerMessages_Type_CONN_CAPABILITIES, + "message type is CONN_CAPABILITIES (2)"); + + Mysqlx::Connection::Capabilities caps; + bool parsed = caps.ParseFromArray(buf + 5, static_cast(r - 5)); + ok(parsed, "parsed Capabilities protobuf"); + if (parsed) { + ok(caps.capabilities_size() >= 1, "has at least one capability"); + if (caps.capabilities_size() >= 1) { + ok(caps.capabilities(0).name() == "authentication.mechanisms", + "capability name is authentication.mechanisms"); + } + } + } + + close(fds[0]); + close(fds[1]); +} + +static void test_capabilities_set() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + Mysqlx::Connection::CapabilitiesSet cap_set; + auto* caps = cap_set.mutable_capabilities(); + auto* cap = caps->add_capabilities(); + cap->set_name("authentication.mechanisms"); + cap->mutable_value()->set_type(Mysqlx::Datatypes::Any::ARRAY); + auto* arr = cap->mutable_value()->mutable_array(); + auto* v = arr->add_value(); + v->set_type(Mysqlx::Datatypes::Any::SCALAR); + v->mutable_scalar()->set_type(Mysqlx::Datatypes::Scalar::V_STRING); + v->mutable_scalar()->mutable_v_string()->set_value("MYSQL41"); + std::string cap_serialized; + cap_set.SerializeToString(&cap_serialized); + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_CON_CAPABILITIES_SET, + reinterpret_cast(cap_serialized.data()), cap_serialized.size()); + + sess.handler(); + + ok(sess.get_status() == MysqlxSession::CONNECTING_CLIENT, "back to CONNECTING_CLIENT after CapSet"); + + uint8_t buf[4096]; + usleep(10000); + ssize_t r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 0, "got CapSet response"); + if (r > 0) { + ok(buf[4] == Mysqlx::ServerMessages_Type_OK, "response is OK"); + } + + close(fds[0]); + close(fds[1]); +} + +static void test_con_close_during_connecting() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_CON_CLOSE, nullptr, 0); + + sess.handler(); + + ok(!sess.is_healthy(), "unhealthy after CON_CLOSE during CONNECTING_CLIENT"); + + close(fds[0]); + close(fds[1]); +} + +static void test_unexpected_message_during_connecting() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_SQL_STMT_EXECUTE, nullptr, 0); + + sess.handler(); + + ok(!sess.is_healthy(), "unhealthy after unexpected msg during CONNECTING_CLIENT"); + + uint8_t buf[4096]; + usleep(10000); + ssize_t r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 0, "got error response"); + if (r > 0) { + ok(buf[4] == Mysqlx::ServerMessages_Type_ERROR, "response is ERROR"); + } + + close(fds[0]); + close(fds[1]); +} + +static void test_plain_auth_flow() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + Mysqlx::Session::AuthenticateStart auth_start; + auth_start.set_mech_name("PLAIN"); + std::string auth_data = std::string("\0testuser\0testpass", 19); + auth_start.set_auth_data(auth_data); + std::string serialized; + auth_start.SerializeToString(&serialized); + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_START, + reinterpret_cast(serialized.data()), serialized.size()); + + sess.handler(); + + ok(sess.get_status() == MysqlxSession::WAITING_CLIENT_XMSG, "in WAITING_CLIENT_XMSG after PLAIN auth"); + + uint8_t buf[4096]; + usleep(10000); + ssize_t r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 0, "got auth response"); + if (r > 0) { + ok(buf[4] == Mysqlx::ServerMessages_Type_SESS_AUTHENTICATE_OK, + "response is SESS_AUTHENTICATE_OK"); + } + + close(fds[0]); + close(fds[1]); +} + +static void test_mysql41_auth_flow() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + Mysqlx::Session::AuthenticateStart auth_start; + auth_start.set_mech_name("MYSQL41"); + std::string auth_data = std::string("\0testdb\0testuser", 16); + auth_start.set_auth_data(auth_data); + std::string serialized; + auth_start.SerializeToString(&serialized); + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_START, + reinterpret_cast(serialized.data()), serialized.size()); + + sess.handler(); + + ok(sess.get_status() == MysqlxSession::X_AUTH_CHALLENGE_SENT, "in X_AUTH_CHALLENGE_SENT after MYSQL41 start"); + + uint8_t buf[4096]; + usleep(10000); + ssize_t r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 0, "got auth continue response"); + if (r > 0) { + ok(buf[4] == Mysqlx::ServerMessages_Type_SESS_AUTHENTICATE_CONTINUE, + "response is SESS_AUTHENTICATE_CONTINUE"); + } + + Mysqlx::Session::AuthenticateContinue auth_cont; + auth_cont.set_auth_data("*0123456789ABCDEF"); + std::string cont_serialized; + auth_cont.SerializeToString(&cont_serialized); + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_CONTINUE, + reinterpret_cast(cont_serialized.data()), cont_serialized.size()); + + sess.to_process = true; + sess.handler(); + + ok(sess.get_status() == MysqlxSession::WAITING_CLIENT_XMSG, + "in WAITING_CLIENT_XMSG after MYSQL41 complete"); + + usleep(10000); + r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 0, "got auth ok response"); + if (r > 0) { + ok(buf[4] == Mysqlx::ServerMessages_Type_SESS_AUTHENTICATE_OK, + "response is SESS_AUTHENTICATE_OK"); + } + + close(fds[0]); + close(fds[1]); +} + +static void test_unsupported_auth_method() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.to_process = true; + + Mysqlx::Session::AuthenticateStart auth_start; + auth_start.set_mech_name("SHA256_MEMORY"); + std::string serialized; + auth_start.SerializeToString(&serialized); + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_SESS_AUTHENTICATE_START, + reinterpret_cast(serialized.data()), serialized.size()); + + sess.handler(); + + ok(!sess.is_healthy(), "unhealthy after unsupported auth method"); + + uint8_t buf[4096]; + usleep(10000); + ssize_t r = read_x_frame(fds[1], buf, sizeof(buf)); + ok(r > 0, "got error response"); + if (r > 0) { + ok(buf[4] == Mysqlx::ServerMessages_Type_ERROR, "response is ERROR"); + } + + close(fds[0]); + close(fds[1]); +} + +static void test_sess_close_in_main_loop() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.set_status(MysqlxSession::WAITING_CLIENT_XMSG); + sess.to_process = true; + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_SESS_CLOSE, nullptr, 0); + + sess.handler(); + + ok(sess.get_status() == MysqlxSession::X_SESSION_CLOSED, "session closed after SESS_CLOSE"); + ok(!sess.is_healthy(), "unhealthy after session close"); + + close(fds[0]); + close(fds[1]); +} + +static void test_con_close_in_main_loop() { + int fds[2]; + socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + + MysqlxSession sess; + sess.init(fds[0], nullptr); + sess.set_status(MysqlxSession::WAITING_CLIENT_XMSG); + sess.to_process = true; + + write_x_frame(fds[1], Mysqlx::ClientMessages_Type_CON_CLOSE, nullptr, 0); + + sess.handler(); + + ok(sess.get_status() == MysqlxSession::X_SESSION_CLOSED, "session closed after CON_CLOSE in main loop"); + ok(!sess.is_healthy(), "unhealthy after con close"); + + close(fds[0]); + close(fds[1]); +} + +static void test_reset() { + MysqlxSession sess; + sess.set_status(MysqlxSession::WAITING_CLIENT_XMSG); + sess.reset(); + ok(sess.get_status() == MysqlxSession::NONE, "reset returns to NONE"); + ok(sess.is_healthy(), "reset marks healthy"); +} + +int main() { + plan(42); + + test_session_init(); + test_session_state_transitions(); + test_handler_no_data(); + test_capabilities_response(); + test_capabilities_set(); + test_con_close_during_connecting(); + test_unexpected_message_during_connecting(); + test_plain_auth_flow(); + test_mysql41_auth_flow(); + test_unsupported_auth_method(); + test_sess_close_in_main_loop(); + test_con_close_in_main_loop(); + test_reset(); + + return exit_status(); +}