From 962e8c03e9dc03a71ded99e439148379d9e46ae6 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Sat, 13 Apr 2024 16:33:25 +0500 Subject: [PATCH] Added SSL connection support for PostgreSQL connection. * Added pgsql-have_ssl variable Note: There is a retrying mechanism when authentication fails if sslmode is prefer which seems incorrect. Need to investigate it. --- include/PgSQL_Authentication.h | 2 +- include/PgSQL_Protocol.h | 2 +- include/proxysql_structs.h | 2 ++ lib/PgSQL_Protocol.cpp | 18 +++++++++-- lib/PgSQL_Session.cpp | 57 ++++++++++++++++++++-------------- lib/PgSQL_Thread.cpp | 4 ++- 6 files changed, 55 insertions(+), 30 deletions(-) diff --git a/include/PgSQL_Authentication.h b/include/PgSQL_Authentication.h index af9782c07..1a429c11d 100644 --- a/include/PgSQL_Authentication.h +++ b/include/PgSQL_Authentication.h @@ -9,7 +9,7 @@ #ifndef PGSQL_ACCOUNT_DETAILS_T #define PGSQL_ACCOUNT_DETAILS_T -typedef struct _scram_keys { +struct _scram_keys { uint8_t scram_ClientKey[32]; uint8_t scram_ServerKey[32]; }; diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index d67ab3188..a1fe4205e 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -168,7 +168,7 @@ public: } bool generate_pkt_initial_handshake(bool send, void** ptr, unsigned int* len, uint32_t* thread_id, bool deprecate_eof_active) override; - bool process_startup_packet(unsigned char* pkt, unsigned int len); + bool process_startup_packet(unsigned char* pkt, unsigned int len, bool& ssl_request); EXECUTION_STATE process_handshake_response_packet(unsigned char* pkt, unsigned int len); void welcome_client(); diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 72903d3cb..5d502d4bf 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -803,6 +803,7 @@ PgSQL_HostGroups_Manager* PgHGM; __thread int pgsql_thread___authentication_method; __thread int pgsql_thread___show_processlist_extended; __thread char *pgsql_thread___server_version; +__thread bool pgsql_thread___have_ssl; //--------------------------- __thread char *mysql_thread___default_schema; @@ -983,6 +984,7 @@ extern PgSQL_HostGroups_Manager *PgHGM; extern __thread int pgsql_thread___authentication_method; extern __thread int pgsql_thread___show_processlist_extended; extern __thread char *pgsql_thread___server_version; +extern __thread bool pgsql_thread___have_ssl; //--------------------------- extern __thread char *mysql_thread___default_schema; diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 038f39bc1..7ce419e71 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -623,15 +623,27 @@ void PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt, bool startup) } } -bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len) { - +bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len, bool& ssl_request) { + + ssl_request = false; pgsql_hdr hdr{}; if (!get_header(pkt, len, &hdr)) { return false; } - //PG_PKT_STARTUP_V2 not supported + if (hdr.type == PG_PKT_SSLREQ) { + const bool have_ssl = pgsql_thread___have_ssl; + char* ssl_supported = (char*)malloc(1); + *ssl_supported = have_ssl ? 'S' : 'N'; + (*myds)->PSarrayOUT->add((void*)ssl_supported, 1); + (*myds)->sess->writeout(); + (*myds)->encrypted = have_ssl; + ssl_request = true; + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 8, "Session=%p , DS=%p. SSL_REQUEST:'%c'\n", (*myds)->sess, (*myds), *ssl_supported); + return true; + } + //PG_PKT_STARTUP_V2 not supported if (hdr.type != PG_PKT_STARTUP) { return false; } diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 94e769e84..87bb16920 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -3930,10 +3930,8 @@ __get_pkts_from_client: case CONNECTING_CLIENT: switch (client_myds->DSS) { - case STATE_SERVER_HANDSHAKE: - handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE(&pkt, &wrong_pass); - break; case STATE_SSL_INIT: + case STATE_SERVER_HANDSHAKE: handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE(&pkt, &wrong_pass); break; default: @@ -5642,21 +5640,32 @@ void PgSQL_Session::handler___status_CHANGING_USER_CLIENT___STATE_CLIENT_HANDSHA void PgSQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE(PtrSize_t* pkt, bool* wrong_pass) { bool is_encrypted = client_myds->encrypted; - - //bool handshake_response_return = client_myds->myprot.process_pkt_handshake_response((unsigned char*)pkt->ptr, pkt->size); bool handshake_response_return = false; + bool ssl_request = false; if (client_myds->auth_received_startup == false) { - if (client_myds->myprot.process_startup_packet((unsigned char*)pkt->ptr, pkt->size) == true && - client_myds->myprot.generate_pkt_initial_handshake(true, NULL, NULL, &thread_session_id, true) == true) { - client_myds->auth_received_startup = true; - l_free(pkt->size, pkt->ptr); - return; + if (client_myds->myprot.process_startup_packet((unsigned char*)pkt->ptr, pkt->size, ssl_request) == true ) { + if (ssl_request) { + if (is_encrypted == false && client_myds->encrypted == true) { + // switch to SSL... + } else { + // if sslmode is prefer, same connection will be used for plain text + l_free(pkt->size, pkt->ptr); + return; + } + } else if (client_myds->myprot.generate_pkt_initial_handshake(true, NULL, NULL, &thread_session_id, true) == true) { + client_myds->auth_received_startup = true; + l_free(pkt->size, pkt->ptr); + return; + } else { + assert(0); // this should never happen + } } else { - //send error packet here + *wrong_pass = true; //to forcefully close the connection. Is there a better way to do it? + client_myds->setDSS_STATE_QUERY_SENT_NET(); l_free(pkt->size, pkt->ptr); return; - } + } } bool handshake_err = true; @@ -5687,18 +5696,18 @@ void PgSQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE( (handshake_response_return == false) && // the authentication didn't complete (client_myds->encrypted == true) // client is asking for encryption ) { - // use SSL - proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 8, "Session=%p , DS=%p . SSL_INIT\n", this, client_myds); - client_myds->DSS = STATE_SSL_INIT; - client_myds->rbio_ssl = BIO_new(BIO_s_mem()); - client_myds->wbio_ssl = BIO_new(BIO_s_mem()); - client_myds->ssl = GloVars.get_SSL_new(); - SSL_set_fd(client_myds->ssl, client_myds->fd); - SSL_set_accept_state(client_myds->ssl); - SSL_set_bio(client_myds->ssl, client_myds->rbio_ssl, client_myds->wbio_ssl); - l_free(pkt->size, pkt->ptr); - proxysql_keylog_attach_callback(GloVars.get_SSL_ctx()); - return; + // use SSL + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 8, "Session=%p , DS=%p . SSL_INIT\n", this, client_myds); + client_myds->DSS = STATE_SSL_INIT; + client_myds->rbio_ssl = BIO_new(BIO_s_mem()); + client_myds->wbio_ssl = BIO_new(BIO_s_mem()); + client_myds->ssl = GloVars.get_SSL_new(); + SSL_set_fd(client_myds->ssl, client_myds->fd); + SSL_set_accept_state(client_myds->ssl); + SSL_set_bio(client_myds->ssl, client_myds->rbio_ssl, client_myds->wbio_ssl); + l_free(pkt->size, pkt->ptr); + proxysql_keylog_attach_callback(GloVars.get_SSL_ctx()); + return; } if ( diff --git a/lib/PgSQL_Thread.cpp b/lib/PgSQL_Thread.cpp index 883d2123f..2a6870740 100644 --- a/lib/PgSQL_Thread.cpp +++ b/lib/PgSQL_Thread.cpp @@ -4042,6 +4042,7 @@ void PgSQL_Thread::refresh_variables() { if (pgsql_thread___server_version) free(pgsql_thread___server_version); pgsql_thread___server_version = GloPTH->get_variable_string((char*)"server_version"); + pgsql_thread___have_ssl = (bool)GloPTH->get_variable_int((char*)"have_ssl"); if (mysql_thread___eventslog_filename) free(mysql_thread___eventslog_filename); mysql_thread___eventslog_filesize = GloPTH->get_variable_int((char*)"eventslog_filesize"); @@ -4062,7 +4063,7 @@ void PgSQL_Thread::refresh_variables() { mysql_thread___poll_timeout = GloPTH->get_variable_int((char*)"poll_timeout"); mysql_thread___poll_timeout_on_failure = GloPTH->get_variable_int((char*)"poll_timeout_on_failure"); mysql_thread___have_compress = (bool)GloPTH->get_variable_int((char*)"have_compress"); - mysql_thread___have_ssl = (bool)GloPTH->get_variable_int((char*)"have_ssl"); + mysql_thread___multiplexing = (bool)GloPTH->get_variable_int((char*)"multiplexing"); mysql_thread___log_unhealthy_connections = (bool)GloPTH->get_variable_int((char*)"log_unhealthy_connections"); mysql_thread___connection_warming = (bool)GloPTH->get_variable_int((char*)"connection_warming"); @@ -4130,6 +4131,7 @@ PgSQL_Thread::PgSQL_Thread() { last_processing_idles = 0; __thread_PgSQL_Thread_Variables_version = 0; pgsql_thread___server_version = NULL; + pgsql_thread___have_ssl = true; mysql_thread___init_connect = NULL; mysql_thread___ldap_user_variable = NULL; mysql_thread___add_ldap_user_comment = NULL;