diff --git a/lib/mysql_connection.cpp b/lib/mysql_connection.cpp index cd58f3275..19868584e 100644 --- a/lib/mysql_connection.cpp +++ b/lib/mysql_connection.cpp @@ -1123,15 +1123,21 @@ handler_again: if (myds) if (myds->sess != NULL) if (myds->sess->session_fast_forward == true) { - myds->encrypted = true; assert(myds->ssl==NULL); if (myds->ssl == NULL) { // check the definition of P_MARIADB_TLS P_MARIADB_TLS * matls = (P_MARIADB_TLS *)mysql->net.pvio->ctls; - myds->ssl = (SSL *)matls->ssl; - myds->rbio_ssl = BIO_new(BIO_s_mem()); - myds->wbio_ssl = BIO_new(BIO_s_mem()); - SSL_set_bio(myds->ssl, myds->rbio_ssl, myds->wbio_ssl); + if (matls != NULL) { + myds->encrypted = true; + myds->ssl = (SSL *)matls->ssl; + myds->rbio_ssl = BIO_new(BIO_s_mem()); + myds->wbio_ssl = BIO_new(BIO_s_mem()); + SSL_set_bio(myds->ssl, myds->rbio_ssl, myds->wbio_ssl); + } else { + // if mysql->options.use_ssl == 1 but matls == NULL + // it means that ProxySQL tried to use SSL to connect to the backend + // but the backend didn't support SSL + } } } } diff --git a/test/tap/tests/test_ssl_fast_forward-t.cpp b/test/tap/tests/test_ssl_fast_forward-t.cpp new file mode 100644 index 000000000..72fc28b63 --- /dev/null +++ b/test/tap/tests/test_ssl_fast_forward-t.cpp @@ -0,0 +1,193 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "tap.h" +#include "command_line.h" +#include "utils.h" + +char * username = (char *)"user1459"; +char * password = (char *)"pass1459"; + +int main(int argc, char** argv) { + CommandLine cl; + + if(cl.getEnv()) + return exit_status(); + + plan(3); + diag("Testing SSL and fast_forward"); + + MYSQL* mysqladmin = mysql_init(NULL); + if (!mysqladmin) + return exit_status(); + + if (!mysql_real_connect(mysqladmin, cl.host, cl.admin_username, cl.admin_password, NULL, cl.admin_port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", + __FILE__, __LINE__, mysql_error(mysqladmin)); + return exit_status(); + } + diag("We will reconfigure ProxySQL to use SQLite3 Server on hostgroup 1459, IP 127.0.0.1 and port 6030"); + { + std::vector queries = { + "SET mysql-have_ssl='false'", + "LOAD MYSQL VARIABLES TO RUNTIME", + "DELETE FROM mysql_servers WHERE hostgroup_id = 1459", + "INSERT INTO mysql_servers (hostgroup_id, hostname, port, use_ssl) VALUES (1459, '127.0.0.1', 6030, 0)", + "LOAD MYSQL SERVERS TO RUNTIME", + "DELETE FROM mysql_users WHERE username = 'user1459'", + "INSERT INTO mysql_users (username,password,fast_forward,default_hostgroup) VALUES ('" + std::string(username) + "','" + std::string(password) + "',1,1459)", + "LOAD MYSQL USERS TO RUNTIME", + }; + for (std::vector::iterator it = queries.begin(); it != queries.end(); it++) { + std::string q = *it; + diag("Running: %s", q.c_str()); + MYSQL_QUERY(mysqladmin, q.c_str()); + } + } + diag("We now create a connection not using SSL for either client or backend"); + MYSQL* mysql_1 = mysql_init(NULL); + if (!mysql_1) + return exit_status(); + + if (!mysql_real_connect(mysql_1, cl.host, username, password, NULL, cl.port, NULL, 0)) { + fprintf(stderr, "Failed to connect to database: Error: %s\n", + mysql_error(mysql_1)); + return exit_status(); + } + MYSQL_QUERY(mysql_1, "select 1"); + MYSQL_RES* result = mysql_store_result(mysql_1); + ok(mysql_num_rows(result) == 1, "Select statement should be executed on connection 1"); + mysql_free_result(result); + mysql_close(mysql_1); + + diag("We now create a connection using SSL for client connection only and no SSL for backend"); + { + std::vector queries = { + "SET mysql-have_ssl='true'", + "LOAD MYSQL VARIABLES TO RUNTIME", + }; + for (std::vector::iterator it = queries.begin(); it != queries.end(); it++) { + std::string q = *it; + diag("Running: %s", q.c_str()); + MYSQL_QUERY(mysqladmin, q.c_str()); + } + } + mysql_1 = mysql_init(NULL); + if (!mysql_1) + return exit_status(); + + if (!mysql_real_connect(mysql_1, cl.host, username, password, NULL, cl.port, NULL, CLIENT_SSL)) { + fprintf(stderr, "Failed to connect to database: Error: %s\n", + mysql_error(mysql_1)); + return exit_status(); + } + MYSQL_QUERY(mysql_1, "select 1"); + result = mysql_store_result(mysql_1); + ok(mysql_num_rows(result) == 1, "Select statement should be executed on connection 1"); + mysql_free_result(result); + mysql_close(mysql_1); + + + diag("We now create a connection trying to use SSL for backend connection (but SSL is disabled globally) and not SSL for frontend"); + { + std::vector queries = { + "SET mysql-have_ssl='false'", + "LOAD MYSQL VARIABLES TO RUNTIME", + "UPDATE mysql_servers SET use_ssl=1 WHERE hostgroup_id = 1459", + "LOAD MYSQL SERVERS TO RUNTIME", + }; + for (std::vector::iterator it = queries.begin(); it != queries.end(); it++) { + std::string q = *it; + diag("Running: %s", q.c_str()); + MYSQL_QUERY(mysqladmin, q.c_str()); + } + } + mysql_1 = mysql_init(NULL); + if (!mysql_1) + return exit_status(); + + if (!mysql_real_connect(mysql_1, cl.host, username, password, NULL, cl.port, NULL, 0)) { + fprintf(stderr, "Failed to connect to database: Error: %s\n", + mysql_error(mysql_1)); + return exit_status(); + } + MYSQL_QUERY(mysql_1, "select 1"); + result = mysql_store_result(mysql_1); + ok(mysql_num_rows(result) == 1, "Select statement should be executed on connection 1"); + mysql_free_result(result); + mysql_close(mysql_1); + + + diag("We now create a connection trying to use SSL for backend connection and not SSL for frontend"); + { + std::vector queries = { + "SET mysql-have_ssl='true'", + "LOAD MYSQL VARIABLES TO RUNTIME", + "UPDATE mysql_servers SET use_ssl=1 WHERE hostgroup_id = 1459", + "LOAD MYSQL SERVERS TO RUNTIME", + }; + for (std::vector::iterator it = queries.begin(); it != queries.end(); it++) { + std::string q = *it; + diag("Running: %s", q.c_str()); + MYSQL_QUERY(mysqladmin, q.c_str()); + } + } + mysql_1 = mysql_init(NULL); + if (!mysql_1) + return exit_status(); + + if (!mysql_real_connect(mysql_1, cl.host, username, password, NULL, cl.port, NULL, 0)) { + fprintf(stderr, "Failed to connect to database: Error: %s\n", + mysql_error(mysql_1)); + return exit_status(); + } + MYSQL_QUERY(mysql_1, "select 1"); + result = mysql_store_result(mysql_1); + ok(mysql_num_rows(result) == 1, "Select statement should be executed on connection 1"); + mysql_free_result(result); + mysql_close(mysql_1); + +/* + diag("We now create a connection using SSL for both client or backend"); + { + std::vector queries = { + "SET mysql-have_ssl='true'", + "LOAD MYSQL VARIABLES TO RUNTIME", + "UPDATE mysql_servers SET use_ssl=1 WHERE hostgroup_id = 1459", + "LOAD MYSQL SERVERS TO RUNTIME", + }; + for (std::vector::iterator it = queries.begin(); it != queries.end(); it++) { + std::string q = *it; + diag("Running: %s", q.c_str()); + MYSQL_QUERY(mysqladmin, q.c_str()); + } + } + mysql_1 = mysql_init(NULL); + if (!mysql_1) + return exit_status(); + + if (!mysql_real_connect(mysql_1, cl.host, username, password, NULL, cl.port, NULL, CLIENT_SSL)) { + fprintf(stderr, "Failed to connect to database: Error: %s\n", + mysql_error(mysql_1)); + return exit_status(); + } + MYSQL_QUERY(mysql_1, "select 1"); + result = mysql_store_result(mysql_1); + ok(mysql_num_rows(result) == 1, "Select statement should be executed on connection 1"); + mysql_free_result(result); + mysql_close(mysql_1); + +*/ + + mysql_close(mysqladmin); + + return exit_status(); +} +