From db442ba8cfc00977fe3a6f2eb195b366baef56c9 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 1 Apr 2025 02:25:06 +0500 Subject: [PATCH] Treat client_encoding as normal server parameter/variable. Optimized RESET ALL command handling. Few optimizations. Code cleanup --- include/PgSQL_Connection.h | 193 ++-------------------------- include/PgSQL_Session.h | 3 - include/PgSQL_Set_Stmt_Parser.h | 2 + include/PgSQL_Variables.h | 6 +- include/PgSQL_Variables_Validator.h | 5 +- include/proxysql_structs.h | 3 +- lib/PgSQL_Connection.cpp | 161 +++-------------------- lib/PgSQL_Protocol.cpp | 97 +++++++------- lib/PgSQL_Session.cpp | 147 +++++---------------- lib/PgSQL_Set_Stmt_Parser.cpp | 3 +- lib/PgSQL_Variables.cpp | 169 ++---------------------- lib/PgSQL_Variables_Validator.cpp | 31 +++++ 12 files changed, 157 insertions(+), 663 deletions(-) diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index 0d5e17bd5..e2271c4b3 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -38,7 +38,6 @@ enum PgSQL_Param_Name { PG_REQUIRE_AUTH, // Specifies the authentication method that the client requires from the server PG_CHANNEL_BINDING, // Controls the client's use of channel binding PG_CONNECT_TIMEOUT, // Maximum time to wait while connecting, in seconds - PG_CLIENT_ENCODING, // Sets the client_encoding configuration parameter for this connection PG_OPTIONS, // Specifies command-line options to send to the server at connection start PG_APPLICATION_NAME, // Specifies a value for the application_name configuration parameter PG_FALLBACK_APPLICATION_NAME, // Specifies a fallback value for the application_name configuration parameter @@ -90,7 +89,7 @@ static const Param_Name_Validation sslcertmode{ (const char* []) { "disable","al static const Param_Name_Validation target_session_attrs{ (const char* []) { "any","read-write","read-only","primary","standby","prefer-standby",nullptr },0 }; static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disable","random",nullptr },-1 }; - +// Excluding client_encoding since it is managed as part of the session variables #define PARAMETER_LIST \ PARAM("host", nullptr) \ PARAM("hostaddr", nullptr) \ @@ -102,7 +101,6 @@ static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disab PARAM("require_auth", &require_auth) \ PARAM("channel_binding", nullptr) \ PARAM("connect_timeout", nullptr) \ - PARAM("client_encoding", nullptr) \ PARAM("options", nullptr) \ PARAM("application_name", nullptr) \ PARAM("fallback_application_name", nullptr) \ @@ -141,200 +139,24 @@ constexpr const char* param_name[] = { #undef PARAM }; -static const std::unordered_map PgSQL_Param_Name_Str = { +// make sure all the keys are in lower case +static const std::unordered_map param_name_map = { #define PARAM(name, val) {name, val}, PARAMETER_LIST #undef PARAM }; -#if 0 -static const std::unordered_map PgSQL_Param_Name_Str = { - { "host", nullptr }, - { "hostaddr", nullptr}, - { "port", nullptr }, - { "database", nullptr}, - { "user", nullptr}, - { "password", nullptr}, - { "passfile", nullptr}, - { "require_auth", &require_auth}, - { "channel_binding", nullptr}, - { "connect_timeout", nullptr}, - { "client_encoding", nullptr}, - { "options", nullptr}, - { "application_name", nullptr}, - { "fallback_application_name", nullptr}, - { "keepalives", nullptr}, - { "keepalives_idle", nullptr}, - { "keepalives_interval", nullptr}, - { "keepalives_count", nullptr}, - { "tcp_user_timeout", nullptr}, - { "replication", &replication}, - { "gsseencmode", &gsseencmode}, - { "sslmode", &sslmode}, - { "requiressl", nullptr}, - { "sslcompression", nullptr}, - { "sslcert", nullptr}, - { "sslkey", nullptr}, - { "sslpassword", nullptr}, - { "sslcertmode", &sslcertmode}, - { "sslrootcert", nullptr}, - { "sslcrl", nullptr}, - { "sslcrldir", nullptr}, - { "sslsni", nullptr}, - { "requirepeer", nullptr}, - { "ssl_min_protocol_version", nullptr}, - { "ssl_max_protocol_version", nullptr}, - { "krbsrvname", nullptr}, - { "gsslib", nullptr}, - { "gssdelegation", nullptr}, - { "service", nullptr}, - { "target_session_attrs", &target_session_attrs}, - { "load_balance_hosts", &load_balance_hosts}, - // Environment Options - { "datestyle", nullptr}, - { "timezone", nullptr}, - { "geqo", nullptr} -}; - -static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SIZE] = { - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &require_auth, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &replication, - &gsseencmode, - &sslmode, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &sslcertmode, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &target_session_attrs, - &load_balance_hosts, - nullptr, - nullptr, - nullptr -}; -#endif - #define PG_EVENT_NONE 0x00 #define PG_EVENT_READ 0x01 #define PG_EVENT_WRITE 0x02 #define PG_EVENT_EXCEPT 0x04 #define PG_EVENT_TIMEOUT 0x08 -#if 0 -class PgSQL_Conn_Param { -private: - bool validate(PgSQL_Param_Name key, const char* val) { - assert(val); - const Param_Name_Validation* validation = PgSQL_Param_Name_Accepted_Values[key]; - - if (validation != nullptr && validation->accepted_values) { - const char** accepted_value = validation->accepted_values; - while (accepted_value != nullptr) { - if (strcmp(val, *accepted_value) == 0) { - return true; - } - } - } else { - return true; - } - - return false; - } - -public: - PgSQL_Conn_Param() {} - ~PgSQL_Conn_Param() { - for (int i = 0; i < PG_PARAM_SIZE; i++) { - if (param_value[i]) - free(param_value[i]); - } - } - - bool set_value(PgSQL_Param_Name key, const char* val) { - if (key == -1) return false; - if (validate(key, val)) { - if (param_value[key]) { - free(param_value[key]); - } - param_value[key] = strdup(val); - param_set.push_back(key); - return true; - } - return false; - } - - bool set_value(const char* key, const char* val) { - return set_value((PgSQL_Param_Name)get_param_name(key), val); - } - - void reset_value(PgSQL_Param_Name key) { - if (param_value[key]) { - free(param_value[key]); - } - param_value[key] = nullptr; - - // this has O(n) complexity. need to fix it.... - param_set.erase(param_set.begin() + static_cast(key)); - } - - const char* get_value(PgSQL_Param_Name key) const { - return param_value[key]; - } - - int get_param_name(const char* name) { - int key = -1; - - for (int i = 0; i < PG_PARAM_SIZE; i++) { - if (strcmp(name, PgSQL_Param_Name_Str[i]) == 0) { - key = i; - break; - } - } - if (key == -1) { - proxy_warning("Unrecognized connection option. Please report this as a bug for future enhancements:%s\n", name); - } - return key; - } - - std::vector param_set; - char* param_value[PG_PARAM_SIZE]{}; -}; -#endif - class PgSQL_Conn_Param { public: PgSQL_Conn_Param() {} ~PgSQL_Conn_Param() {} + bool set_value(const char* key, const char* val) { if (key == nullptr || val == nullptr) return false; connection_parameters[key] = val; @@ -353,6 +175,7 @@ public: } return nullptr; } + bool remove_value(const char* key) { auto it = connection_parameters.find(key); if (it != connection_parameters.end()) { @@ -361,9 +184,13 @@ public: } return false; } + + inline bool is_empty() const { return connection_parameters.empty(); } + + inline void clear() { connection_parameters.clear(); } @@ -524,8 +351,6 @@ public: void update_bytes_recv(uint64_t bytes_recv); void update_bytes_sent(uint64_t bytes_sent); void ProcessQueryAndSetStatusFlags(char* query_digest_text); - void set_charset(const char* charset); - void connect_start_SetCharset(const char* client_encoding, uint32_t hash = 0); inline const PGconn* get_pg_connection() const { return pgsql_conn; } inline int get_pg_server_version() { return PQserverVersion(pgsql_conn); } diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index aa1ec86b5..41c340705 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -283,9 +283,6 @@ private: bool handler_again___status_SETTING_LDAP_USER_VARIABLE(int*); bool handler_again___status_SETTING_SQL_MODE(int*); bool handler_again___status_SETTING_SESSION_TRACK_GTIDS(int*); -#endif // 0 - bool handler_again___status_CHANGING_CHARSET(int* _rc); -#if 0 bool handler_again___status_CHANGING_SCHEMA(int*); #endif // 0 bool handler_again___status_CONNECTING_SERVER(int*); diff --git a/include/PgSQL_Set_Stmt_Parser.h b/include/PgSQL_Set_Stmt_Parser.h index c052184b8..ee3732fa2 100644 --- a/include/PgSQL_Set_Stmt_Parser.h +++ b/include/PgSQL_Set_Stmt_Parser.h @@ -37,7 +37,9 @@ class PgSQL_Set_Stmt_Parser { void generateRE_parse1v2(); // First implemenation of the parser for TRANSACTION ISOLATION LEVEL and TRANSACTION READ/WRITE std::map> parse2(); +#if 0 std::string parse_character_set(); +#endif std::string remove_comments(const std::string& q); }; diff --git a/include/PgSQL_Variables.h b/include/PgSQL_Variables.h index 24c1cbc07..7e2daec8c 100644 --- a/include/PgSQL_Variables.h +++ b/include/PgSQL_Variables.h @@ -14,10 +14,8 @@ extern void print_backtrace(void); typedef bool (*pgsql_verify_var)(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash); typedef bool (*pgsql_update_var)(PgSQL_Session* session, int idx, int &_rc); -bool validate_charset(PgSQL_Session* session, int idx, int &_rc); bool update_server_variable(PgSQL_Session* session, int idx, int &_rc); bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash); -bool verify_set_names(PgSQL_Session* session); class PgSQL_Variables { static pgsql_verify_var verifiers[PGSQL_NAME_LAST_HIGH_WM]; @@ -31,13 +29,13 @@ public: PgSQL_Variables(); ~PgSQL_Variables(); - bool client_set_value(PgSQL_Session* session, int idx, const std::string& value); + bool client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx); bool client_set_hash_and_value(PgSQL_Session* session, int idx, const std::string& value, uint32_t hash); void client_reset_value(PgSQL_Session* session, int idx); const char* client_get_value(PgSQL_Session* session, int idx) const; uint32_t client_get_hash(PgSQL_Session* session, int idx) const; - void server_set_value(PgSQL_Session* session, int idx, const char* value); + void server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx); void server_set_hash_and_value(PgSQL_Session* session, int idx, const char* value, uint32_t hash); void server_reset_value(PgSQL_Session* session, int idx); const char* server_get_value(PgSQL_Session* session, int idx) const; diff --git a/include/PgSQL_Variables_Validator.h b/include/PgSQL_Variables_Validator.h index ae5469847..e6d1183e5 100644 --- a/include/PgSQL_Variables_Validator.h +++ b/include/PgSQL_Variables_Validator.h @@ -9,7 +9,8 @@ typedef enum { VARIABLE_TYPE_BOOL, /**< Boolean variable type. */ VARIABLE_TYPE_STRING, /**< String variable type. */ VARIABLE_TYPE_DATESTYLE, /**< DateStyle variable type. */ - VARIABLE_TYPE_MAINTENANCE_WORK_MEM + VARIABLE_TYPE_MAINTENANCE_WORK_MEM, /**< MaintenanceWorkMem variable type. */ + VARIABLE_TYPE_CLIENT_ENCODING /**< ClientEncoding variable type. */ } pgsql_variable_type_t; @@ -61,5 +62,5 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag extern const pgsql_variable_validator pgsql_variable_validator_bytea_output; extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits; extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem; - +extern const pgsql_variable_validator pgsql_variable_validator_client_encoding; #endif // PGSQL_VARIABLES_VALIDATOR_H diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 32f8e8a53..84a950768 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -1791,9 +1791,10 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag extern const pgsql_variable_validator pgsql_variable_validator_bytea_output; extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits; extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem; +extern const pgsql_variable_validator pgsql_variable_validator_client_encoding; pgsql_variable_st pgsql_tracked_variables[]{ - { PGSQL_CLIENT_ENCODING, SETTING_CHARSET, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_SET_TRANSACTION | PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), nullptr, { "names", nullptr } }, + { PGSQL_CLIENT_ENCODING, SETTING_VARIABLE, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_client_encoding, { "names", nullptr } }, { PGSQL_DATESTYLE, SETTING_VARIABLE, "datestyle", "datestyle", "ISO, MDY" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_datestyle, nullptr }, { PGSQL_INTERVALSTYLE, SETTING_VARIABLE, "intervalstyle", "intervalstyle", "postgres" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_intervalstyle, nullptr }, { PGSQL_STANDARD_CONFORMING_STRINGS, SETTING_VARIABLE, "standard_conforming_strings", "standard_conforming_strings", "on", (PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_bool, nullptr }, diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index 2cd818309..68baf89da 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -14,103 +14,6 @@ using json = nlohmann::json; #include "PgSQL_Query_Processor.h" #include "MySQL_Variables.h" - -#if 0 -// some of the code that follows is from mariadb client library memory allocator -typedef int myf; // Type of MyFlags in my_funcs -#define MYF(v) (myf) (v) -#define MY_KEEP_PREALLOC 1 -#define MY_ALIGN(A,L) (((A) + (L) - 1) & ~((L) - 1)) -#define ALIGN_SIZE(A) MY_ALIGN((A),sizeof(double)) -static void ma_free_root(MA_MEM_ROOT *root, myf MyFLAGS); -static void *ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size); -#define MAX(a,b) (((a) > (b)) ? (a) : (b)) - - -static void * ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size) -{ - size_t get_size; - void * point; - MA_USED_MEM *next= 0; - MA_USED_MEM **prev; - - Size= ALIGN_SIZE(Size); - - if ((*(prev= &mem_root->free))) - { - if ((*prev)->left < Size && - mem_root->first_block_usage++ >= 16 && - (*prev)->left < 4096) - { - next= *prev; - *prev= next->next; - next->next= mem_root->used; - mem_root->used= next; - mem_root->first_block_usage= 0; - } - for (next= *prev; next && next->left < Size; next= next->next) - prev= &next->next; - } - if (! next) - { /* Time to alloc new block */ - get_size= MAX(Size+ALIGN_SIZE(sizeof(MA_USED_MEM)), - (mem_root->block_size & ~1) * ( (mem_root->block_num >> 2) < 4 ? 4 : (mem_root->block_num >> 2) ) ); - - if (!(next = (MA_USED_MEM*) malloc(get_size))) - { - if (mem_root->error_handler) - (*mem_root->error_handler)(); - return((void *) 0); /* purecov: inspected */ - } - mem_root->block_num++; - next->next= *prev; - next->size= get_size; - next->left= get_size-ALIGN_SIZE(sizeof(MA_USED_MEM)); - *prev=next; - } - point= (void *) ((char*) next+ (next->size-next->left)); - if ((next->left-= Size) < mem_root->min_malloc) - { /* Full block */ - *prev=next->next; /* Remove block from list */ - next->next=mem_root->used; - mem_root->used=next; - mem_root->first_block_usage= 0; - } - return(point); -} - - -static void ma_free_root(MA_MEM_ROOT *root, myf MyFlags) -{ - MA_USED_MEM *next,*old; - - if (!root) - return; /* purecov: inspected */ - if (!(MyFlags & MY_KEEP_PREALLOC)) - root->pre_alloc=0; - - for ( next=root->used; next ;) - { - old=next; next= next->next ; - if (old != root->pre_alloc) - free(old); - } - for (next= root->free ; next ; ) - { - old=next; next= next->next ; - if (old != root->pre_alloc) - free(old); - } - root->used=root->free=0; - if (root->pre_alloc) - { - root->free=root->pre_alloc; - root->free->left=root->pre_alloc->size-ALIGN_SIZE(sizeof(MA_USED_MEM)); - root->free->next=0; - } -} -#endif // 0 - extern char * binary_sha1; #include "proxysql_find_charset.h" @@ -841,31 +744,22 @@ void PgSQL_Connection::connect_start() { conninfo << "sslmode='disable' "; // not supporting SSL } - if (myds && myds->sess) { - const char* charset = NULL; - uint32_t charset_hash = 0; - - // Take client character set and use it to connect to backend - charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING); - if (charset_hash != 0) - charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING); - - //if (!charset) { - // charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING]; - //} - + if (myds && myds->sess && myds->sess->client_myds) { // Client Encoding should be always set - assert(charset); - - connect_start_SetCharset(charset, charset_hash); - - escaped_str = escape_string_single_quotes_and_backslashes((char*)charset, false); + const char* client_charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING); + assert(client_charset); + uint32_t client_charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING); + assert(client_charset_hash); + const char* escaped_str = escape_string_backslash_spaces(client_charset); conninfo << "client_encoding='" << escaped_str << "' "; - if (escaped_str != charset) - free(escaped_str); + if (escaped_str != client_charset) + free((char*)escaped_str); + + // charset validation is already done + pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, client_charset, client_charset_hash); std::vector client_options; - client_options.reserve(PGSQL_NAME_LAST_LOW_WM + dynamic_variables_idx.size()); + client_options.reserve(PGSQL_NAME_LAST_LOW_WM + myds->sess->client_myds->myconn->dynamic_variables_idx.size()); // excluding PGSQL_CLIENT_ENCODING for (unsigned int idx = 1; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { @@ -873,9 +767,9 @@ void PgSQL_Connection::connect_start() { client_options.push_back(idx); } - for (std::vector::const_iterator it_c = dynamic_variables_idx.begin(); it_c != dynamic_variables_idx.end(); it_c++) { - assert(pgsql_variables.client_get_hash(myds->sess, *it_c)); - client_options.push_back(*it_c); + for (uint32_t idx : myds->sess->client_myds->myconn->dynamic_variables_idx) { + assert(pgsql_variables.client_get_hash(myds->sess, idx)); + client_options.push_back(idx); } if (client_options.empty() == false || @@ -1882,31 +1776,6 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) { } } -void PgSQL_Connection::set_charset(const char* charset) { - proxy_debug(PROXY_DEBUG_MYSQL_CONNPOOL, 4, "Setting client encoding %s\n", charset); - pgsql_variables.client_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset); -} - -void PgSQL_Connection::connect_start_SetCharset(const char* charset, uint32_t hash) { - assert(charset); - - int charset_encoding = PgSQL_Connection::char_to_encoding(charset); - - if (charset_encoding == -1) { - proxy_error("Cannot find character set [%s]\n", charset); - assert(0); - } - - /* We are connecting to backend setting charset in connection parameters. - * Client already has sent us a character set and client connection variables have been already set. - * Now we store this charset in server connection variables to avoid updating this variables on backend. - */ - if (hash == 0) - pgsql_variables.server_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset); - else - pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, charset, hash); -} - // this function is identical to async_query() , with the only exception that MyRS should never be set int PgSQL_Connection::async_send_simple_command(short event, char* stmt, unsigned long length) { PROXY_TRACE(); diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 08e7d2c70..63aabf8de 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -783,6 +783,7 @@ std::vector> PgSQL_Protocol::parse_options(c // Add key-value pair to the list if (!key.empty()) { + std::transform(key.begin(), key.end(), key.begin(), ::tolower); options_list.emplace_back(std::move(key), std::move(value)); } } @@ -822,24 +823,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* return EXECUTION_STATE::FAILED; } - // set charset but first verify - const char* charset = (*myds)->myconn->conn_params.get_value(PG_CLIENT_ENCODING); - // if client does not provide client_encoding, PostgreSQL uses the default client encoding. - // We need to save the default client encoding to send it to the client in case client doesn't provide one. - if (charset == NULL) charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING]; - - assert(charset); - - int charset_encoding = (*myds)->myconn->char_to_encoding(charset); - - if (charset_encoding == -1) { - proxy_error("Cannot find charset [%s]\n", charset); - proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "DS=%p , Session=%p , charset='%s'. Client charset not supported.\n", (*myds), (*myds)->sess, charset); - generate_error_packet(true, false, "Client charset not supported", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); - return EXECUTION_STATE::FAILED; - } - user = (char*)(*myds)->myconn->conn_params.get_value(PG_USER); if (!user || *user == '\0') { @@ -1075,20 +1059,26 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* parameters.reserve((*myds)->myconn->conn_params.connection_parameters.size()); - // new implementation + /* Note: Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. + PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, + but the error is encountered when executing a query. + This is behavious is intentional, as newer PostgreSQL versions may introduce parameters that ProxySQL is not yet aware of. + */ + // New implementation for (const auto& [param_name, param_val] : (*myds)->myconn->conn_params.connection_parameters) { std::string param_name_lowercase(param_name); std::transform(param_name_lowercase.cbegin(), param_name_lowercase.cend(), param_name_lowercase.begin(), ::tolower); - auto it = PgSQL_Param_Name_Str.find(param_name_lowercase.c_str()); - if (it != PgSQL_Param_Name_Str.end()) { + // check if parameter is part of connection-level parameters + auto itr = param_name_map.find(param_name_lowercase.c_str()); + if (itr != param_name_map.end()) { if (param_name_lowercase.compare("user") == 0 || param_name_lowercase.compare("password") == 0) { continue; } bool is_validation_success = false; - const Param_Name_Validation* validation = it->second; + const Param_Name_Validation* validation = itr->second; if (validation != nullptr && validation->accepted_values) { const char** accepted_value = validation->accepted_values; @@ -1112,26 +1102,23 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); free(errmsg); ret = EXECUTION_STATE::FAILED; + + // freeing userinfo->username and userinfo->password to prevent invalid password error generation. free(userinfo->username); free(userinfo->password); userinfo->username = strdup(""); userinfo->password = strdup(""); + // goto __exit_process_pkt_handshake_response; } if (param_name_lowercase.compare("database") == 0) { userinfo->set_dbname(param_val.empty() ? user : param_val.c_str()); - } else if (param_name_lowercase.compare("client_encoding") == 0) { - assert(sess); - assert(sess->client_myds); - PgSQL_Connection* myconn = sess->client_myds->myconn; - assert(myconn); - myconn->set_charset(charset); - sess->set_default_session_variable(PGSQL_CLIENT_ENCODING, charset); } else if (param_name_lowercase.compare("options") == 0) { options_list = parse_options(param_val.c_str()); } } else { + // session parameters/variables? parameters.push_back(std::make_pair(param_name_lowercase, param_val)); } } @@ -1140,60 +1127,69 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* userinfo->set_dbname(user); } + // Merge options with parameters. + // Options are processed first, followed by connection parameters. + // If a parameter is specified in both, the connection parameter takes precedence + // and overwrites the previosly set value. if (options_list.empty() == false) { options_list.reserve(parameters.size() + options_list.size()); options_list.insert(options_list.end(), std::make_move_iterator(parameters.begin()), std::make_move_iterator(parameters.end())); parameters = std::move(options_list); } - // assign default datestyle to current datestyle + // assign default datestyle to current datestyle. + // This is needed by PgSQL_DateStyle_Util::parse_datestyle sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); - for (const auto&[key, val] : parameters) { + for (const auto&[param_key, param_val] : parameters) { int idx = PGSQL_NAME_LAST_HIGH_WM; for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { if (i == PGSQL_NAME_LAST_LOW_WM) continue; - if (strcmp(pgsql_tracked_variables[i].set_variable_name, key.c_str()) == 0) { + if (strncmp(param_key.c_str(), pgsql_tracked_variables[i].set_variable_name, + strlen(pgsql_tracked_variables[i].set_variable_name)) == 0) { idx = i; break; } } if (idx != PGSQL_NAME_LAST_HIGH_WM) { - std::string value1 = val; + std::string value_copy = param_val; char* transformed_value = nullptr; if (pgsql_tracked_variables[idx].validator && pgsql_tracked_variables[idx].validator->validate && ( *pgsql_tracked_variables[idx].validator->validate)( - value1.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false + value_copy.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false ) { char* m = NULL; char* errmsg = NULL; - proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str()); m = (char*)"invalid value for parameter \"%s\": \"%s\""; - errmsg = (char*)malloc(value1.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m)); - sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + errmsg = (char*)malloc(value_copy.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m)); + sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str()); generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); free(errmsg); ret = EXECUTION_STATE::FAILED; + + // freeing userinfo->username and userinfo->password to prevent invalid password error generation. free(userinfo->username); free(userinfo->password); userinfo->username = strdup(""); userinfo->password = strdup(""); + // goto __exit_process_pkt_handshake_response; } if (transformed_value) { - value1 = transformed_value; + value_copy = transformed_value; free(transformed_value); } if (idx == PGSQL_DATESTYLE) { // get datestyle from connection parameters - std::string datestyle = val.empty() == false ? val : ""; + std::string datestyle = value_copy.empty() == false ? value_copy : ""; if (datestyle.empty()) { // No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable. @@ -1211,35 +1207,36 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* assert(datestyle.empty() == false); - if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str())) { + if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str(), false)) { // change current datestyle sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle); sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str()); } } else { - pgsql_variables.client_set_value(sess, idx, val.c_str()); - sess->set_default_session_variable((enum pgsql_variable_name)idx, val.c_str()); + pgsql_variables.client_set_value(sess, idx, value_copy.c_str(), false); + sess->set_default_session_variable((enum pgsql_variable_name)idx, value_copy.c_str()); } } else { - proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", key.c_str(), val.c_str()); - const char* escaped_str = escape_string_backslash_spaces(val.c_str()); - sess->untracked_option_parameters = "-c " + key + "=" + escaped_str + " "; - if (escaped_str != val) + // parameter provided is not part of the tracked variables. Will lock on hostgroup on next query. + const char* val_cstr = param_val.c_str(); + proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", param_key.c_str(), val_cstr); + const char* escaped_str = escape_string_backslash_spaces(val_cstr); + sess->untracked_option_parameters = "-c " + param_key + "=" + escaped_str + " "; + if (escaped_str != val_cstr) free((char*)escaped_str); } } - // set mandatory variables if not sent by client + // fill all crtical variables with default values, if not set by client for (int i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { - //if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) { - // continue; - //} if (pgsql_variables.client_get_hash(sess, i) != 0) continue; const char* val = pgsql_thread___default_variables[i]; - pgsql_variables.client_set_value(sess, i, val); + pgsql_variables.client_set_value(sess, i, val, false); sess->set_default_session_variable((pgsql_variable_name)i, val); } + + sess->client_myds->myconn->reorder_dynamic_variables_idx(); } else { // we always duplicate username and password, or crashes happen diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 712631fe9..42ef7896a 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -81,7 +81,8 @@ static inline char is_normal_char(char c) { } */ -static const std::array pgsql_critical_variables = { +static const std::array pgsql_critical_variables = { + "client_encoding", "datestyle", "intervalstyle", "standard_conforming_strings", @@ -1507,98 +1508,6 @@ bool PgSQL_Session::handler_again___status_SETTING_INIT_CONNECT(int* _rc) { return ret; } -bool PgSQL_Session::handler_again___status_CHANGING_CHARSET(int* _rc) { - assert(mybe->server_myds->myconn); - PgSQL_Data_Stream* myds = mybe->server_myds; - PgSQL_Connection* myconn = myds->myconn; - - /* Validate that server can support client's charset */ - if (!validate_charset(this, PGSQL_CLIENT_ENCODING, *_rc)) { - return false; - } - - myds->DSS = STATE_MARIADB_QUERY; - enum session_status st = status; - if (myds->mypolls == NULL) { - thread->mypolls.add(POLLIN | POLLOUT, mybe->server_myds->fd, mybe->server_myds, thread->curtime); - } - - std::string charset = pgsql_variables.client_get_value(this, PGSQL_CLIENT_ENCODING); - std::string query = "SET CLIENT_ENCODING TO '" + charset + "'"; - - int rc = myconn->async_send_simple_command(myds->revents, (char*)query.c_str(),query.length()); - - if (rc == 0) { - __sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1); - myds->DSS = STATE_MARIADB_GENERIC; - st = previous_status.top(); - previous_status.pop(); - NEXT_IMMEDIATE_NEW(st); - } else { - if (rc == -1) { - // the command failed - const bool error_present = myconn->is_error_present(); - PgHGM->p_update_pgsql_error_counter( - p_pgsql_error_type::pgsql, - myconn->parent->myhgc->hid, - myconn->parent->address, - myconn->parent->port, - (error_present ? 9999 : ER_PROXYSQL_OFFLINE_SRV) // TOFIX: 9999 is a placeholder for the actual error code - ); - if (error_present == false || (error_present == true && myconn->is_connection_in_reusable_state() == false)) { - bool retry_conn = false; - // client error, serious - proxy_error( - "Client trying to set a charset (%s) not supported by backend (%s:%d). Changing it to %s\n", - charset.c_str(), myconn->parent->address, myconn->parent->port, pgsql_tracked_variables[PGSQL_CLIENT_ENCODING].default_value - ); - detected_broken_connection(__FILE__, __LINE__, __func__, "during SET CLIENT_ENCODING", myconn); - if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) { - retry_conn = true; - } - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd = 0; - if (retry_conn) { - myds->DSS = STATE_NOT_INITIALIZED; - NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); - } - *_rc = -1; - return false; - } else { - proxy_warning("Error during SET CLIENT_ENCODING: %s\n", myconn->get_error_code_with_message().c_str()); - // we won't go back to PROCESSING_QUERY - st = previous_status.top(); - previous_status.pop(); - client_myds->myprot.generate_error_packet(true, true, myconn->get_error_message().c_str(), myconn->get_error_code(), false); - myds->destroy_MySQL_Connection_From_Pool(true); - myds->fd = 0; - RequestEnd(myds); //fix bug #682 - } - } else { - if (rc == -2) { - bool retry_conn = false; - proxy_error("Timeout during SET CLIENT_ENCODING on %s , %d\n", myconn->parent->address, myconn->parent->port); - PgHGM->p_update_pgsql_error_counter(p_pgsql_error_type::pgsql, myconn->parent->myhgc->hid, myconn->parent->address, myconn->parent->port, ER_PROXYSQL_CHANGE_USER_TIMEOUT); - if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) { - retry_conn = true; - } - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd = 0; - if (retry_conn) { - myds->DSS = STATE_NOT_INITIALIZED; - NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); - } - *_rc = -1; - return false; - } - else { - // rc==1 , nothing to do for now - } - } - } - return false; -} - bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, const char* var_name, const char* var_value, bool no_quote, bool set_transaction) { bool ret = false; @@ -1649,6 +1558,9 @@ bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, co query = NULL; } if (rc == 0) { + if (strncasecmp(var_name, "client_encoding", sizeof("client_encoding")-1) == 0) { + __sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1); + } myds->revents |= POLLOUT; // we also set again POLLOUT to send a query immediately! myds->DSS = STATE_MARIADB_GENERIC; st = previous_status.top(); @@ -3313,11 +3225,6 @@ handler_again: } if (locked_on_hostgroup == -1 || locked_on_hostgroup_and_all_variables_set == false) { - // verify charset - if (verify_set_names(this)) { - goto handler_again; - } - for (auto i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { auto client_hash = client_myds->myconn->var_hash[i]; #ifdef DEBUG @@ -3597,9 +3504,6 @@ bool PgSQL_Session::handler_again___multiple_statuses(int* rc) { //case SETTING_INIT_CONNECT: // ret = handler_again___status_SETTING_INIT_CONNECT(rc); break; - case SETTING_CHARSET: - ret = handler_again___status_CHANGING_CHARSET(rc); - break; default: break; } @@ -4349,7 +4253,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", var.c_str(), value1.c_str()); uint32_t var_hash_int = SpookyHash::Hash32(value1.c_str(), value1.length(), 10); if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str())) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str(), true)) { return false; } if (idx == PGSQL_DATESTYLE) { @@ -4629,7 +4533,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C l_free(pkt->size, pkt->ptr); return true; - }*/ else if (match_regexes && match_regexes[3]->match(dig)) { + } else if (match_regexes && match_regexes[3]->match(dig)) { std::vector> param_status; PgSQL_Set_Stmt_Parser parser(nq); std::string charset = parser.parse_character_set(); @@ -4681,7 +4585,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C RequestEnd(NULL); l_free(pkt->size, pkt->ptr); return true; - } else { + }*/ else { unable_to_parse_set_statement(lock_hostgroup); return false; } @@ -4720,26 +4624,41 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C if (strncasecmp(nq.c_str(), "ALL", 3) == 0) { - for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { + for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { - if (i == PGSQL_NAME_LAST_LOW_WM) - continue; + const char* name = pgsql_tracked_variables[idx].set_variable_name; + const char* value = get_default_session_variable((enum pgsql_variable_name)idx); - const char* name = pgsql_tracked_variables[i].set_variable_name; - const char* value = get_default_session_variable((enum pgsql_variable_name)i); - proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[i].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[i].idx, value)) { + if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, false)) { + return false; + } + if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { + param_status.push_back(std::make_pair(name, value)); + } + } + } + + for (int idx : client_myds->myconn->dynamic_variables_idx) { + assert(idx < PGSQL_NAME_LAST_HIGH_WM); + const char* name = pgsql_tracked_variables[idx].set_variable_name; + const char* value = get_default_session_variable((enum pgsql_variable_name)idx); + proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); + uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); + if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, false)) { return false; } - if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) { + if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { param_status.push_back(std::make_pair(name, value)); } } } + client_myds->myconn->reorder_dynamic_variables_idx(); + } else if (std::find(pgsql_variables.ignore_vars.begin(), pgsql_variables.ignore_vars.end(), nq) != pgsql_variables.ignore_vars.end()) { // this is a variable we parse but ignore #ifdef DEBUG @@ -4765,7 +4684,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value)) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, true)) { return false; } if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { diff --git a/lib/PgSQL_Set_Stmt_Parser.cpp b/lib/PgSQL_Set_Stmt_Parser.cpp index 42218f8f2..3f92ef12a 100644 --- a/lib/PgSQL_Set_Stmt_Parser.cpp +++ b/lib/PgSQL_Set_Stmt_Parser.cpp @@ -202,6 +202,7 @@ std::map> PgSQL_Set_Stmt_Parser::parse2() { return result; } +#if 0 std::string PgSQL_Set_Stmt_Parser::parse_character_set() { #ifdef DEBUG proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 4, "Parsing query %s\n", query.c_str()); @@ -224,7 +225,7 @@ std::string PgSQL_Set_Stmt_Parser::parse_character_set() { delete opt2; return value3; } - +#endif std::string PgSQL_Set_Stmt_Parser::remove_comments(const std::string& q) { std::string result = ""; bool in_multiline_comment = false; diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index 83473651c..af1f5ba6a 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -10,27 +10,12 @@ #include - -/*static inline char is_digit(char c) { - if(c >= '0' && c <= '9') - return 1; - return 0; -}*/ - pgsql_verify_var PgSQL_Variables::verifiers[PGSQL_NAME_LAST_HIGH_WM]; pgsql_update_var PgSQL_Variables::updaters[PGSQL_NAME_LAST_HIGH_WM]; - PgSQL_Variables::PgSQL_Variables() { // add here all the variables we want proxysql to recognize, but ignore ignore_vars.push_back("application_name"); - //ignore_vars.push_back("interactive_timeout"); - //ignore_vars.push_back("wait_timeout"); - //ignore_vars.push_back("net_read_timeout"); - //ignore_vars.push_back("net_write_timeout"); - //ignore_vars.push_back("net_buffer_length"); - //ignore_vars.push_back("read_buffer_size"); - //ignore_vars.push_back("read_rnd_buffer_size"); // NOTE: This variable has been temporarily ignored. Check issues #3442 and #3441. //ignore_vars.push_back("session_track_schema"); variables_regexp = ""; @@ -55,14 +40,11 @@ PgSQL_Variables::PgSQL_Variables() { } } for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - if (i == PGSQL_CLIENT_ENCODING) { - PgSQL_Variables::updaters[i] = NULL; - PgSQL_Variables::verifiers[i] = NULL; - } else { - PgSQL_Variables::verifiers[i] = verify_server_variable; - PgSQL_Variables::updaters[i] = update_server_variable; - } - if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { + + PgSQL_Variables::verifiers[i] = verify_server_variable; + PgSQL_Variables::updaters[i] = update_server_variable; + + if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { variables_regexp += pgsql_tracked_variables[i].set_variable_name; variables_regexp += "|"; @@ -135,7 +117,7 @@ void PgSQL_Variables::server_set_hash_and_value(PgSQL_Session* session, int idx, session->mybe->server_myds->myconn->variables[idx].value = strdup(value); } -bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value) { +bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx) { if (!session || !session->client_myds || !session->client_myds->myconn) { proxy_warning("Session validation failed\n"); return false; @@ -147,7 +129,7 @@ bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const st } session->client_myds->myconn->variables[idx].value = strdup(value.c_str()); - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx session->client_myds->myconn->reorder_dynamic_variables_idx(); } @@ -168,7 +150,7 @@ uint32_t PgSQL_Variables::client_get_hash(PgSQL_Session* session, int idx) const return session->client_myds->myconn->var_hash[idx]; } -void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value) { +void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx) { assert(session); assert(session->mybe); assert(session->mybe->server_myds); @@ -181,7 +163,7 @@ void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const ch } session->mybe->server_myds->myconn->variables[idx].value = strdup(value); - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); } @@ -253,98 +235,6 @@ bool PgSQL_Variables::verify_variable(PgSQL_Session* session, int idx) const { return ret; } -bool validate_charset(PgSQL_Session* session, int idx, int &_rc) { - /*if (idx == PGSQL_CLIENT_ENCODING || idx == PGSQL_SET_NAMES) { - PgSQL_Data_Stream *myds = session->mybe->server_myds; - PgSQL_Connection *myconn = myds->myconn; - char msg[128]; - const MARIADB_CHARSET_INFO *ci = NULL; - const char* replace_collation = NULL; - const char* not_supported_collation = NULL; - unsigned int replace_collation_nr = 0; - std::stringstream ss; - int charset = atoi(pgsql_variables.client_get_value(session, idx)); - if (charset >= 255 && myconn->pgsql->server_version[0] != '8') { - switch(pgsql_thread___handle_unknown_charset) { - case HANDLE_UNKNOWN_CHARSET__DISCONNECT_CLIENT: - snprintf(msg,sizeof(msg),"Can't initialize character set %s", pgsql_variables.client_get_value(session, idx)); - proxy_error("Can't initialize character set on %s, %d: Error %d (%s). Closing client connection %s:%d.\n", - myconn->parent->address, myconn->parent->port, 2019, msg, session->client_myds->addr.addr, session->client_myds->addr.port); - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd=0; - _rc=-1; - return false; - case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT_VERBOSE: - ci = proxysql_find_charset_nr(charset); - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot find character set [%s]\n", pgsql_variables.client_get_value(session, idx)); - assert(0); - // LCOV_EXCL_STOP - } - not_supported_collation = ci->name; - - if (idx == SQL_COLLATION_CONNECTION) { - ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]); - } else { - if (pgsql_thread___default_variables[idx]) { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]); - } else { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]); - } - } - - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot find character set [%s]\n", pgsql_thread___default_variables[idx]); - assert(0); - // LCOV_EXCL_STOP - } - replace_collation = ci->name; - replace_collation_nr = ci->nr; - - proxy_warning("Server doesn't support collation (%s) %s. Replacing it with the configured default (%d) %s. Client %s:%d\n", - pgsql_variables.client_get_value(session, idx), not_supported_collation, - replace_collation_nr, replace_collation, session->client_myds->addr.addr, session->client_myds->addr.port); - - ss << replace_collation_nr; - pgsql_variables.client_set_value(session, idx, ss.str()); - _rc=0; - return true; - case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT: - if (idx == SQL_COLLATION_CONNECTION) { - ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]); - } else { - if (pgsql_thread___default_variables[idx]) { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]); - } else { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]); - } - } - - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot filnd charset [%s]\n", pgsql_thread___default_variables[idx]); - assert(0); - // LCOV_EXCL_STOP - } - replace_collation_nr = ci->nr; - - ss << replace_collation_nr; - pgsql_variables.client_set_value(session, idx, ss.str()); - _rc=0; - return true; - default: - proxy_error("Wrong configuration of the handle_unknown_charset\n"); - _rc=-1; - return false; - } - } - }*/ - _rc=0; - return true; -} - bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) { bool no_quote = true; if (IS_PGTRACKED_VAR_OPTION_SET_QUOTE(pgsql_tracked_variables[idx])) no_quote = false; @@ -352,49 +242,12 @@ bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) { const char *set_var_name = pgsql_tracked_variables[idx].set_variable_name; bool ret = false; - if (!validate_charset(session, idx, _rc)) { - return false; - } - const char* value = pgsql_variables.client_get_value(session, idx); - pgsql_variables.server_set_value(session, idx, value); + pgsql_variables.server_set_value(session, idx, value, true); ret = session->handler_again___status_SETTING_GENERIC_VARIABLE(&_rc, set_var_name, value, no_quote, st); return ret; } -bool verify_set_names(PgSQL_Session* session) { - uint32_t client_charset_hash = pgsql_variables.client_get_hash(session, PGSQL_CLIENT_ENCODING); - if (client_charset_hash == 0) - return false; - - if (client_charset_hash != pgsql_variables.server_get_hash(session, PGSQL_CLIENT_ENCODING)) { - switch(session->status) { // this switch can be replaced with a simple previous_status.push(status), but it is here for readibility - case PROCESSING_QUERY: - session->previous_status.push(PROCESSING_QUERY); - break; - /* - case PROCESSING_STMT_PREPARE: - session->previous_status.push(PROCESSING_STMT_PREPARE); - break; - case PROCESSING_STMT_EXECUTE: - session->previous_status.push(PROCESSING_STMT_EXECUTE); - break; - */ - default: - // LCOV_EXCL_START - proxy_error("Wrong status %d\n", session->status); - assert(0); - break; - // LCOV_EXCL_STOP - } - session->set_status(SETTING_CHARSET); - const char* client_charset_value = pgsql_variables.client_get_value(session, PGSQL_CLIENT_ENCODING); - pgsql_variables.server_set_hash_and_value(session, PGSQL_CLIENT_ENCODING, client_charset_value, client_charset_hash); - return true; - } - return false; -} - inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash) { if (client_hash && client_hash != server_hash) { // Edge case for set charset command, because we do not know database character set @@ -427,7 +280,7 @@ inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t cli // LCOV_EXCL_STOP } session->set_status(pgsql_tracked_variables[idx].status); - pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx)); + pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx), true); return true; } return false; diff --git a/lib/PgSQL_Variables_Validator.cpp b/lib/PgSQL_Variables_Validator.cpp index f1b6b7edf..3286a72a8 100644 --- a/lib/PgSQL_Variables_Validator.cpp +++ b/lib/PgSQL_Variables_Validator.cpp @@ -422,6 +422,31 @@ bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const pa return true; } +bool pgsql_variable_validate_client_encoding(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) { + (void)params; + if (transformed_value) *transformed_value = nullptr; + + int charset_encoding = PgSQL_Connection::char_to_encoding(value); + + if (charset_encoding == -1) { + return false; + } + + if (transformed_value) { + *transformed_value = strdup(value); + + if (*transformed_value) { // Ensure strdup succeeded + char* tmp_val = *transformed_value; + + while (*tmp_val) { + *tmp_val = toupper((unsigned char)*tmp_val); + tmp_val++; + } + } + } + + return true; +} const pgsql_variable_validator pgsql_variable_validator_bool = { .type = VARIABLE_TYPE_BOOL, @@ -482,3 +507,9 @@ const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = { .uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB } }; + +const pgsql_variable_validator pgsql_variable_validator_client_encoding = { + .type = VARIABLE_TYPE_CLIENT_ENCODING, + .validate = &pgsql_variable_validate_client_encoding, + .params = {} +};