diff --git a/include/proxysql_admin.h b/include/proxysql_admin.h index 24a365294..12037861c 100644 --- a/include/proxysql_admin.h +++ b/include/proxysql_admin.h @@ -390,7 +390,7 @@ class ProxySQL_Admin { void admin_shutdown(); bool is_command(std::string); void send_MySQL_OK(MySQL_Protocol *myprot, char *msg, int rows=0); - void send_MySQL_ERR(MySQL_Protocol *myprot, char *msg); + void send_MySQL_ERR(MySQL_Protocol *myprot, char *msg, uint32_t code=1045); #ifdef DEBUG // these two following functions used to just call and return one function each // this approach was replaced when we introduced debug filters diff --git a/include/proxysql_utils.h b/include/proxysql_utils.h index eb510357a..77c21d3f1 100644 --- a/include/proxysql_utils.h +++ b/include/proxysql_utils.h @@ -210,4 +210,39 @@ std::string replace_str(const std::string& str, const std::string& match, const std::string generate_multi_rows_query(int rows, int params); void close_all_non_term_fd(std::vector excludeFDs); + +/** + * @brief Suggested implementation of 'mismatch_' from ['cppreference'](https://en.cppreference.com/w/cpp/algorithm/mismatch). + * + * @param first1 begin of the first range of the elements. + * @param last1 end of the first range of the elements. + * @param first2 begin of the second range of the elements. + * @param last2 end of the second range of the elements. + * @param p binary predicate which returns `true` if the elements should be treated as equal. + * + * @return std::pair with iterators to the first two non-equal elements. + */ +template +std::pair mismatch_( + InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2, BinaryPredicate p +) { + while (first1 != last1 && first2 != last2 && p(*first1, *first2)) { + ++first1, ++first2; + } + return std::make_pair(first1, first2); +} + +/** + * @brief Returns a sorted copy in ascending order of the supplied version numbers. + * @details Expected version numbers formats are: ['N', 'N.N', 'N.N.N', ...] + */ +std::vector sort_versions(std::vector versions); + +/** + * @brief Returns the expected error for query 'SELECT $$'. + * @param version The 'server_version' for which the error should match. + * @return A pair of the shape '{err_code,err_msg}'. + */ +std::pair get_dollar_quote_error(const char* version); + #endif diff --git a/lib/MySQL_Session.cpp b/lib/MySQL_Session.cpp index 215cff61c..94572d491 100644 --- a/lib/MySQL_Session.cpp +++ b/lib/MySQL_Session.cpp @@ -1365,6 +1365,22 @@ bool MySQL_Session::handler_special_queries(PtrSize_t *pkt) { l_free(pkt->size,pkt->ptr); return true; } + // MySQL client check command for dollars quote support, starting at version '8.1.0'. See #4300. + if ((pkt->size == strlen("SELECT $$") + 5) && strncasecmp("SELECT $$", (char*)pkt->ptr + 5, pkt->size - 5) == 0) { + pair err_info { get_dollar_quote_error(mysql_thread___server_version) }; + + client_myds->DSS=STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, err_info.first, (char *)"HY000", err_info.second, true); + client_myds->DSS=STATE_SLEEP; + status=WAITING_CLIENT_DATA; + + if (mirror==false) { + RequestEnd(NULL); + } + l_free(pkt->size,pkt->ptr); + + return true; + } if (locked_on_hostgroup >= 0 && (strncasecmp((char *)"SET ",(char *)pkt->ptr+5,4)==0)) { // this is a circuit breaker, we will send everything to the backend // diff --git a/lib/ProxySQL_Admin.cpp b/lib/ProxySQL_Admin.cpp index 12fcfc4d6..a420b3e1e 100644 --- a/lib/ProxySQL_Admin.cpp +++ b/lib/ProxySQL_Admin.cpp @@ -4561,6 +4561,14 @@ void admin_session_handler(MySQL_Session *sess, void *_pa, PtrSize_t *pkt) { goto __run_query; } + // MySQL client check command for dollars quote support, starting at version '8.1.0'. See #4300. + if (!strncasecmp("SELECT $$", query_no_space, strlen("SELECT $$"))) { + pair err_info { get_dollar_quote_error(mysql_thread___server_version) }; + SPA->send_MySQL_ERR(&sess->client_myds->myprot, const_cast(err_info.second), err_info.first); + run_query=false; + goto __run_query; + } + if (query_no_space_length==SELECT_VERSION_COMMENT_LEN) { if (!strncasecmp(SELECT_VERSION_COMMENT, query_no_space, query_no_space_length)) { l_free(query_length,query); @@ -11166,14 +11174,14 @@ void ProxySQL_Admin::send_MySQL_OK(MySQL_Protocol *myprot, char *msg, int rows) myds->DSS=STATE_SLEEP; } -void ProxySQL_Admin::send_MySQL_ERR(MySQL_Protocol *myprot, char *msg) { +void ProxySQL_Admin::send_MySQL_ERR(MySQL_Protocol *myprot, char *msg, uint32_t code) { assert(myprot); MySQL_Data_Stream *myds=myprot->get_myds(); myds->DSS=STATE_QUERY_SENT_DS; char *a = (char *)"ProxySQL Admin Error: "; char *new_msg = (char *)malloc(strlen(msg)+strlen(a)+1); sprintf(new_msg, "%s%s", a, msg); - myprot->generate_pkt_ERR(true,NULL,NULL,1,1045,(char *)"28000",new_msg); + myprot->generate_pkt_ERR(true,NULL,NULL,1,code,(char *)"28000",new_msg); free(new_msg); myds->DSS=STATE_SLEEP; } diff --git a/lib/proxysql_utils.cpp b/lib/proxysql_utils.cpp index 27d8051be..a4c3e25c3 100644 --- a/lib/proxysql_utils.cpp +++ b/lib/proxysql_utils.cpp @@ -1,4 +1,5 @@ #include "proxysql_utils.h" +#include "mysqld_error.h" #include #include @@ -431,3 +432,47 @@ void close_all_non_term_fd(std::vector excludeFDs) { } } } + +vector sort_versions(vector versions) { + std::sort( + versions.begin(), versions.end(), + [](const string& v1, const string& v2) { + const auto result = + mismatch_( + v1.cbegin(), v1.cend(), v2.cbegin(), v2.cend(), + [](const unsigned char lhs, const unsigned char rhs) { + return tolower(lhs) == tolower(rhs); + } + ); + + const bool not_equal = result.second != v2.cend(); + const bool fst_shorter = result.first == v1.cend(); + const bool fst_lesser = std::tolower(*result.first) < std::tolower(*result.second); + + return not_equal && (fst_shorter || fst_lesser); + } + ); + + return versions; +} + +std::pair get_dollar_quote_error(const char* version) { + const char* ER_PARSE_MSG { + "You have an error in your SQL syntax; check the manual that corresponds to your MySQL server" + " version for the right syntax to use near '$$' at line 1'" + }; + + if (strcasecmp(version,"8.1.0") == 0) { + return { ER_PARSE_ERROR, ER_PARSE_MSG }; + } else { + const vector sorted { sort_versions({"8.1.0", version}) }; + + if (sorted[0] == "8.1.0") { + // SQLSTATE: 42000 + return { ER_PARSE_ERROR, ER_PARSE_MSG }; + } else { + // SQLSTATE: 42S22 + return { ER_BAD_FIELD_ERROR, "Unknown column '$$' in 'field list'" }; + } + } +}