From fbb9500cb9826515a39e7140be4803df436974d6 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Mon, 31 Mar 2025 10:57:04 +0500 Subject: [PATCH 1/7] Refactored connection parameters * Accept all parameters sent by the client, mirroring PostgreSQL's permissive handling. * Validate and apply parameters only after successful authentication. This avoids wasting resources on invalid connections. --- include/PgSQL_Connection.h | 220 +++++++++++++++++++-------- include/PgSQL_Protocol.h | 2 +- include/PgSQL_Session.h | 6 - lib/PgSQL_HostGroups_Manager.cpp | 2 +- lib/PgSQL_Protocol.cpp | 246 +++++++++++++++++++++---------- lib/PgSQL_Session.cpp | 9 +- 6 files changed, 338 insertions(+), 147 deletions(-) diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index a8048c05b..0d5e17bd5 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -70,74 +70,131 @@ enum PgSQL_Param_Name { PG_TARGET_SESSION_ATTRS, // Determines whether the session must have certain properties to be acceptable PG_LOAD_BALANCE_HOSTS, // Controls the order in which the client tries to connect to the available hosts and addresses // Environment Options - PG_DATESTYLE, // Sets the value of the DateStyle parameter - PG_TIMEZONE, // Sets the value of the TimeZone parameter - PG_GEQO, // Enables or disables the use of the GEQO query optimizer + //PG_DATESTYLE, // Sets the value of the DateStyle parameter + //PG_TIMEZONE, // Sets the value of the TimeZone parameter + //PG_GEQO, // Enables or disables the use of the GEQO query optimizer PG_PARAM_SIZE }; -static const char* PgSQL_Param_Name_Str[] = { - "host", - "hostaddr", - "port", - "database", - "user", - "password", - "passfile", - "require_auth", - "channel_binding", - "connect_timeout", - "client_encoding", - "options", - "application_name", - "fallback_application_name", - "keepalives", - "keepalives_idle", - "keepalives_interval", - "keepalives_count", - "tcp_user_timeout", - "replication", - "gsseencmode", - "sslmode", - "requiressl", - "sslcompression", - "sslcert", - "sslkey", - "sslpassword", - "sslcertmode", - "sslrootcert", - "sslcrl", - "sslcrldir", - "sslsni", - "requirepeer", - "ssl_min_protocol_version", - "ssl_max_protocol_version", - "krbsrvname", - "gsslib", - "gssdelegation", - "service", - "target_session_attrs", - "load_balance_hosts", - // Environment Options - "datestyle", - "timezone", - "geqo" -}; - struct Param_Name_Validation { const char** accepted_values; int default_value_idx; +}; +static const Param_Name_Validation require_auth{ (const char* []) { "password","md5","gss","sspi","scram-sha-256","none",nullptr },-1 }; +static const Param_Name_Validation replication{ (const char* []) { "true","on","yes","1","database","false","off","no","0",nullptr },-1 }; +static const Param_Name_Validation gsseencmode{ (const char* []) { "disable","prefer","require",nullptr },1 }; +static const Param_Name_Validation sslmode{ (const char* []) { "disable","allow","prefer","require","verify-ca","verify-full",nullptr },2 }; +static const Param_Name_Validation sslcertmode{ (const char* []) { "disable","allow","require",nullptr },1 }; +static const Param_Name_Validation target_session_attrs{ (const char* []) { "any","read-write","read-only","primary","standby","prefer-standby",nullptr },0 }; +static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disable","random",nullptr },-1 }; + + +#define PARAMETER_LIST \ + PARAM("host", nullptr) \ + PARAM("hostaddr", nullptr) \ + PARAM("port", nullptr) \ + PARAM("database", nullptr) \ + PARAM("user", nullptr) \ + PARAM("password", nullptr) \ + PARAM("passfile", nullptr) \ + PARAM("require_auth", &require_auth) \ + PARAM("channel_binding", nullptr) \ + PARAM("connect_timeout", nullptr) \ + PARAM("client_encoding", nullptr) \ + PARAM("options", nullptr) \ + PARAM("application_name", nullptr) \ + PARAM("fallback_application_name", nullptr) \ + PARAM("keepalives", nullptr) \ + PARAM("keepalives_idle", nullptr) \ + PARAM("keepalives_interval", nullptr) \ + PARAM("keepalives_count", nullptr) \ + PARAM("tcp_user_timeout", nullptr) \ + PARAM("replication", &replication) \ + PARAM("gsseencmode", &gsseencmode) \ + PARAM("sslmode", &sslmode) \ + PARAM("requiressl", nullptr) \ + PARAM("sslcompression", nullptr) \ + PARAM("sslcert", nullptr) \ + PARAM("sslkey", nullptr) \ + PARAM("sslpassword", nullptr) \ + PARAM("sslcertmode", &sslcertmode) \ + PARAM("sslrootcert", nullptr) \ + PARAM("sslcrl", nullptr) \ + PARAM("sslcrldir", nullptr) \ + PARAM("sslsni", nullptr) \ + PARAM("requirepeer", nullptr) \ + PARAM("ssl_min_protocol_version", nullptr) \ + PARAM("ssl_max_protocol_version", nullptr) \ + PARAM("krbsrvname", nullptr) \ + PARAM("gsslib", nullptr) \ + PARAM("gssdelegation", nullptr) \ + PARAM("service", nullptr) \ + PARAM("target_session_attrs", &target_session_attrs) \ + PARAM("load_balance_hosts", &load_balance_hosts) + +// Generate parameter array +constexpr const char* param_name[] = { +#define PARAM(name, val) name, + PARAMETER_LIST +#undef PARAM }; -static const Param_Name_Validation require_auth {(const char*[]){"password","md5","gss","sspi","scram-sha-256","none",nullptr},-1}; -static const Param_Name_Validation replication {(const char*[]){"true","on","yes","1","database","false","off","no","0",nullptr},-1}; -static const Param_Name_Validation gsseencmode {(const char*[]){"disable","prefer","require",nullptr},1}; -static const Param_Name_Validation sslmode {(const char*[]){"disable","allow","prefer","require","verify-ca","verify-full",nullptr},2}; -static const Param_Name_Validation sslcertmode {(const char*[]){"disable","allow","require",nullptr},1}; -static const Param_Name_Validation target_session_attrs {(const char*[]){"any","read-write","read-only","primary","standby","prefer-standby",nullptr},0 }; -static const Param_Name_Validation load_balance_hosts {(const char*[]){"disable","random",nullptr},-1}; +static const std::unordered_map PgSQL_Param_Name_Str = { +#define PARAM(name, val) {name, val}, + PARAMETER_LIST +#undef PARAM +}; + +#if 0 +static const std::unordered_map PgSQL_Param_Name_Str = { + { "host", nullptr }, + { "hostaddr", nullptr}, + { "port", nullptr }, + { "database", nullptr}, + { "user", nullptr}, + { "password", nullptr}, + { "passfile", nullptr}, + { "require_auth", &require_auth}, + { "channel_binding", nullptr}, + { "connect_timeout", nullptr}, + { "client_encoding", nullptr}, + { "options", nullptr}, + { "application_name", nullptr}, + { "fallback_application_name", nullptr}, + { "keepalives", nullptr}, + { "keepalives_idle", nullptr}, + { "keepalives_interval", nullptr}, + { "keepalives_count", nullptr}, + { "tcp_user_timeout", nullptr}, + { "replication", &replication}, + { "gsseencmode", &gsseencmode}, + { "sslmode", &sslmode}, + { "requiressl", nullptr}, + { "sslcompression", nullptr}, + { "sslcert", nullptr}, + { "sslkey", nullptr}, + { "sslpassword", nullptr}, + { "sslcertmode", &sslcertmode}, + { "sslrootcert", nullptr}, + { "sslcrl", nullptr}, + { "sslcrldir", nullptr}, + { "sslsni", nullptr}, + { "requirepeer", nullptr}, + { "ssl_min_protocol_version", nullptr}, + { "ssl_max_protocol_version", nullptr}, + { "krbsrvname", nullptr}, + { "gsslib", nullptr}, + { "gssdelegation", nullptr}, + { "service", nullptr}, + { "target_session_attrs", &target_session_attrs}, + { "load_balance_hosts", &load_balance_hosts}, + // Environment Options + { "datestyle", nullptr}, + { "timezone", nullptr}, + { "geqo", nullptr} +}; static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SIZE] = { nullptr, @@ -185,6 +242,7 @@ static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SI nullptr, nullptr }; +#endif #define PG_EVENT_NONE 0x00 #define PG_EVENT_READ 0x01 @@ -192,6 +250,7 @@ static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SI #define PG_EVENT_EXCEPT 0x04 #define PG_EVENT_TIMEOUT 0x08 +#if 0 class PgSQL_Conn_Param { private: bool validate(PgSQL_Param_Name key, const char* val) { @@ -270,6 +329,50 @@ public: std::vector param_set; char* param_value[PG_PARAM_SIZE]{}; }; +#endif + +class PgSQL_Conn_Param { +public: + PgSQL_Conn_Param() {} + ~PgSQL_Conn_Param() {} + bool set_value(const char* key, const char* val) { + if (key == nullptr || val == nullptr) return false; + connection_parameters[key] = val; + return true; + } + + inline + const char* get_value(PgSQL_Param_Name key) const { + return get_value(param_name[key]); + } + + const char* get_value(const char* key) const { + auto it = connection_parameters.find(key); + if (it != connection_parameters.end()) { + return it->second.c_str(); + } + return nullptr; + } + bool remove_value(const char* key) { + auto it = connection_parameters.find(key); + if (it != connection_parameters.end()) { + connection_parameters.erase(it); + return true; + } + return false; + } + bool is_empty() const { + return connection_parameters.empty(); + } + void clear() { + connection_parameters.clear(); + } + +private: + std::map connection_parameters; + friend class PgSQL_Session; + friend class PgSQL_Protocol; +}; class PgSQL_Variable { public: @@ -418,7 +521,6 @@ public: PGresult* get_result(); void next_multi_statement_result(PGresult* result); bool set_single_row_mode(); - void optimize() {} void update_bytes_recv(uint64_t bytes_recv); void update_bytes_sent(uint64_t bytes_sent); void ProcessQueryAndSetStatusFlags(char* query_digest_text); diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index db903c205..a1d2a23c1 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -986,7 +986,7 @@ private: * @note This function iterates through the key-value pairs in the startup * packet and stores them in the connection parameters object. */ - void load_conn_parameters(pgsql_hdr* pkt, bool startup); + bool load_conn_parameters(pgsql_hdr* pkt); /** * @brief Handles the client's first message in a SCRAM-SHA-256 diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index cfc6d9335..aa1ec86b5 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -414,12 +414,6 @@ public: int current_hostgroup; int default_hostgroup; int previous_hostgroup; - /** - * @brief Charset directly specified by the client. Supplied and updated via 'HandshakeResponse' - * and 'COM_CHANGE_USER' packets. - * @details Used when session needs to be restored via 'COM_RESET_CONNECTION'. - */ - int default_charset; int locked_on_hostgroup; int next_query_flagIN; int mirror_hostgroup; diff --git a/lib/PgSQL_HostGroups_Manager.cpp b/lib/PgSQL_HostGroups_Manager.cpp index b3455ccc6..bd892e3e6 100644 --- a/lib/PgSQL_HostGroups_Manager.cpp +++ b/lib/PgSQL_HostGroups_Manager.cpp @@ -1789,7 +1789,7 @@ void PgSQL_HostGroups_Manager::push_MyConn_to_pool(PgSQL_Connection *c, bool _lo mysrvc->ConnectionsUsed->add(c); // Add the connection back to the list of used connections destroy_MyConn_from_pool(c, false); // Destroy the connection from the pool } else {*/ - c->optimize(); + //c->optimize(); mysrvc->ConnectionsFree->add(c); //} } else { diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 65084631f..08e7d2c70 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -630,26 +630,34 @@ unsigned int get_string(const char* data, unsigned int len, const char** dst_p) return (nul + 1 - data); } -void PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt, bool startup) +bool PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt) { - const char* key, * val; - unsigned int read_pos = 0; - - while (1) { + int32_t offset = 0; - int pos = get_string(((const char*)pkt->data.ptr) + read_pos, pkt->data.size - read_pos, &key); - if (pos == 0) return; + while (offset < pkt->data.size) { + char* nameptr = (char*)pkt->data.ptr + offset; + int32_t valoffset; + char* valptr; - read_pos += pos; + if (*nameptr == '\0') + break; /* found packet terminator */ + valoffset = offset + strlen(nameptr) + 1; + if (valoffset >= pkt->data.size) + break; /* missing value, will complain below */ + valptr = (char*)pkt->data.ptr + valoffset; - pos = get_string(((const char*)pkt->data.ptr) + read_pos, pkt->data.size - read_pos, &val); - if (pos == 0) return; + (*myds)->myconn->conn_params.set_value(nameptr, valptr); - read_pos += pos; + offset = valoffset + strlen(valptr) + 1; + } - //slog_debug(server, "S: param: %s = %s", key, val); - (*myds)->myconn->conn_params.set_value(key, val); + if (offset != pkt->data.size - 1) { + proxy_error("Malformed startup packet was received from client %s:%d\n", (*myds)->addr.addr, (*myds)->addr.port); + return false; } + + return true; + } bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len, bool& ssl_request) { @@ -677,13 +685,19 @@ bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len return false; } - load_conn_parameters(&hdr, true); + if (!load_conn_parameters(&hdr)) { + proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. malformed startup packet.\n", (*myds)->sess, (*myds)); + generate_error_packet(true, false, "invalid startup packet layout: expected terminator as last byte", + PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); + return false; + } const unsigned char* user = (unsigned char*)(*myds)->myconn->conn_params.get_value(PG_USER); if (!user || *user == '\0') { - proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. no username supplied.\n", (*myds), (*myds)->sess); - generate_error_packet(true, false, "no username supplied", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); + proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. no username supplied.\n", (*myds)->sess, (*myds)); + generate_error_packet(true, false, "no PostgreSQL user name specified in startup packet", + PGSQL_ERROR_CODES::ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION, true); return false; } @@ -825,8 +839,6 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* generate_error_packet(true, false, "Client charset not supported", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); return EXECUTION_STATE::FAILED; } - - (*myds)->sess->default_charset = charset_encoding; user = (char*)(*myds)->myconn->conn_params.get_value(PG_USER); @@ -1058,92 +1070,176 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* userinfo->username = strdup((const char*)user); userinfo->password = strdup((const char*)password); - const char* db = (*myds)->myconn->conn_params.get_value(PG_DATABASE); - userinfo->set_dbname(db ? db : userinfo->username); + std::vector> parameters; + std::vector> options_list; - assert(sess); - assert(sess->client_myds); + parameters.reserve((*myds)->myconn->conn_params.connection_parameters.size()); - PgSQL_Connection* myconn = sess->client_myds->myconn; - assert(myconn); - myconn->set_charset(charset); - sess->set_default_session_variable(PGSQL_CLIENT_ENCODING, charset); + // new implementation + for (const auto& [param_name, param_val] : (*myds)->myconn->conn_params.connection_parameters) { + std::string param_name_lowercase(param_name); + std::transform(param_name_lowercase.cbegin(), param_name_lowercase.cend(), param_name_lowercase.begin(), ::tolower); - // get datestyle from connection parameters - const char* datestyle_tmp = (*myds)->myconn->conn_params.get_value(PG_DATESTYLE); - std::string datestyle = datestyle_tmp ? datestyle_tmp : ""; + auto it = PgSQL_Param_Name_Str.find(param_name_lowercase.c_str()); + if (it != PgSQL_Param_Name_Str.end()) { - if (datestyle.empty()) { - // No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable. - datestyle = pgsql_thread___default_variables[PGSQL_DATESTYLE]; - } else { - PgSQL_DateStyle_t datestyle_parsed = PgSQL_DateStyle_Util::parse_datestyle(datestyle); - - // If DateStyle provided in the connection parameters is incomplete, the missing parts will be taken from the default DateStyle. - if (datestyle_parsed.format == DATESTYLE_FORMAT_NONE || datestyle_parsed.order == DATESTYLE_ORDER_NONE) { - PgSQL_DateStyle_t datestyle_default = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); - datestyle = PgSQL_DateStyle_Util::datestyle_to_string(datestyle_parsed, datestyle_default); - } - } + if (param_name_lowercase.compare("user") == 0 || param_name_lowercase.compare("password") == 0) { + continue; + } - assert(datestyle.empty() == false); + bool is_validation_success = false; + const Param_Name_Validation* validation = it->second; - if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str())) { - // change current datestyle - sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle); - sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str()); - } + if (validation != nullptr && validation->accepted_values) { + const char** accepted_value = validation->accepted_values; + while (accepted_value != nullptr) { + if (strcmp(param_val.c_str(), *accepted_value) == 0) { + is_validation_success = true; + break; + } + } + } else { + is_validation_success = true; + } - // get timezone from connection parameters - const char* timezone = (*myds)->myconn->conn_params.get_value(PG_TIMEZONE); - if (timezone == NULL) - timezone = pgsql_thread___default_variables[PGSQL_TIMEZONE]; + if (is_validation_success == false) { + char* m = NULL; + char* errmsg = NULL; + proxy_error("invalid value for parameter \"%s\": \"%s\"\n", param_name.c_str(), param_val.c_str()); + m = (char*)"invalid value for parameter \"%s\": \"%s\""; + errmsg = (char*)malloc(param_val.length() + param_name.length() + strlen(m)); + sprintf(errmsg, m, param_name.c_str(), param_val.c_str()); + generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); + free(errmsg); + ret = EXECUTION_STATE::FAILED; + free(userinfo->username); + free(userinfo->password); + userinfo->username = strdup(""); + userinfo->password = strdup(""); + goto __exit_process_pkt_handshake_response; + } - pgsql_variables.client_set_value(sess, PGSQL_TIMEZONE, timezone); - sess->set_default_session_variable(PGSQL_TIMEZONE, timezone); + if (param_name_lowercase.compare("database") == 0) { + userinfo->set_dbname(param_val.empty() ? user : param_val.c_str()); + } else if (param_name_lowercase.compare("client_encoding") == 0) { + assert(sess); + assert(sess->client_myds); + PgSQL_Connection* myconn = sess->client_myds->myconn; + assert(myconn); + myconn->set_charset(charset); + sess->set_default_session_variable(PGSQL_CLIENT_ENCODING, charset); + } else if (param_name_lowercase.compare("options") == 0) { + options_list = parse_options(param_val.c_str()); + } + } else { + parameters.push_back(std::make_pair(param_name_lowercase, param_val)); + } + } - const char* intervalstyle = pgsql_thread___default_variables[PGSQL_INTERVALSTYLE]; - if (intervalstyle) { - pgsql_variables.client_set_value(sess, PGSQL_INTERVALSTYLE, intervalstyle); - sess->set_default_session_variable(PGSQL_INTERVALSTYLE, intervalstyle); + if (userinfo->dbname == nullptr) { + userinfo->set_dbname(user); } - // get standard_conforming_strings from connection parameters - const char* standard_conforming_strings = pgsql_thread___default_variables[PGSQL_STANDARD_CONFORMING_STRINGS]; - if (standard_conforming_strings) { - pgsql_variables.client_set_value(sess, PGSQL_STANDARD_CONFORMING_STRINGS, standard_conforming_strings); - sess->set_default_session_variable(PGSQL_STANDARD_CONFORMING_STRINGS, standard_conforming_strings); + if (options_list.empty() == false) { + options_list.reserve(parameters.size() + options_list.size()); + options_list.insert(options_list.end(), std::make_move_iterator(parameters.begin()), std::make_move_iterator(parameters.end())); + parameters = std::move(options_list); } - const char* options = (*myds)->myconn->conn_params.get_value(PG_OPTIONS); + // assign default datestyle to current datestyle + sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); - auto options_list = parse_options(options); + for (const auto&[key, val] : parameters) { - for (auto& option : options_list) { int idx = PGSQL_NAME_LAST_HIGH_WM; for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - if (i == PGSQL_NAME_LAST_LOW_WM) + if (i == PGSQL_NAME_LAST_LOW_WM) continue; - if (variable_name_exists(pgsql_tracked_variables[i], option.first.c_str()) == true) { + if (strcmp(pgsql_tracked_variables[i].set_variable_name, key.c_str()) == 0) { idx = i; break; - } + } } if (idx != PGSQL_NAME_LAST_HIGH_WM) { - pgsql_variables.client_set_value(sess, idx, option.second.c_str()); - sess->set_default_session_variable((enum pgsql_variable_name)idx, option.second.c_str()); + std::string value1 = val; + + char* transformed_value = nullptr; + if (pgsql_tracked_variables[idx].validator && pgsql_tracked_variables[idx].validator->validate && + ( + *pgsql_tracked_variables[idx].validator->validate)( + value1.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false + ) { + char* m = NULL; + char* errmsg = NULL; + proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + m = (char*)"invalid value for parameter \"%s\": \"%s\""; + errmsg = (char*)malloc(value1.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m)); + sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); + free(errmsg); + ret = EXECUTION_STATE::FAILED; + free(userinfo->username); + free(userinfo->password); + userinfo->username = strdup(""); + userinfo->password = strdup(""); + goto __exit_process_pkt_handshake_response; + } + + if (transformed_value) { + value1 = transformed_value; + free(transformed_value); + } + + if (idx == PGSQL_DATESTYLE) { + // get datestyle from connection parameters + std::string datestyle = val.empty() == false ? val : ""; + + if (datestyle.empty()) { + // No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable. + datestyle = pgsql_thread___default_variables[PGSQL_DATESTYLE]; + } + else { + PgSQL_DateStyle_t datestyle_parsed = PgSQL_DateStyle_Util::parse_datestyle(datestyle); + + // If DateStyle provided in the connection parameters is incomplete, the missing parts will be taken from the default DateStyle. + if (datestyle_parsed.format == DATESTYLE_FORMAT_NONE || datestyle_parsed.order == DATESTYLE_ORDER_NONE) { + PgSQL_DateStyle_t datestyle_default = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); + datestyle = PgSQL_DateStyle_Util::datestyle_to_string(datestyle_parsed, datestyle_default); + } + } + + assert(datestyle.empty() == false); + + if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str())) { + // change current datestyle + sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle); + sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str()); + } + } else { + pgsql_variables.client_set_value(sess, idx, val.c_str()); + sess->set_default_session_variable((enum pgsql_variable_name)idx, val.c_str()); + } } else { - const char* val = option.second.c_str(); - const char* escaped_str = escape_string_backslash_spaces(val); - sess->untracked_option_parameters = "-c " + option.first + "=" + escaped_str + " "; + proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", key.c_str(), val.c_str()); + const char* escaped_str = escape_string_backslash_spaces(val.c_str()); + sess->untracked_option_parameters = "-c " + key + "=" + escaped_str + " "; if (escaped_str != val) free((char*)escaped_str); } } - //if (charset) - // (*myds)->sess->default_charset = charset; + // set mandatory variables if not sent by client + for (int i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { + //if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) { + // continue; + //} + if (pgsql_variables.client_get_hash(sess, i) != 0) + continue; + const char* val = pgsql_thread___default_variables[i]; + pgsql_variables.client_set_value(sess, i, val); + sess->set_default_session_variable((pgsql_variable_name)i, val); + } } else { // we always duplicate username and password, or crashes happen diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 217898d4e..712631fe9 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -825,10 +825,9 @@ void PgSQL_Session::generate_proxysql_internal_session_json(json& j) { //j["conn"]["ps"]["client_stmt_to_global_ids"] = client_myds->myconn->local_stmts->client_stmt_to_global_ids; const PgSQL_Conn_Param& conn_params = client_myds->myconn->conn_params; - for (size_t i = 0; i < conn_params.param_set.size(); i++) { - if (conn_params.param_value[conn_params.param_set[i]] != NULL) { - j["client"]["conn"]["connection_options"][PgSQL_Param_Name_Str[conn_params.param_set[i]]] = conn_params.param_value[conn_params.param_set[i]]; - } + + for (const auto& [key, val] : conn_params.connection_parameters) { + j["client"]["conn"]["connection_options"][key.c_str()] = val.c_str(); } } } @@ -3928,7 +3927,7 @@ void PgSQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE( else { client_addr = strdup((char*)""); } - if (client_myds->myconn->userinfo->username) { + if (client_myds->myconn->userinfo->username && client_myds->myconn->userinfo->username[0] != '\0') { char* _s = (char*)malloc(strlen(client_myds->myconn->userinfo->username) + 100 + strlen(client_addr)); uint8_t _pid = 2; if (client_myds->switching_auth_stage) _pid += 2; From db442ba8cfc00977fe3a6f2eb195b366baef56c9 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 1 Apr 2025 02:25:06 +0500 Subject: [PATCH 2/7] Treat client_encoding as normal server parameter/variable. Optimized RESET ALL command handling. Few optimizations. Code cleanup --- include/PgSQL_Connection.h | 193 ++-------------------------- include/PgSQL_Session.h | 3 - include/PgSQL_Set_Stmt_Parser.h | 2 + include/PgSQL_Variables.h | 6 +- include/PgSQL_Variables_Validator.h | 5 +- include/proxysql_structs.h | 3 +- lib/PgSQL_Connection.cpp | 161 +++-------------------- lib/PgSQL_Protocol.cpp | 97 +++++++------- lib/PgSQL_Session.cpp | 147 +++++---------------- lib/PgSQL_Set_Stmt_Parser.cpp | 3 +- lib/PgSQL_Variables.cpp | 169 ++---------------------- lib/PgSQL_Variables_Validator.cpp | 31 +++++ 12 files changed, 157 insertions(+), 663 deletions(-) diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index 0d5e17bd5..e2271c4b3 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -38,7 +38,6 @@ enum PgSQL_Param_Name { PG_REQUIRE_AUTH, // Specifies the authentication method that the client requires from the server PG_CHANNEL_BINDING, // Controls the client's use of channel binding PG_CONNECT_TIMEOUT, // Maximum time to wait while connecting, in seconds - PG_CLIENT_ENCODING, // Sets the client_encoding configuration parameter for this connection PG_OPTIONS, // Specifies command-line options to send to the server at connection start PG_APPLICATION_NAME, // Specifies a value for the application_name configuration parameter PG_FALLBACK_APPLICATION_NAME, // Specifies a fallback value for the application_name configuration parameter @@ -90,7 +89,7 @@ static const Param_Name_Validation sslcertmode{ (const char* []) { "disable","al static const Param_Name_Validation target_session_attrs{ (const char* []) { "any","read-write","read-only","primary","standby","prefer-standby",nullptr },0 }; static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disable","random",nullptr },-1 }; - +// Excluding client_encoding since it is managed as part of the session variables #define PARAMETER_LIST \ PARAM("host", nullptr) \ PARAM("hostaddr", nullptr) \ @@ -102,7 +101,6 @@ static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disab PARAM("require_auth", &require_auth) \ PARAM("channel_binding", nullptr) \ PARAM("connect_timeout", nullptr) \ - PARAM("client_encoding", nullptr) \ PARAM("options", nullptr) \ PARAM("application_name", nullptr) \ PARAM("fallback_application_name", nullptr) \ @@ -141,200 +139,24 @@ constexpr const char* param_name[] = { #undef PARAM }; -static const std::unordered_map PgSQL_Param_Name_Str = { +// make sure all the keys are in lower case +static const std::unordered_map param_name_map = { #define PARAM(name, val) {name, val}, PARAMETER_LIST #undef PARAM }; -#if 0 -static const std::unordered_map PgSQL_Param_Name_Str = { - { "host", nullptr }, - { "hostaddr", nullptr}, - { "port", nullptr }, - { "database", nullptr}, - { "user", nullptr}, - { "password", nullptr}, - { "passfile", nullptr}, - { "require_auth", &require_auth}, - { "channel_binding", nullptr}, - { "connect_timeout", nullptr}, - { "client_encoding", nullptr}, - { "options", nullptr}, - { "application_name", nullptr}, - { "fallback_application_name", nullptr}, - { "keepalives", nullptr}, - { "keepalives_idle", nullptr}, - { "keepalives_interval", nullptr}, - { "keepalives_count", nullptr}, - { "tcp_user_timeout", nullptr}, - { "replication", &replication}, - { "gsseencmode", &gsseencmode}, - { "sslmode", &sslmode}, - { "requiressl", nullptr}, - { "sslcompression", nullptr}, - { "sslcert", nullptr}, - { "sslkey", nullptr}, - { "sslpassword", nullptr}, - { "sslcertmode", &sslcertmode}, - { "sslrootcert", nullptr}, - { "sslcrl", nullptr}, - { "sslcrldir", nullptr}, - { "sslsni", nullptr}, - { "requirepeer", nullptr}, - { "ssl_min_protocol_version", nullptr}, - { "ssl_max_protocol_version", nullptr}, - { "krbsrvname", nullptr}, - { "gsslib", nullptr}, - { "gssdelegation", nullptr}, - { "service", nullptr}, - { "target_session_attrs", &target_session_attrs}, - { "load_balance_hosts", &load_balance_hosts}, - // Environment Options - { "datestyle", nullptr}, - { "timezone", nullptr}, - { "geqo", nullptr} -}; - -static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SIZE] = { - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &require_auth, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &replication, - &gsseencmode, - &sslmode, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &sslcertmode, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &target_session_attrs, - &load_balance_hosts, - nullptr, - nullptr, - nullptr -}; -#endif - #define PG_EVENT_NONE 0x00 #define PG_EVENT_READ 0x01 #define PG_EVENT_WRITE 0x02 #define PG_EVENT_EXCEPT 0x04 #define PG_EVENT_TIMEOUT 0x08 -#if 0 -class PgSQL_Conn_Param { -private: - bool validate(PgSQL_Param_Name key, const char* val) { - assert(val); - const Param_Name_Validation* validation = PgSQL_Param_Name_Accepted_Values[key]; - - if (validation != nullptr && validation->accepted_values) { - const char** accepted_value = validation->accepted_values; - while (accepted_value != nullptr) { - if (strcmp(val, *accepted_value) == 0) { - return true; - } - } - } else { - return true; - } - - return false; - } - -public: - PgSQL_Conn_Param() {} - ~PgSQL_Conn_Param() { - for (int i = 0; i < PG_PARAM_SIZE; i++) { - if (param_value[i]) - free(param_value[i]); - } - } - - bool set_value(PgSQL_Param_Name key, const char* val) { - if (key == -1) return false; - if (validate(key, val)) { - if (param_value[key]) { - free(param_value[key]); - } - param_value[key] = strdup(val); - param_set.push_back(key); - return true; - } - return false; - } - - bool set_value(const char* key, const char* val) { - return set_value((PgSQL_Param_Name)get_param_name(key), val); - } - - void reset_value(PgSQL_Param_Name key) { - if (param_value[key]) { - free(param_value[key]); - } - param_value[key] = nullptr; - - // this has O(n) complexity. need to fix it.... - param_set.erase(param_set.begin() + static_cast(key)); - } - - const char* get_value(PgSQL_Param_Name key) const { - return param_value[key]; - } - - int get_param_name(const char* name) { - int key = -1; - - for (int i = 0; i < PG_PARAM_SIZE; i++) { - if (strcmp(name, PgSQL_Param_Name_Str[i]) == 0) { - key = i; - break; - } - } - if (key == -1) { - proxy_warning("Unrecognized connection option. Please report this as a bug for future enhancements:%s\n", name); - } - return key; - } - - std::vector param_set; - char* param_value[PG_PARAM_SIZE]{}; -}; -#endif - class PgSQL_Conn_Param { public: PgSQL_Conn_Param() {} ~PgSQL_Conn_Param() {} + bool set_value(const char* key, const char* val) { if (key == nullptr || val == nullptr) return false; connection_parameters[key] = val; @@ -353,6 +175,7 @@ public: } return nullptr; } + bool remove_value(const char* key) { auto it = connection_parameters.find(key); if (it != connection_parameters.end()) { @@ -361,9 +184,13 @@ public: } return false; } + + inline bool is_empty() const { return connection_parameters.empty(); } + + inline void clear() { connection_parameters.clear(); } @@ -524,8 +351,6 @@ public: void update_bytes_recv(uint64_t bytes_recv); void update_bytes_sent(uint64_t bytes_sent); void ProcessQueryAndSetStatusFlags(char* query_digest_text); - void set_charset(const char* charset); - void connect_start_SetCharset(const char* client_encoding, uint32_t hash = 0); inline const PGconn* get_pg_connection() const { return pgsql_conn; } inline int get_pg_server_version() { return PQserverVersion(pgsql_conn); } diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index aa1ec86b5..41c340705 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -283,9 +283,6 @@ private: bool handler_again___status_SETTING_LDAP_USER_VARIABLE(int*); bool handler_again___status_SETTING_SQL_MODE(int*); bool handler_again___status_SETTING_SESSION_TRACK_GTIDS(int*); -#endif // 0 - bool handler_again___status_CHANGING_CHARSET(int* _rc); -#if 0 bool handler_again___status_CHANGING_SCHEMA(int*); #endif // 0 bool handler_again___status_CONNECTING_SERVER(int*); diff --git a/include/PgSQL_Set_Stmt_Parser.h b/include/PgSQL_Set_Stmt_Parser.h index c052184b8..ee3732fa2 100644 --- a/include/PgSQL_Set_Stmt_Parser.h +++ b/include/PgSQL_Set_Stmt_Parser.h @@ -37,7 +37,9 @@ class PgSQL_Set_Stmt_Parser { void generateRE_parse1v2(); // First implemenation of the parser for TRANSACTION ISOLATION LEVEL and TRANSACTION READ/WRITE std::map> parse2(); +#if 0 std::string parse_character_set(); +#endif std::string remove_comments(const std::string& q); }; diff --git a/include/PgSQL_Variables.h b/include/PgSQL_Variables.h index 24c1cbc07..7e2daec8c 100644 --- a/include/PgSQL_Variables.h +++ b/include/PgSQL_Variables.h @@ -14,10 +14,8 @@ extern void print_backtrace(void); typedef bool (*pgsql_verify_var)(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash); typedef bool (*pgsql_update_var)(PgSQL_Session* session, int idx, int &_rc); -bool validate_charset(PgSQL_Session* session, int idx, int &_rc); bool update_server_variable(PgSQL_Session* session, int idx, int &_rc); bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash); -bool verify_set_names(PgSQL_Session* session); class PgSQL_Variables { static pgsql_verify_var verifiers[PGSQL_NAME_LAST_HIGH_WM]; @@ -31,13 +29,13 @@ public: PgSQL_Variables(); ~PgSQL_Variables(); - bool client_set_value(PgSQL_Session* session, int idx, const std::string& value); + bool client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx); bool client_set_hash_and_value(PgSQL_Session* session, int idx, const std::string& value, uint32_t hash); void client_reset_value(PgSQL_Session* session, int idx); const char* client_get_value(PgSQL_Session* session, int idx) const; uint32_t client_get_hash(PgSQL_Session* session, int idx) const; - void server_set_value(PgSQL_Session* session, int idx, const char* value); + void server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx); void server_set_hash_and_value(PgSQL_Session* session, int idx, const char* value, uint32_t hash); void server_reset_value(PgSQL_Session* session, int idx); const char* server_get_value(PgSQL_Session* session, int idx) const; diff --git a/include/PgSQL_Variables_Validator.h b/include/PgSQL_Variables_Validator.h index ae5469847..e6d1183e5 100644 --- a/include/PgSQL_Variables_Validator.h +++ b/include/PgSQL_Variables_Validator.h @@ -9,7 +9,8 @@ typedef enum { VARIABLE_TYPE_BOOL, /**< Boolean variable type. */ VARIABLE_TYPE_STRING, /**< String variable type. */ VARIABLE_TYPE_DATESTYLE, /**< DateStyle variable type. */ - VARIABLE_TYPE_MAINTENANCE_WORK_MEM + VARIABLE_TYPE_MAINTENANCE_WORK_MEM, /**< MaintenanceWorkMem variable type. */ + VARIABLE_TYPE_CLIENT_ENCODING /**< ClientEncoding variable type. */ } pgsql_variable_type_t; @@ -61,5 +62,5 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag extern const pgsql_variable_validator pgsql_variable_validator_bytea_output; extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits; extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem; - +extern const pgsql_variable_validator pgsql_variable_validator_client_encoding; #endif // PGSQL_VARIABLES_VALIDATOR_H diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 32f8e8a53..84a950768 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -1791,9 +1791,10 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag extern const pgsql_variable_validator pgsql_variable_validator_bytea_output; extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits; extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem; +extern const pgsql_variable_validator pgsql_variable_validator_client_encoding; pgsql_variable_st pgsql_tracked_variables[]{ - { PGSQL_CLIENT_ENCODING, SETTING_CHARSET, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_SET_TRANSACTION | PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), nullptr, { "names", nullptr } }, + { PGSQL_CLIENT_ENCODING, SETTING_VARIABLE, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_client_encoding, { "names", nullptr } }, { PGSQL_DATESTYLE, SETTING_VARIABLE, "datestyle", "datestyle", "ISO, MDY" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_datestyle, nullptr }, { PGSQL_INTERVALSTYLE, SETTING_VARIABLE, "intervalstyle", "intervalstyle", "postgres" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_intervalstyle, nullptr }, { PGSQL_STANDARD_CONFORMING_STRINGS, SETTING_VARIABLE, "standard_conforming_strings", "standard_conforming_strings", "on", (PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_bool, nullptr }, diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index 2cd818309..68baf89da 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -14,103 +14,6 @@ using json = nlohmann::json; #include "PgSQL_Query_Processor.h" #include "MySQL_Variables.h" - -#if 0 -// some of the code that follows is from mariadb client library memory allocator -typedef int myf; // Type of MyFlags in my_funcs -#define MYF(v) (myf) (v) -#define MY_KEEP_PREALLOC 1 -#define MY_ALIGN(A,L) (((A) + (L) - 1) & ~((L) - 1)) -#define ALIGN_SIZE(A) MY_ALIGN((A),sizeof(double)) -static void ma_free_root(MA_MEM_ROOT *root, myf MyFLAGS); -static void *ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size); -#define MAX(a,b) (((a) > (b)) ? (a) : (b)) - - -static void * ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size) -{ - size_t get_size; - void * point; - MA_USED_MEM *next= 0; - MA_USED_MEM **prev; - - Size= ALIGN_SIZE(Size); - - if ((*(prev= &mem_root->free))) - { - if ((*prev)->left < Size && - mem_root->first_block_usage++ >= 16 && - (*prev)->left < 4096) - { - next= *prev; - *prev= next->next; - next->next= mem_root->used; - mem_root->used= next; - mem_root->first_block_usage= 0; - } - for (next= *prev; next && next->left < Size; next= next->next) - prev= &next->next; - } - if (! next) - { /* Time to alloc new block */ - get_size= MAX(Size+ALIGN_SIZE(sizeof(MA_USED_MEM)), - (mem_root->block_size & ~1) * ( (mem_root->block_num >> 2) < 4 ? 4 : (mem_root->block_num >> 2) ) ); - - if (!(next = (MA_USED_MEM*) malloc(get_size))) - { - if (mem_root->error_handler) - (*mem_root->error_handler)(); - return((void *) 0); /* purecov: inspected */ - } - mem_root->block_num++; - next->next= *prev; - next->size= get_size; - next->left= get_size-ALIGN_SIZE(sizeof(MA_USED_MEM)); - *prev=next; - } - point= (void *) ((char*) next+ (next->size-next->left)); - if ((next->left-= Size) < mem_root->min_malloc) - { /* Full block */ - *prev=next->next; /* Remove block from list */ - next->next=mem_root->used; - mem_root->used=next; - mem_root->first_block_usage= 0; - } - return(point); -} - - -static void ma_free_root(MA_MEM_ROOT *root, myf MyFlags) -{ - MA_USED_MEM *next,*old; - - if (!root) - return; /* purecov: inspected */ - if (!(MyFlags & MY_KEEP_PREALLOC)) - root->pre_alloc=0; - - for ( next=root->used; next ;) - { - old=next; next= next->next ; - if (old != root->pre_alloc) - free(old); - } - for (next= root->free ; next ; ) - { - old=next; next= next->next ; - if (old != root->pre_alloc) - free(old); - } - root->used=root->free=0; - if (root->pre_alloc) - { - root->free=root->pre_alloc; - root->free->left=root->pre_alloc->size-ALIGN_SIZE(sizeof(MA_USED_MEM)); - root->free->next=0; - } -} -#endif // 0 - extern char * binary_sha1; #include "proxysql_find_charset.h" @@ -841,31 +744,22 @@ void PgSQL_Connection::connect_start() { conninfo << "sslmode='disable' "; // not supporting SSL } - if (myds && myds->sess) { - const char* charset = NULL; - uint32_t charset_hash = 0; - - // Take client character set and use it to connect to backend - charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING); - if (charset_hash != 0) - charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING); - - //if (!charset) { - // charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING]; - //} - + if (myds && myds->sess && myds->sess->client_myds) { // Client Encoding should be always set - assert(charset); - - connect_start_SetCharset(charset, charset_hash); - - escaped_str = escape_string_single_quotes_and_backslashes((char*)charset, false); + const char* client_charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING); + assert(client_charset); + uint32_t client_charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING); + assert(client_charset_hash); + const char* escaped_str = escape_string_backslash_spaces(client_charset); conninfo << "client_encoding='" << escaped_str << "' "; - if (escaped_str != charset) - free(escaped_str); + if (escaped_str != client_charset) + free((char*)escaped_str); + + // charset validation is already done + pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, client_charset, client_charset_hash); std::vector client_options; - client_options.reserve(PGSQL_NAME_LAST_LOW_WM + dynamic_variables_idx.size()); + client_options.reserve(PGSQL_NAME_LAST_LOW_WM + myds->sess->client_myds->myconn->dynamic_variables_idx.size()); // excluding PGSQL_CLIENT_ENCODING for (unsigned int idx = 1; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { @@ -873,9 +767,9 @@ void PgSQL_Connection::connect_start() { client_options.push_back(idx); } - for (std::vector::const_iterator it_c = dynamic_variables_idx.begin(); it_c != dynamic_variables_idx.end(); it_c++) { - assert(pgsql_variables.client_get_hash(myds->sess, *it_c)); - client_options.push_back(*it_c); + for (uint32_t idx : myds->sess->client_myds->myconn->dynamic_variables_idx) { + assert(pgsql_variables.client_get_hash(myds->sess, idx)); + client_options.push_back(idx); } if (client_options.empty() == false || @@ -1882,31 +1776,6 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) { } } -void PgSQL_Connection::set_charset(const char* charset) { - proxy_debug(PROXY_DEBUG_MYSQL_CONNPOOL, 4, "Setting client encoding %s\n", charset); - pgsql_variables.client_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset); -} - -void PgSQL_Connection::connect_start_SetCharset(const char* charset, uint32_t hash) { - assert(charset); - - int charset_encoding = PgSQL_Connection::char_to_encoding(charset); - - if (charset_encoding == -1) { - proxy_error("Cannot find character set [%s]\n", charset); - assert(0); - } - - /* We are connecting to backend setting charset in connection parameters. - * Client already has sent us a character set and client connection variables have been already set. - * Now we store this charset in server connection variables to avoid updating this variables on backend. - */ - if (hash == 0) - pgsql_variables.server_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset); - else - pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, charset, hash); -} - // this function is identical to async_query() , with the only exception that MyRS should never be set int PgSQL_Connection::async_send_simple_command(short event, char* stmt, unsigned long length) { PROXY_TRACE(); diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 08e7d2c70..63aabf8de 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -783,6 +783,7 @@ std::vector> PgSQL_Protocol::parse_options(c // Add key-value pair to the list if (!key.empty()) { + std::transform(key.begin(), key.end(), key.begin(), ::tolower); options_list.emplace_back(std::move(key), std::move(value)); } } @@ -822,24 +823,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* return EXECUTION_STATE::FAILED; } - // set charset but first verify - const char* charset = (*myds)->myconn->conn_params.get_value(PG_CLIENT_ENCODING); - // if client does not provide client_encoding, PostgreSQL uses the default client encoding. - // We need to save the default client encoding to send it to the client in case client doesn't provide one. - if (charset == NULL) charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING]; - - assert(charset); - - int charset_encoding = (*myds)->myconn->char_to_encoding(charset); - - if (charset_encoding == -1) { - proxy_error("Cannot find charset [%s]\n", charset); - proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "DS=%p , Session=%p , charset='%s'. Client charset not supported.\n", (*myds), (*myds)->sess, charset); - generate_error_packet(true, false, "Client charset not supported", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); - return EXECUTION_STATE::FAILED; - } - user = (char*)(*myds)->myconn->conn_params.get_value(PG_USER); if (!user || *user == '\0') { @@ -1075,20 +1059,26 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* parameters.reserve((*myds)->myconn->conn_params.connection_parameters.size()); - // new implementation + /* Note: Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. + PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, + but the error is encountered when executing a query. + This is behavious is intentional, as newer PostgreSQL versions may introduce parameters that ProxySQL is not yet aware of. + */ + // New implementation for (const auto& [param_name, param_val] : (*myds)->myconn->conn_params.connection_parameters) { std::string param_name_lowercase(param_name); std::transform(param_name_lowercase.cbegin(), param_name_lowercase.cend(), param_name_lowercase.begin(), ::tolower); - auto it = PgSQL_Param_Name_Str.find(param_name_lowercase.c_str()); - if (it != PgSQL_Param_Name_Str.end()) { + // check if parameter is part of connection-level parameters + auto itr = param_name_map.find(param_name_lowercase.c_str()); + if (itr != param_name_map.end()) { if (param_name_lowercase.compare("user") == 0 || param_name_lowercase.compare("password") == 0) { continue; } bool is_validation_success = false; - const Param_Name_Validation* validation = it->second; + const Param_Name_Validation* validation = itr->second; if (validation != nullptr && validation->accepted_values) { const char** accepted_value = validation->accepted_values; @@ -1112,26 +1102,23 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); free(errmsg); ret = EXECUTION_STATE::FAILED; + + // freeing userinfo->username and userinfo->password to prevent invalid password error generation. free(userinfo->username); free(userinfo->password); userinfo->username = strdup(""); userinfo->password = strdup(""); + // goto __exit_process_pkt_handshake_response; } if (param_name_lowercase.compare("database") == 0) { userinfo->set_dbname(param_val.empty() ? user : param_val.c_str()); - } else if (param_name_lowercase.compare("client_encoding") == 0) { - assert(sess); - assert(sess->client_myds); - PgSQL_Connection* myconn = sess->client_myds->myconn; - assert(myconn); - myconn->set_charset(charset); - sess->set_default_session_variable(PGSQL_CLIENT_ENCODING, charset); } else if (param_name_lowercase.compare("options") == 0) { options_list = parse_options(param_val.c_str()); } } else { + // session parameters/variables? parameters.push_back(std::make_pair(param_name_lowercase, param_val)); } } @@ -1140,60 +1127,69 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* userinfo->set_dbname(user); } + // Merge options with parameters. + // Options are processed first, followed by connection parameters. + // If a parameter is specified in both, the connection parameter takes precedence + // and overwrites the previosly set value. if (options_list.empty() == false) { options_list.reserve(parameters.size() + options_list.size()); options_list.insert(options_list.end(), std::make_move_iterator(parameters.begin()), std::make_move_iterator(parameters.end())); parameters = std::move(options_list); } - // assign default datestyle to current datestyle + // assign default datestyle to current datestyle. + // This is needed by PgSQL_DateStyle_Util::parse_datestyle sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); - for (const auto&[key, val] : parameters) { + for (const auto&[param_key, param_val] : parameters) { int idx = PGSQL_NAME_LAST_HIGH_WM; for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { if (i == PGSQL_NAME_LAST_LOW_WM) continue; - if (strcmp(pgsql_tracked_variables[i].set_variable_name, key.c_str()) == 0) { + if (strncmp(param_key.c_str(), pgsql_tracked_variables[i].set_variable_name, + strlen(pgsql_tracked_variables[i].set_variable_name)) == 0) { idx = i; break; } } if (idx != PGSQL_NAME_LAST_HIGH_WM) { - std::string value1 = val; + std::string value_copy = param_val; char* transformed_value = nullptr; if (pgsql_tracked_variables[idx].validator && pgsql_tracked_variables[idx].validator->validate && ( *pgsql_tracked_variables[idx].validator->validate)( - value1.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false + value_copy.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false ) { char* m = NULL; char* errmsg = NULL; - proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str()); m = (char*)"invalid value for parameter \"%s\": \"%s\""; - errmsg = (char*)malloc(value1.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m)); - sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + errmsg = (char*)malloc(value_copy.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m)); + sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str()); generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); free(errmsg); ret = EXECUTION_STATE::FAILED; + + // freeing userinfo->username and userinfo->password to prevent invalid password error generation. free(userinfo->username); free(userinfo->password); userinfo->username = strdup(""); userinfo->password = strdup(""); + // goto __exit_process_pkt_handshake_response; } if (transformed_value) { - value1 = transformed_value; + value_copy = transformed_value; free(transformed_value); } if (idx == PGSQL_DATESTYLE) { // get datestyle from connection parameters - std::string datestyle = val.empty() == false ? val : ""; + std::string datestyle = value_copy.empty() == false ? value_copy : ""; if (datestyle.empty()) { // No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable. @@ -1211,35 +1207,36 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* assert(datestyle.empty() == false); - if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str())) { + if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str(), false)) { // change current datestyle sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle); sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str()); } } else { - pgsql_variables.client_set_value(sess, idx, val.c_str()); - sess->set_default_session_variable((enum pgsql_variable_name)idx, val.c_str()); + pgsql_variables.client_set_value(sess, idx, value_copy.c_str(), false); + sess->set_default_session_variable((enum pgsql_variable_name)idx, value_copy.c_str()); } } else { - proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", key.c_str(), val.c_str()); - const char* escaped_str = escape_string_backslash_spaces(val.c_str()); - sess->untracked_option_parameters = "-c " + key + "=" + escaped_str + " "; - if (escaped_str != val) + // parameter provided is not part of the tracked variables. Will lock on hostgroup on next query. + const char* val_cstr = param_val.c_str(); + proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", param_key.c_str(), val_cstr); + const char* escaped_str = escape_string_backslash_spaces(val_cstr); + sess->untracked_option_parameters = "-c " + param_key + "=" + escaped_str + " "; + if (escaped_str != val_cstr) free((char*)escaped_str); } } - // set mandatory variables if not sent by client + // fill all crtical variables with default values, if not set by client for (int i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { - //if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) { - // continue; - //} if (pgsql_variables.client_get_hash(sess, i) != 0) continue; const char* val = pgsql_thread___default_variables[i]; - pgsql_variables.client_set_value(sess, i, val); + pgsql_variables.client_set_value(sess, i, val, false); sess->set_default_session_variable((pgsql_variable_name)i, val); } + + sess->client_myds->myconn->reorder_dynamic_variables_idx(); } else { // we always duplicate username and password, or crashes happen diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 712631fe9..42ef7896a 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -81,7 +81,8 @@ static inline char is_normal_char(char c) { } */ -static const std::array pgsql_critical_variables = { +static const std::array pgsql_critical_variables = { + "client_encoding", "datestyle", "intervalstyle", "standard_conforming_strings", @@ -1507,98 +1508,6 @@ bool PgSQL_Session::handler_again___status_SETTING_INIT_CONNECT(int* _rc) { return ret; } -bool PgSQL_Session::handler_again___status_CHANGING_CHARSET(int* _rc) { - assert(mybe->server_myds->myconn); - PgSQL_Data_Stream* myds = mybe->server_myds; - PgSQL_Connection* myconn = myds->myconn; - - /* Validate that server can support client's charset */ - if (!validate_charset(this, PGSQL_CLIENT_ENCODING, *_rc)) { - return false; - } - - myds->DSS = STATE_MARIADB_QUERY; - enum session_status st = status; - if (myds->mypolls == NULL) { - thread->mypolls.add(POLLIN | POLLOUT, mybe->server_myds->fd, mybe->server_myds, thread->curtime); - } - - std::string charset = pgsql_variables.client_get_value(this, PGSQL_CLIENT_ENCODING); - std::string query = "SET CLIENT_ENCODING TO '" + charset + "'"; - - int rc = myconn->async_send_simple_command(myds->revents, (char*)query.c_str(),query.length()); - - if (rc == 0) { - __sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1); - myds->DSS = STATE_MARIADB_GENERIC; - st = previous_status.top(); - previous_status.pop(); - NEXT_IMMEDIATE_NEW(st); - } else { - if (rc == -1) { - // the command failed - const bool error_present = myconn->is_error_present(); - PgHGM->p_update_pgsql_error_counter( - p_pgsql_error_type::pgsql, - myconn->parent->myhgc->hid, - myconn->parent->address, - myconn->parent->port, - (error_present ? 9999 : ER_PROXYSQL_OFFLINE_SRV) // TOFIX: 9999 is a placeholder for the actual error code - ); - if (error_present == false || (error_present == true && myconn->is_connection_in_reusable_state() == false)) { - bool retry_conn = false; - // client error, serious - proxy_error( - "Client trying to set a charset (%s) not supported by backend (%s:%d). Changing it to %s\n", - charset.c_str(), myconn->parent->address, myconn->parent->port, pgsql_tracked_variables[PGSQL_CLIENT_ENCODING].default_value - ); - detected_broken_connection(__FILE__, __LINE__, __func__, "during SET CLIENT_ENCODING", myconn); - if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) { - retry_conn = true; - } - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd = 0; - if (retry_conn) { - myds->DSS = STATE_NOT_INITIALIZED; - NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); - } - *_rc = -1; - return false; - } else { - proxy_warning("Error during SET CLIENT_ENCODING: %s\n", myconn->get_error_code_with_message().c_str()); - // we won't go back to PROCESSING_QUERY - st = previous_status.top(); - previous_status.pop(); - client_myds->myprot.generate_error_packet(true, true, myconn->get_error_message().c_str(), myconn->get_error_code(), false); - myds->destroy_MySQL_Connection_From_Pool(true); - myds->fd = 0; - RequestEnd(myds); //fix bug #682 - } - } else { - if (rc == -2) { - bool retry_conn = false; - proxy_error("Timeout during SET CLIENT_ENCODING on %s , %d\n", myconn->parent->address, myconn->parent->port); - PgHGM->p_update_pgsql_error_counter(p_pgsql_error_type::pgsql, myconn->parent->myhgc->hid, myconn->parent->address, myconn->parent->port, ER_PROXYSQL_CHANGE_USER_TIMEOUT); - if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) { - retry_conn = true; - } - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd = 0; - if (retry_conn) { - myds->DSS = STATE_NOT_INITIALIZED; - NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); - } - *_rc = -1; - return false; - } - else { - // rc==1 , nothing to do for now - } - } - } - return false; -} - bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, const char* var_name, const char* var_value, bool no_quote, bool set_transaction) { bool ret = false; @@ -1649,6 +1558,9 @@ bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, co query = NULL; } if (rc == 0) { + if (strncasecmp(var_name, "client_encoding", sizeof("client_encoding")-1) == 0) { + __sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1); + } myds->revents |= POLLOUT; // we also set again POLLOUT to send a query immediately! myds->DSS = STATE_MARIADB_GENERIC; st = previous_status.top(); @@ -3313,11 +3225,6 @@ handler_again: } if (locked_on_hostgroup == -1 || locked_on_hostgroup_and_all_variables_set == false) { - // verify charset - if (verify_set_names(this)) { - goto handler_again; - } - for (auto i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { auto client_hash = client_myds->myconn->var_hash[i]; #ifdef DEBUG @@ -3597,9 +3504,6 @@ bool PgSQL_Session::handler_again___multiple_statuses(int* rc) { //case SETTING_INIT_CONNECT: // ret = handler_again___status_SETTING_INIT_CONNECT(rc); break; - case SETTING_CHARSET: - ret = handler_again___status_CHANGING_CHARSET(rc); - break; default: break; } @@ -4349,7 +4253,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", var.c_str(), value1.c_str()); uint32_t var_hash_int = SpookyHash::Hash32(value1.c_str(), value1.length(), 10); if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str())) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str(), true)) { return false; } if (idx == PGSQL_DATESTYLE) { @@ -4629,7 +4533,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C l_free(pkt->size, pkt->ptr); return true; - }*/ else if (match_regexes && match_regexes[3]->match(dig)) { + } else if (match_regexes && match_regexes[3]->match(dig)) { std::vector> param_status; PgSQL_Set_Stmt_Parser parser(nq); std::string charset = parser.parse_character_set(); @@ -4681,7 +4585,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C RequestEnd(NULL); l_free(pkt->size, pkt->ptr); return true; - } else { + }*/ else { unable_to_parse_set_statement(lock_hostgroup); return false; } @@ -4720,26 +4624,41 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C if (strncasecmp(nq.c_str(), "ALL", 3) == 0) { - for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { + for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { - if (i == PGSQL_NAME_LAST_LOW_WM) - continue; + const char* name = pgsql_tracked_variables[idx].set_variable_name; + const char* value = get_default_session_variable((enum pgsql_variable_name)idx); - const char* name = pgsql_tracked_variables[i].set_variable_name; - const char* value = get_default_session_variable((enum pgsql_variable_name)i); - proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[i].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[i].idx, value)) { + if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, false)) { + return false; + } + if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { + param_status.push_back(std::make_pair(name, value)); + } + } + } + + for (int idx : client_myds->myconn->dynamic_variables_idx) { + assert(idx < PGSQL_NAME_LAST_HIGH_WM); + const char* name = pgsql_tracked_variables[idx].set_variable_name; + const char* value = get_default_session_variable((enum pgsql_variable_name)idx); + proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); + uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); + if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, false)) { return false; } - if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) { + if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { param_status.push_back(std::make_pair(name, value)); } } } + client_myds->myconn->reorder_dynamic_variables_idx(); + } else if (std::find(pgsql_variables.ignore_vars.begin(), pgsql_variables.ignore_vars.end(), nq) != pgsql_variables.ignore_vars.end()) { // this is a variable we parse but ignore #ifdef DEBUG @@ -4765,7 +4684,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value)) { + if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, true)) { return false; } if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { diff --git a/lib/PgSQL_Set_Stmt_Parser.cpp b/lib/PgSQL_Set_Stmt_Parser.cpp index 42218f8f2..3f92ef12a 100644 --- a/lib/PgSQL_Set_Stmt_Parser.cpp +++ b/lib/PgSQL_Set_Stmt_Parser.cpp @@ -202,6 +202,7 @@ std::map> PgSQL_Set_Stmt_Parser::parse2() { return result; } +#if 0 std::string PgSQL_Set_Stmt_Parser::parse_character_set() { #ifdef DEBUG proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 4, "Parsing query %s\n", query.c_str()); @@ -224,7 +225,7 @@ std::string PgSQL_Set_Stmt_Parser::parse_character_set() { delete opt2; return value3; } - +#endif std::string PgSQL_Set_Stmt_Parser::remove_comments(const std::string& q) { std::string result = ""; bool in_multiline_comment = false; diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index 83473651c..af1f5ba6a 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -10,27 +10,12 @@ #include - -/*static inline char is_digit(char c) { - if(c >= '0' && c <= '9') - return 1; - return 0; -}*/ - pgsql_verify_var PgSQL_Variables::verifiers[PGSQL_NAME_LAST_HIGH_WM]; pgsql_update_var PgSQL_Variables::updaters[PGSQL_NAME_LAST_HIGH_WM]; - PgSQL_Variables::PgSQL_Variables() { // add here all the variables we want proxysql to recognize, but ignore ignore_vars.push_back("application_name"); - //ignore_vars.push_back("interactive_timeout"); - //ignore_vars.push_back("wait_timeout"); - //ignore_vars.push_back("net_read_timeout"); - //ignore_vars.push_back("net_write_timeout"); - //ignore_vars.push_back("net_buffer_length"); - //ignore_vars.push_back("read_buffer_size"); - //ignore_vars.push_back("read_rnd_buffer_size"); // NOTE: This variable has been temporarily ignored. Check issues #3442 and #3441. //ignore_vars.push_back("session_track_schema"); variables_regexp = ""; @@ -55,14 +40,11 @@ PgSQL_Variables::PgSQL_Variables() { } } for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - if (i == PGSQL_CLIENT_ENCODING) { - PgSQL_Variables::updaters[i] = NULL; - PgSQL_Variables::verifiers[i] = NULL; - } else { - PgSQL_Variables::verifiers[i] = verify_server_variable; - PgSQL_Variables::updaters[i] = update_server_variable; - } - if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { + + PgSQL_Variables::verifiers[i] = verify_server_variable; + PgSQL_Variables::updaters[i] = update_server_variable; + + if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { variables_regexp += pgsql_tracked_variables[i].set_variable_name; variables_regexp += "|"; @@ -135,7 +117,7 @@ void PgSQL_Variables::server_set_hash_and_value(PgSQL_Session* session, int idx, session->mybe->server_myds->myconn->variables[idx].value = strdup(value); } -bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value) { +bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx) { if (!session || !session->client_myds || !session->client_myds->myconn) { proxy_warning("Session validation failed\n"); return false; @@ -147,7 +129,7 @@ bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const st } session->client_myds->myconn->variables[idx].value = strdup(value.c_str()); - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx session->client_myds->myconn->reorder_dynamic_variables_idx(); } @@ -168,7 +150,7 @@ uint32_t PgSQL_Variables::client_get_hash(PgSQL_Session* session, int idx) const return session->client_myds->myconn->var_hash[idx]; } -void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value) { +void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx) { assert(session); assert(session->mybe); assert(session->mybe->server_myds); @@ -181,7 +163,7 @@ void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const ch } session->mybe->server_myds->myconn->variables[idx].value = strdup(value); - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); } @@ -253,98 +235,6 @@ bool PgSQL_Variables::verify_variable(PgSQL_Session* session, int idx) const { return ret; } -bool validate_charset(PgSQL_Session* session, int idx, int &_rc) { - /*if (idx == PGSQL_CLIENT_ENCODING || idx == PGSQL_SET_NAMES) { - PgSQL_Data_Stream *myds = session->mybe->server_myds; - PgSQL_Connection *myconn = myds->myconn; - char msg[128]; - const MARIADB_CHARSET_INFO *ci = NULL; - const char* replace_collation = NULL; - const char* not_supported_collation = NULL; - unsigned int replace_collation_nr = 0; - std::stringstream ss; - int charset = atoi(pgsql_variables.client_get_value(session, idx)); - if (charset >= 255 && myconn->pgsql->server_version[0] != '8') { - switch(pgsql_thread___handle_unknown_charset) { - case HANDLE_UNKNOWN_CHARSET__DISCONNECT_CLIENT: - snprintf(msg,sizeof(msg),"Can't initialize character set %s", pgsql_variables.client_get_value(session, idx)); - proxy_error("Can't initialize character set on %s, %d: Error %d (%s). Closing client connection %s:%d.\n", - myconn->parent->address, myconn->parent->port, 2019, msg, session->client_myds->addr.addr, session->client_myds->addr.port); - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd=0; - _rc=-1; - return false; - case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT_VERBOSE: - ci = proxysql_find_charset_nr(charset); - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot find character set [%s]\n", pgsql_variables.client_get_value(session, idx)); - assert(0); - // LCOV_EXCL_STOP - } - not_supported_collation = ci->name; - - if (idx == SQL_COLLATION_CONNECTION) { - ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]); - } else { - if (pgsql_thread___default_variables[idx]) { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]); - } else { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]); - } - } - - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot find character set [%s]\n", pgsql_thread___default_variables[idx]); - assert(0); - // LCOV_EXCL_STOP - } - replace_collation = ci->name; - replace_collation_nr = ci->nr; - - proxy_warning("Server doesn't support collation (%s) %s. Replacing it with the configured default (%d) %s. Client %s:%d\n", - pgsql_variables.client_get_value(session, idx), not_supported_collation, - replace_collation_nr, replace_collation, session->client_myds->addr.addr, session->client_myds->addr.port); - - ss << replace_collation_nr; - pgsql_variables.client_set_value(session, idx, ss.str()); - _rc=0; - return true; - case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT: - if (idx == SQL_COLLATION_CONNECTION) { - ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]); - } else { - if (pgsql_thread___default_variables[idx]) { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]); - } else { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]); - } - } - - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot filnd charset [%s]\n", pgsql_thread___default_variables[idx]); - assert(0); - // LCOV_EXCL_STOP - } - replace_collation_nr = ci->nr; - - ss << replace_collation_nr; - pgsql_variables.client_set_value(session, idx, ss.str()); - _rc=0; - return true; - default: - proxy_error("Wrong configuration of the handle_unknown_charset\n"); - _rc=-1; - return false; - } - } - }*/ - _rc=0; - return true; -} - bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) { bool no_quote = true; if (IS_PGTRACKED_VAR_OPTION_SET_QUOTE(pgsql_tracked_variables[idx])) no_quote = false; @@ -352,49 +242,12 @@ bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) { const char *set_var_name = pgsql_tracked_variables[idx].set_variable_name; bool ret = false; - if (!validate_charset(session, idx, _rc)) { - return false; - } - const char* value = pgsql_variables.client_get_value(session, idx); - pgsql_variables.server_set_value(session, idx, value); + pgsql_variables.server_set_value(session, idx, value, true); ret = session->handler_again___status_SETTING_GENERIC_VARIABLE(&_rc, set_var_name, value, no_quote, st); return ret; } -bool verify_set_names(PgSQL_Session* session) { - uint32_t client_charset_hash = pgsql_variables.client_get_hash(session, PGSQL_CLIENT_ENCODING); - if (client_charset_hash == 0) - return false; - - if (client_charset_hash != pgsql_variables.server_get_hash(session, PGSQL_CLIENT_ENCODING)) { - switch(session->status) { // this switch can be replaced with a simple previous_status.push(status), but it is here for readibility - case PROCESSING_QUERY: - session->previous_status.push(PROCESSING_QUERY); - break; - /* - case PROCESSING_STMT_PREPARE: - session->previous_status.push(PROCESSING_STMT_PREPARE); - break; - case PROCESSING_STMT_EXECUTE: - session->previous_status.push(PROCESSING_STMT_EXECUTE); - break; - */ - default: - // LCOV_EXCL_START - proxy_error("Wrong status %d\n", session->status); - assert(0); - break; - // LCOV_EXCL_STOP - } - session->set_status(SETTING_CHARSET); - const char* client_charset_value = pgsql_variables.client_get_value(session, PGSQL_CLIENT_ENCODING); - pgsql_variables.server_set_hash_and_value(session, PGSQL_CLIENT_ENCODING, client_charset_value, client_charset_hash); - return true; - } - return false; -} - inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash) { if (client_hash && client_hash != server_hash) { // Edge case for set charset command, because we do not know database character set @@ -427,7 +280,7 @@ inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t cli // LCOV_EXCL_STOP } session->set_status(pgsql_tracked_variables[idx].status); - pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx)); + pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx), true); return true; } return false; diff --git a/lib/PgSQL_Variables_Validator.cpp b/lib/PgSQL_Variables_Validator.cpp index f1b6b7edf..3286a72a8 100644 --- a/lib/PgSQL_Variables_Validator.cpp +++ b/lib/PgSQL_Variables_Validator.cpp @@ -422,6 +422,31 @@ bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const pa return true; } +bool pgsql_variable_validate_client_encoding(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) { + (void)params; + if (transformed_value) *transformed_value = nullptr; + + int charset_encoding = PgSQL_Connection::char_to_encoding(value); + + if (charset_encoding == -1) { + return false; + } + + if (transformed_value) { + *transformed_value = strdup(value); + + if (*transformed_value) { // Ensure strdup succeeded + char* tmp_val = *transformed_value; + + while (*tmp_val) { + *tmp_val = toupper((unsigned char)*tmp_val); + tmp_val++; + } + } + } + + return true; +} const pgsql_variable_validator pgsql_variable_validator_bool = { .type = VARIABLE_TYPE_BOOL, @@ -482,3 +507,9 @@ const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = { .uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB } }; + +const pgsql_variable_validator pgsql_variable_validator_client_encoding = { + .type = VARIABLE_TYPE_CLIENT_ENCODING, + .validate = &pgsql_variable_validate_client_encoding, + .params = {} +}; From b8ffb7d56e59339c88ce1ee04604e37ff7d9061e Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Tue, 1 Apr 2025 19:11:44 +0500 Subject: [PATCH 3/7] Fixed connection parameter validation --- lib/PgSQL_Protocol.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 63aabf8de..13ed7a3e6 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -1062,7 +1062,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* /* Note: Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, but the error is encountered when executing a query. - This is behavious is intentional, as newer PostgreSQL versions may introduce parameters that ProxySQL is not yet aware of. + This is behaviour is intentional, as newer PostgreSQL versions may introduce parameters that ProxySQL is not yet aware of. */ // New implementation for (const auto& [param_name, param_val] : (*myds)->myconn->conn_params.connection_parameters) { @@ -1082,11 +1082,12 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* if (validation != nullptr && validation->accepted_values) { const char** accepted_value = validation->accepted_values; - while (accepted_value != nullptr) { + while (*accepted_value) { if (strcmp(param_val.c_str(), *accepted_value) == 0) { is_validation_success = true; break; } + accepted_value++; } } else { is_validation_success = true; From 4c918490a4370dd8538dda340ba6390fbc80e3ba Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 2 Apr 2025 01:30:28 +0500 Subject: [PATCH 4/7] Added TAP test Note: - Using libpq to test ProxySQL's handling of undocumented parameters isn't possible, as libpq enforces a strict subset of PostgreSQL connection parameters as per the official documentation, rejecting any undocumented parameters. However, actual PostgreSQL servers accept additional parameters (e.g., extra_float_digits) and apply them at the connection/session level. To test this behavior, a raw socket is used to connect to a ProxySQL server and send custom built messages to communicate with ProxySQL. It currently works with plain text password authentication, without ssl support. - Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, but the error is encountered when executing a query. This is behaviour is intentional, as newer PostgreSQL versions may introduce new parameters that ProxySQL is not yet aware of. --- .../pgsql-connection_parameters_test-t.cpp | 909 ++++++++++++++++++ 1 file changed, 909 insertions(+) create mode 100644 test/tap/tests/pgsql-connection_parameters_test-t.cpp diff --git a/test/tap/tests/pgsql-connection_parameters_test-t.cpp b/test/tap/tests/pgsql-connection_parameters_test-t.cpp new file mode 100644 index 000000000..063a75e77 --- /dev/null +++ b/test/tap/tests/pgsql-connection_parameters_test-t.cpp @@ -0,0 +1,909 @@ +/** + * @file pgsql-connection_parameters_test-t-t.cpp + * @brief This TAP test validates if connection parameters are correctly handled by ProxySQL. + * + * Note: + * - Using libpq to test ProxySQL's handling of undocumented parameters isn't possible, as libpq enforces a strict subset of PostgreSQL + * connection parameters as per the official documentation, rejecting any undocumented parameters. However, actual + * PostgreSQL servers accept additional parameters (e.g., extra_float_digits) and apply them at the connection/session level. + * To test this behavior, a raw socket is used to connect to a ProxySQL server and send custom built messages to communicate + * with ProxySQL. It currently works with plain text password authentication, without ssl support. + * + * - Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. + * PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, + * but the error is encountered when executing a query. + * This is behaviour is intentional, as newer PostgreSQL versions may introduce new parameters that ProxySQL is not yet aware of. + * + * + */ + +#include +#include +#include +#include +#include +#include +#include "libpq-fe.h" +#include "command_line.h" +#include "tap.h" +#include "utils.h" + +CommandLine cl; + +using PGConnPtr = std::unique_ptr; + +enum ConnType { + ADMIN, + BACKEND +}; + +PGConnPtr createNewConnection(ConnType conn_type, const std::string& parameters = "", bool with_ssl = false) { + + const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host; + int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port; + const char* username = (conn_type == BACKEND) ? cl.pgsql_username : cl.admin_username; + const char* password = (conn_type == BACKEND) ? cl.pgsql_password : cl.admin_password; + + std::stringstream ss; + + ss << "host=" << host << " port=" << port; + ss << " user=" << username << " password=" << password; + ss << (with_ssl ? " sslmode=require" : " sslmode=disable"); + + if (parameters.empty() == false) { + ss << " " << parameters; + } + + PGconn* conn = PQconnectdb(ss.str().c_str()); + if (PQstatus(conn) != CONNECTION_OK) { + fprintf(stderr, "Connection failed to '%s': %s\n", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn)); + PQfinish(conn); + return PGConnPtr(nullptr, &PQfinish); + } + return PGConnPtr(conn, &PQfinish); +} + +bool executeQueries(PGconn* conn, const std::vector& queries) { + auto fnResultType = [](const char* query) -> int { + const char* fs = strchr(query, ' '); + size_t qtlen = strlen(query); + if (fs != NULL) { + qtlen = (fs - query) + 1; + } + char buf[qtlen]; + memcpy(buf, query, qtlen - 1); + buf[qtlen - 1] = 0; + + if (strncasecmp(buf, "SELECT", sizeof("SELECT") - 1) == 0) { + return PGRES_TUPLES_OK; + } + else if (strncasecmp(buf, "COPY", sizeof("COPY") - 1) == 0) { + return PGRES_COPY_OUT; + } + + return PGRES_COMMAND_OK; + }; + + + for (const auto& query : queries) { + diag("Running: %s", query.c_str()); + PGresult* res = PQexec(conn, query.c_str()); + bool success = PQresultStatus(res) == fnResultType(query.c_str()); + if (!success) { + fprintf(stderr, "Failed to execute query '%s': %s\n", + query.c_str(), PQerrorMessage(conn)); + PQclear(res); + return false; + } + PQclear(res); + } + return true; +} + +struct parameter { + std::string name; + std::string value; +}; + +struct parameter_test { + std::vector set_admin_vars; // Admin variables to set + std::vector conn_params; // Parameters in startup message + std::vector conn_options; // Options (-c flags) in startup + std::vector set_commands; // SET commands after connection + std::vector expected; // Expected SHOW values + bool reset_after; // Whether to RESET parameters + bool expect_failure; // If connection/query should fail +}; + +/** + * @struct MyPGresult + * @brief Represents the result of a PostgreSQL query. + * + * This structure holds the columns, rows, status, and error message of a PostgreSQL query result. + */ +struct MyPGresult { + std::vector columns; ///< Column names of the result set. + std::vector> rows; ///< Rows of the result set. + std::string status; ///< Status of the query execution. + std::string error; ///< Error message if the query failed. +}; + +struct PgSQLResponse { + char type; ///< Type of the response message. + int32_t length; ///< Length of the response message. + std::vector data; ///< Data of the response message. +}; + +void add_param(std::vector& msg, std::string_view key, std::string_view value) { + msg.insert(msg.end(), key.begin(), key.end()); + msg.push_back('\0'); + msg.insert(msg.end(), value.begin(), value.end()); + msg.push_back('\0'); +} + +// Function to receive data from socket +bool recv_data(int sock, void* buffer, size_t len) { + + if (recv(sock, buffer, len, 0) <= 0) { + fprintf(stderr, "Error receiving data\n"); + return false; + } + return true; +} + +// Function to send data over socket +bool send_data(int sock, const void* data, size_t len) { + if (send(sock, data, len, 0) != len) { + fprintf(stderr, "Error sending data\n"); + return false; + } + return true; +} + +/** + * @brief Builds a startup message for PostgreSQL connection. + * + * This function constructs a startup message for PostgreSQL connection using the provided + * connection parameters. + * + * @param parameters A vector of key-value pairs representing the connection parameters. + * @return A vector of characters representing the constructed startup message. + */ +std::vector build_startup_message(const std::vector& parameters) { + // Build startup message + std::vector startup_body; + int32_t protocol = htonl(0x00030000); // Protocol 3.0 + startup_body.insert(startup_body.end(), (char*)&protocol, (char*)&protocol + 4); + + // Add connection parameters + for (const auto& param : parameters) { + add_param(startup_body, param.name, param.value); + } + startup_body.push_back('\0'); + + // Prepend message length + std::vector startup_msg; + int32_t len = htonl(startup_body.size() + 4); + startup_msg.insert(startup_msg.end(), (char*)&len, (char*)&len + 4); + startup_msg.insert(startup_msg.end(), startup_body.begin(), startup_body.end()); + + return startup_msg; +} + +/** + * @brief Builds a password message for PostgreSQL authentication. + * + * This function constructs a password message to be sent to the PostgreSQL server + * during the authentication process. + * + * @param password The password to be included in the message. + * @return A vector of characters representing the constructed password message. + */ +std::vector build_password_message(std::string_view password) { + std::vector password_message; + int pass_msg_len = htonl(password.size() + 1 + 4); + password_message.push_back('p'); + password_message.insert(password_message.end(), (char*)&pass_msg_len, (char*)&pass_msg_len + 4); + password_message.insert(password_message.end(), password.begin(), password.end()); + password_message.push_back('\0'); + return password_message; +} + +/** + * @brief Connects to a PostgreSQL server. + * + * This function establishes a connection to a PostgreSQL server using the provided + * host and port number. + * + * @param host The hostname or IP address of the PostgreSQL server. + * @param port The port number of the PostgreSQL server. + * @return The socket file descriptor for the connection. + */ +int connect_server(const std::string& host, int port) { + int sock; + struct sockaddr_in server; + + // Create socket + sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == -1) { + perror("Socket creation failed"); + return 1; + } + + // Configure server address + server.sin_family = AF_INET; + server.sin_port = htons(port); + server.sin_addr.s_addr = inet_addr(host.c_str()); + + // Connect to PostgreSQL server + if (connect(sock, (struct sockaddr*)&server, sizeof(server)) < 0) { + fprintf(stderr, "Connection failed\n"); + return -1; + } + return sock; +} + +/** + * @brief Sends a startup message to the PostgreSQL server. + * + * This function constructs and sends a startup message to the PostgreSQL server + * using the provided socket and connection parameters. + * + * @param sock The socket file descriptor for the connection. + * @param params A vector of key-value pairs representing the connection parameters. + */ +void send_startup_message(int sock, const std::vector>& params) { + + int param_count = params.size(); + + char msg[512]; + int offset = 0; + uint32_t length = 4 + 4; + + for (int i = 0; i < param_count; i++) { + length += params[i].first.size() + 1 + params[i].second.size() + 1; + } + length += 1; + + uint32_t length_nbo = htonl(length); + uint32_t protocol = htonl(196608); + + memcpy(msg + offset, &length_nbo, 4); + offset += 4; + memcpy(msg + offset, &protocol, 4); + offset += 4; + + for (int i = 0; i < param_count; i++) { + strcpy(msg + offset, params[i].first.c_str()); + offset += params[i].first.size() + 1; + strcpy(msg + offset, params[i].second.c_str()); + offset += params[i].second.size() + 1; + } + msg[offset++] = '\0'; + + send(sock, msg, offset, 0); +} + +/** + * @brief Parses an error message from a PostgreSQL response payload. + * + * This function extracts and returns the error message from a given PostgreSQL response payload. + * + * @param payload A vector of characters representing the PostgreSQL response payload. + * @return A string containing the extracted error message. + */ +std::string parse_error(const std::vector& payload) { + const char* data = payload.data(); + size_t pos = 0; + std::string message; + + while (pos < payload.size()) { + if (data[pos] == '\0') break; + char field_type = data[pos++]; + + std::string field_value; + while (pos < payload.size() && data[pos] != '\0') { + field_value += data[pos++]; + } + pos++; + + if (field_type == 'M') message = field_value; + } + return message; +} + +/** + * @brief Handles cleartext password authentication for PostgreSQL. + * + * This function processes the cleartext password authentication request from the PostgreSQL server. + * + * @param sock The socket file descriptor for the connection. + * @param password The password to be sent for authentication. + * @return True if authentication is successful, false otherwise. + */ +bool handle_cleartext_auth(int sock, std::string_view password) { + + char response[1024]; + if (recv_data(sock, response, 9) == false) { // Read first 8 bytes (message type + length + auth type) + fprintf(stderr, "Error: failed to receive authentication message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + if (response[0] != 'R') { + fprintf(stderr, "Unexpected authentication response\n"); + return false; + } + + uint32_t auth_type; + memcpy(&auth_type, response + 5, 4); + auth_type = ntohl(auth_type); + + if (auth_type == 3) { // AuthenticationCleartextPassword + diag("Server requests cleartext password authentication\n"); + + std::vector password_msg = build_password_message(password); + + // Send PasswordMessage + if (send_data(sock, password_msg.data(), password_msg.size()) == false) { + fprintf(stderr, "Error: failed to send password message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + // Receive AuthenticationOK or Failure + if (recv_data(sock, response, 5) == false) { + fprintf(stderr, "Error: failed to receive authentication response in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + if (response[0] == 'E') { + uint32_t err_msg_len = ntohl(*reinterpret_cast(response + 1)); + if (recv_data(sock, response, err_msg_len - 4) == false) { + fprintf(stderr, "Error: failed to receive error message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + const std::vector payload (response, response + err_msg_len - 4); + std::string error_message = parse_error(payload); + fprintf(stderr, "%s\n", error_message.c_str()); + return false; + } + else if (response[0] != 'R') { + fprintf(stderr, "Unexpected authentication response\n"); + return false; + } + + if (recv_data(sock, response, 4) == false) { + fprintf(stderr, "Error: failed to receive error message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + memcpy(&auth_type, response, 4); + auth_type = ntohl(auth_type); + + if (auth_type != 0) { + diag("Authentication failed!\n"); + return false; + } + + std::vector msg_list; + + while (true) { + int bytes_received = recv(sock, response, sizeof(response), MSG_DONTWAIT); + if (bytes_received == 0) { + fprintf(stderr, "Error: Connection closed in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + if (bytes_received == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } + fprintf(stderr, "Error: Connection closed or error occurred in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + size_t offset = 0; + while (offset < static_cast(bytes_received)) { + char messageType = response[offset]; + int32_t messageLength; + + if (offset + 5 > static_cast(bytes_received)) { + fprintf(stderr, "Incomplete message header received\n"); + return false; + } + + memcpy(&messageLength, response + offset + 1, 4); + messageLength = ntohl(messageLength); + + if (offset + messageLength > static_cast(bytes_received)) { + fprintf(stderr, "Incomplete message body received.\n"); + return false; + } + + PgSQLResponse msg; + msg.type = messageType; + msg.length = messageLength; + msg.data.assign(response + offset + 5, response + offset + messageLength + 1); + msg_list.emplace_back(std::move(msg)); + offset += messageLength + 1; + } + } + + if (msg_list.back().type == 'E') { + const std::vector payload = msg_list.back().data; + std::string error_message = parse_error(payload); + fprintf(stderr, "%s\n", error_message.c_str()); + return false; + } else if (msg_list.back().type != 'Z') { + fprintf(stderr, "Unexpected message type %c\n", msg_list.back().type); + return false; + } + + } else { + diag("Unexpected authentication method: %d\n", auth_type); + return false; + } + diag("Authentication successful!\n"); + return true; +} + +/** + * @brief Executes a query on the PostgreSQL server. + * + * This function sends a query to the PostgreSQL server and processes the response. + * + * @param sock The socket file descriptor for the connection. + * @param query The SQL query to be executed. + * @return A unique pointer to a MyPGresult structure containing the query result. + */ +std::unique_ptr execute_query(int sock, const std::string& query) { + std::vector msg; + + // Build query message + msg.push_back('Q'); + uint32_t length = htonl(query.size() + 5); // 1 byte for Q + 4 for length + query + null + msg.insert(msg.end(), reinterpret_cast(&length), reinterpret_cast(&length) + 4); + msg.insert(msg.end(), query.begin(), query.end()); + msg.push_back('\0'); + + // Send Query message + if (send_data(sock, msg.data(), msg.size()) == false) { + fprintf(stderr, "Error: failed to send query in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + std::unique_ptr result(new MyPGresult); + + while (true) { + char msg_type; + if (recv_data(sock, reinterpret_cast(&msg_type), 1) == false) { + fprintf(stderr, "Error: failed to receive message in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + uint32_t msg_length; + if (recv_data(sock, reinterpret_cast(&msg_length), 4) == false) { + fprintf(stderr, "Error: failed to receive message length in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + msg_length = ntohl(msg_length); + + std::vector payload(msg_length - 4); + if (msg_length > 4) { + if (recv_data(sock, payload.data(), msg_length - 4) == false) { + fprintf(stderr, "Error: failed to receive message payload in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + } + + switch (msg_type) { + case 'T': { // Row Description + const char* data = payload.data(); + uint16_t num_columns; + memcpy(&num_columns, data, 2); + num_columns = ntohs(num_columns); + + size_t pos = 2; + result->columns.reserve(num_columns); + for (int i = 0; i < num_columns; i++) { + std::string col_name(data + pos); + pos += col_name.length() + 1 + 22; // Skip field metadata + result->columns.push_back(col_name); + } + break; + } + + case 'D': { // Data Row + const char* data = payload.data(); + uint16_t num_fields; + memcpy(&num_fields, data, 2); + num_fields = ntohs(num_fields); + + std::vector row; + size_t pos = 2; + for (int i = 0; i < num_fields; i++) { + int32_t field_len; + memcpy(&field_len, data + pos, 4); + pos += 4; + field_len = ntohl(field_len); + + if (field_len == -1) { + row.push_back("NULL"); + } + else { + row.emplace_back(data + pos, field_len); + pos += field_len; + } + } + result->rows.push_back(row); + break; + } + + case 'C': // Command Complete + result->status = std::string(payload.data(), payload.size()); + break; + + case 'E': // Error + result->error = parse_error(payload); + // Consume remaining messages + while (msg_type != 'Z') { + if (recv_data(sock, &msg_type, 1) == false) { + fprintf(stderr, "Error: failed to receive message in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + if (recv_data(sock, reinterpret_cast(&msg_length), 4) == false) { + fprintf(stderr, "Error: failed to receive message length in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + msg_length = ntohl(msg_length); + if (msg_length > 4) { + std::vector temp(msg_length - 4); + if (recv_data(sock, temp.data(), msg_length - 4) == false) { + fprintf(stderr, "Error: failed to receive message payload in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + } + } + return std::move(result); + + case 'Z': // Ready for Query + return std::move(result); + + default: + break; + } + } + return nullptr; +} + +const char* escape_string_backslash_spaces(const char* input) { + const char* c; + int input_len = 0; + int escape_count = 0; + + for (c = input; *c != '\0'; c++) { + if ((*c == ' ')) { + escape_count += 3; + } + else if ((*c == '\\')) { + escape_count += 2; + } + input_len++; + } + + if (escape_count == 0) + return input; + + char* output = (char*)malloc(input_len + escape_count + 1); + char* p = output; + + for (c = input; *c != '\0'; c++) { + if ((*c == ' ')) { + memcpy(p, "\\\\", 2); + p += 2; + } + else if (*c == '\\') { + *(p++) = '\\'; + } + *(p++) = *c; + } + *(p++) = '\0'; + return output; +} + +bool test_parameters(PGconn* admin_conn, const parameter_test& test) { + char buffer[512]; + bool ret = false; + for (const auto& parameter : test.set_admin_vars) { + snprintf(buffer, sizeof(buffer), "SET %s='%s'", parameter.name.c_str(), parameter.value.c_str()); + + if (executeQueries(admin_conn, { buffer, "LOAD PGSQL VARIABLES TO RUNTIME" }) == false) { + BAIL_OUT("Error: failed to set admin variable in file %s, line %d", __FILE__, __LINE__); + return false; + } + } + + int sock = connect_server(cl.pgsql_host, cl.pgsql_port); + if (sock == -1) { + BAIL_OUT("Error: failed to connect to the server in file %s, line %d", __FILE__, __LINE__); + return false; + } + + // Build startup message + std::vector parameters = { + { "user", cl.pgsql_username }, + }; + + parameters.reserve(test.conn_params.size() + test.conn_options.size() + 1); + + for (const auto& param : test.conn_params) { + if (param.value.empty() == true) continue; + parameters.push_back(param); + } + + if (test.conn_options.empty() == false) { + std::string options_value; + + for (size_t i = 0; i < test.conn_options.size(); i++) { + + options_value += " -c " + test.conn_params[i].name; + options_value += "="; + + const char* value = test.conn_options[i].c_str(); + const char* escaped_value = escape_string_backslash_spaces(value); + options_value += escaped_value; + + if (value != escaped_value) + free((void*)escaped_value); + } + + parameters.push_back({ "options", options_value }); + } + + // print parameters + for (const auto& param : parameters) { + diag("Parameter: %s = %s\n", param.name.c_str(), param.value.c_str()); + } + + std::vector startup_msg = build_startup_message(parameters); + + // Send StartupMessage + if (send_data(sock, startup_msg.data(), startup_msg.size()) == false) { + BAIL_OUT("Error: failed to send startup message in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + + // Receive AuthenticationRequest + if (handle_cleartext_auth(sock, cl.pgsql_password) == false) { + if (test.expect_failure) { + ok(true, "Authentication should fail"); + ret = true; + } else { + diag("Error: failed to handle cleartext authentication in file %s, line %d", __FILE__, __LINE__); + } + goto cleanup; + } + + for (size_t i = 0; i < test.set_commands.size(); i++) { + snprintf(buffer, sizeof(buffer), "SET %s='%s'", test.conn_params[i].name.c_str(), test.set_commands[i].c_str()); + diag("Executing: %s\n", buffer); + auto result = execute_query(sock, buffer); + + if (result == nullptr) { + BAIL_OUT("Error: failed to execute query in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + + if (result->error.empty() == false) { + ok(test.expect_failure, "Query '%s' should fail. %s", buffer, result->error.c_str()); + } + } + + if (test.reset_after) { + for (const auto& param : test.conn_params) { + std::string reset_cmd = "RESET " + param.name; + diag("Executing: %s\n", reset_cmd.c_str()); + auto result = execute_query(sock, reset_cmd); + if (result == nullptr) { + BAIL_OUT("Error: failed to reset parameter in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + if (result->error.empty() == false) { + ok(test.expect_failure, "Query '%s' should fail. %s", reset_cmd.c_str(), result->error.c_str()); + } + } + } + + for (int i = 0; i < test.conn_params.size(); i++) { + const auto& param = test.conn_params[i]; + std::string show_cmd = "SHOW " + param.name; + diag("Executing: %s\n", show_cmd.c_str()); + auto result = execute_query(sock, show_cmd); + if (result == nullptr) { + BAIL_OUT("Error: failed to execute query in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + if (test.expect_failure == false && result->error.empty()) { + ok(result->rows.size() == 1, "Number of rows should be 1"); + ok(result->rows[0][0] == test.expected[i], "Parameter '%s' value should be '%s'. Actual: '%s'", + param.name.c_str(), test.expected[i].c_str(), result->rows[0][0].c_str()); + } else { + ok(test.expect_failure, "Query '%s' should fail. %s", show_cmd.c_str(), result->error.c_str()); + } + } + ret = true; +cleanup: + close(sock); + + return ret; +} + +std::vector test_cases = { + // check if connection parameters validation is working correctly + { {}, + {{"sslmode", "test"}}, + {}, + {}, + {""}, + false, + true + }, + // check if session parameter validation is working correctly + { {}, + {{"extra_float_digits", "20"}}, + {}, + {}, + {""}, + false, + true + }, + // check if options parameters validation is working correctly + { {}, + {{"extra_float_digits", "1"}}, + {"19"}, + {}, + {""}, + false, + true + }, + { {}, + {{"ENABLE_HASHJOIN", "off"}, {"enable_seqscan", "on"}}, + {"on", "off"}, + {}, + {"off", "on"}, + false, + false + }, + { + {}, + {{"extra_float_digits", "1"}}, + {"2"}, + {}, + {"1"}, + false, + false + }, + { + {}, + {{"enable_hashjoin", "off"}, {"enable_seqscan", "on"}}, + {"on", "off"}, + {"on", "off"}, + {"off", "on"}, + true, + false + }, + { + {{"pgsql-default_datestyle", "ISO, MDY"}}, + {{"datestyle", ""}}, + {}, + {"Postgres"}, + {"Postgres, MDY"}, + false, // Reset both + false + }, + { + {{"pgsql-default_datestyle", "ISO, MDY"}}, + {{"datestyle", ""}}, + {}, + {"Postgres"}, + {"ISO, MDY"}, + true, // Reset both + false + }, + { + {}, + {{"escape_string_warning", "on"}, {"standard_conforming_strings", "on"}}, + {}, + {"off", "off"}, + {"off", "off"}, + false, + false + }, + { + {}, + {{"client_encoding", "UTF8"}}, + {"LATIN1"}, + {}, + {"UTF8"}, + false, + false + }, + { + {}, + {{"client_encoding", "UTF8"}}, + {"LATIN1"}, + {"LATIN1"}, + {"LATIN1"}, + false, + false + }, + { + {{"pgsql-default_client_encoding", "utf8"}}, + {{"client_encoding", "UTF8"}}, + {"LATIN1"}, + {"LATIN1"}, + {"UTF8"}, + true, + false + }, + { + {}, + {{"invalid_param", "invalid"}}, + {}, + {}, + {"invalid"}, + false, + true + } +}; + + +int main(int argc, char** argv) { + + int test_count = 0; + + for (const auto& test_case : test_cases) { + + if (test_case.expect_failure) { + int case_count = 1; + + if (test_case.set_commands.empty() == false) + case_count++; + if (test_case.reset_after) + case_count++; + + test_count += test_case.conn_params.size() * case_count; + } else + test_count += test_case.conn_params.size() * 2; + } + + plan(test_count); + + if (cl.getEnv()) + return exit_status(); + + auto admin_conn = createNewConnection(ConnType::ADMIN, "", false); + + if (!admin_conn || PQstatus(admin_conn.get()) != CONNECTION_OK) { + BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + + if (executeQueries(admin_conn.get(), { "SET pgsql-authentication_method=1", + "LOAD PGSQL VARIABLES TO RUNTIME" }) == false) { + BAIL_OUT("Error: failed to set pgsql-authentication_method=1 in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + + for (const auto& test_case : test_cases) { + + if (test_parameters(admin_conn.get(), test_case) == false) { + BAIL_OUT("Error: failed to test parameters in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + } + + return exit_status(); +} From c2f2ae5aa4c85e6a72daa3700f7bc72724beda2e Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 2 Apr 2025 01:40:48 +0500 Subject: [PATCH 5/7] * Updated code to use pgsql_tracked_variables[idx] consistently, removing direct usage of pgsql_tracked_variables[idx].idx. This change improves consistency and reduces potential human error. * Fixed some warning --- include/proxysql_structs.h | 5 ++--- lib/PgSQL_Protocol.cpp | 6 +++--- lib/PgSQL_Session.cpp | 18 +++++++++--------- lib/PgSQL_Variables.cpp | 37 +++++++++++++++++++------------------ 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 84a950768..4a6f8c810 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -244,7 +244,7 @@ enum mysql_variable_name { }; /* NOTE: - make special ATTENTION that the order in mysql_variable_name + make special ATTENTION that the order in pgsql_variable_name and pgsql_tracked_variables[] is THE SAME */ enum pgsql_variable_name { @@ -337,8 +337,6 @@ struct pgsql_variable_st { enum session_status status; // what status should be changed after setting this variables const char* set_variable_name; // what variable name (or string) will be used when setting it to backend const char* internal_variable_name; // variable name as displayed in admin , WITHOUT "default_" - // Also used in INTERNAL SESSION - // if NULL , MySQL_Variables::MySQL_Variables will set it to set_variable_name during initialization const char* default_value; // default value uint8_t options; // options const pgsql_variable_validator* validator; // validate value @@ -1814,6 +1812,7 @@ pgsql_variable_st pgsql_tracked_variables[]{ { PGSQL_MAINTENANCE_WORK_MEM, SETTING_VARIABLE, "maintenance_work_mem", "maintenance_work_mem", "64MB", (PGTRACKED_VAR_OPT_QUOTE), &pgsql_variable_validator_maintenance_work_mem, nullptr }, { PGSQL_SYNCHRONOUS_COMMIT, SETTING_VARIABLE, "synchronous_commit", "synchronous_commit", "on", (PGTRACKED_VAR_OPT_QUOTE), &pgsql_variable_validator_synchronous_commit, nullptr}, }; + #endif //EXCLUDE_TRACKING_VARAIABLES #else diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 13ed7a3e6..3dd2b79cf 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -403,7 +403,7 @@ bool PgSQL_Protocol::generate_pkt_initial_handshake(bool send, void** _ptr, unsi if (RAND_bytes((*myds)->tmp_login_salt, sizeof((*myds)->tmp_login_salt)) != 1) { // Fallback method: using a basic pseudo-random generator srand((unsigned int)time(NULL)); - for (int i = 0; i < sizeof((*myds)->tmp_login_salt); i++) { + for (size_t i = 0; i < sizeof((*myds)->tmp_login_salt); i++) { (*myds)->tmp_login_salt[i] = rand() % 256; } } @@ -632,11 +632,11 @@ unsigned int get_string(const char* data, unsigned int len, const char** dst_p) bool PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt) { - int32_t offset = 0; + uint32_t offset = 0; while (offset < pkt->data.size) { char* nameptr = (char*)pkt->data.ptr + offset; - int32_t valoffset; + uint32_t valoffset; char* valptr; if (*nameptr == '\0') diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 42ef7896a..e3c429972 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -4201,7 +4201,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C int idx = PGSQL_NAME_LAST_HIGH_WM; for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { if (variable_name_exists(pgsql_tracked_variables[i], var.c_str()) == true) { - idx = pgsql_tracked_variables[i].idx; + idx = i; break; } } @@ -4252,8 +4252,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", var.c_str(), value1.c_str()); uint32_t var_hash_int = SpookyHash::Hash32(value1.c_str(), value1.length(), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str(), true)) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value1.c_str(), true)) { return false; } if (idx == PGSQL_DATESTYLE) { @@ -4631,8 +4631,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, false)) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value, false)) { return false; } if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { @@ -4647,8 +4647,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C const char* value = get_default_session_variable((enum pgsql_variable_name)idx); proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, false)) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value, false)) { return false; } if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { @@ -4683,8 +4683,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C const char* value = get_default_session_variable((enum pgsql_variable_name)idx); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value, true)) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value, true)) { return false; } if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index af1f5ba6a..7285436c6 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -19,32 +19,32 @@ PgSQL_Variables::PgSQL_Variables() { // NOTE: This variable has been temporarily ignored. Check issues #3442 and #3441. //ignore_vars.push_back("session_track_schema"); variables_regexp = ""; + + /* + NOTE: + make special ATTENTION that the order in pgsql_variable_name + and pgsqll_tracked_variables[] is THE SAME + NOTE: + PgSQL_Variables::PgSQL_Variables() has a built-in check to make sure that the order is correct, + and that variables are in alphabetical order + */ for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { + //Array index and enum value (idx) should be same + assert(i == pgsql_tracked_variables[i].idx); + + if (i > PGSQL_NAME_LAST_LOW_WM + 1) { + assert(strcmp(pgsql_tracked_variables[i].set_variable_name, pgsql_tracked_variables[i - 1].set_variable_name) > 0); + } + // we initialized all the internal_variable_name if set to NULL if (pgsql_tracked_variables[i].internal_variable_name == NULL) { pgsql_tracked_variables[i].internal_variable_name = pgsql_tracked_variables[i].set_variable_name; } - } -/* - NOTE: - make special ATTENTION that the order in pgsql_variable_name - and pgsqll_tracked_variables[] is THE SAME - NOTE: - PgSQL_Variables::PgSQL_Variables() has a built-in check to make sure that the order is correct, - and that variables are in alphabetical order -*/ - for (int i = PGSQL_NAME_LAST_LOW_WM; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - assert(i == pgsql_tracked_variables[i].idx); - if (i > PGSQL_NAME_LAST_LOW_WM+1) { - assert(strcmp(pgsql_tracked_variables[i].set_variable_name, pgsql_tracked_variables[i-1].set_variable_name) > 0); - } - } - for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { PgSQL_Variables::verifiers[i] = verify_server_variable; PgSQL_Variables::updaters[i] = update_server_variable; - - if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { + + if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { variables_regexp += pgsql_tracked_variables[i].set_variable_name; variables_regexp += "|"; @@ -56,6 +56,7 @@ PgSQL_Variables::PgSQL_Variables() { } } } + for (std::vector::iterator it=ignore_vars.begin(); it != ignore_vars.end(); it++) { variables_regexp += *it; variables_regexp += "|"; From 8a5dfc2ecbe613fa554281ed5c02111f17f97141 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Thu, 10 Apr 2025 11:57:37 +0500 Subject: [PATCH 6/7] Added comment --- include/PgSQL_Connection.h | 99 +++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 43 deletions(-) diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index e2271c4b3..470278628 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -152,53 +152,66 @@ static const std::unordered_map #define PG_EVENT_EXCEPT 0x04 #define PG_EVENT_TIMEOUT 0x08 + +/** + * @class PgSQL_Conn_Param + * @brief Stores PostgreSQL connection parameters sent by client. + * + * This class stores key-value pairs representing connection parameters + * for PostgreSQL connection. + * + */ class PgSQL_Conn_Param { public: - PgSQL_Conn_Param() {} - ~PgSQL_Conn_Param() {} - - bool set_value(const char* key, const char* val) { - if (key == nullptr || val == nullptr) return false; - connection_parameters[key] = val; - return true; - } - - inline - const char* get_value(PgSQL_Param_Name key) const { - return get_value(param_name[key]); - } - - const char* get_value(const char* key) const { - auto it = connection_parameters.find(key); - if (it != connection_parameters.end()) { - return it->second.c_str(); - } - return nullptr; - } - - bool remove_value(const char* key) { - auto it = connection_parameters.find(key); - if (it != connection_parameters.end()) { - connection_parameters.erase(it); - return true; - } - return false; - } - - inline - bool is_empty() const { - return connection_parameters.empty(); - } - - inline - void clear() { - connection_parameters.clear(); - } + PgSQL_Conn_Param() {} + ~PgSQL_Conn_Param() {} + + bool set_value(const char* key, const char* val) { + if (key == nullptr || val == nullptr) return false; + connection_parameters[key] = val; + return true; + } + + inline + const char* get_value(PgSQL_Param_Name key) const { + return get_value(param_name[key]); + } + + const char* get_value(const char* key) const { + auto it = connection_parameters.find(key); + if (it != connection_parameters.end()) { + return it->second.c_str(); + } + return nullptr; + } + + bool remove_value(const char* key) { + auto it = connection_parameters.find(key); + if (it != connection_parameters.end()) { + connection_parameters.erase(it); + return true; + } + return false; + } + + inline + bool is_empty() const { + return connection_parameters.empty(); + } + + inline + void clear() { + connection_parameters.clear(); + } private: - std::map connection_parameters; - friend class PgSQL_Session; - friend class PgSQL_Protocol; + /** + * @brief Stores the connection parameters as key-value pairs. + */ + std::map connection_parameters; + + friend class PgSQL_Session; + friend class PgSQL_Protocol; }; class PgSQL_Variable { From 982be8b08a31c0da23619631b3d47d2b68ad6d2d Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Thu, 10 Apr 2025 11:58:04 +0500 Subject: [PATCH 7/7] Code cleanup --- lib/PgSQL_HostGroups_Manager.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/PgSQL_HostGroups_Manager.cpp b/lib/PgSQL_HostGroups_Manager.cpp index bd892e3e6..64b6f835a 100644 --- a/lib/PgSQL_HostGroups_Manager.cpp +++ b/lib/PgSQL_HostGroups_Manager.cpp @@ -1789,7 +1789,6 @@ void PgSQL_HostGroups_Manager::push_MyConn_to_pool(PgSQL_Connection *c, bool _lo mysrvc->ConnectionsUsed->add(c); // Add the connection back to the list of used connections destroy_MyConn_from_pool(c, false); // Destroy the connection from the pool } else {*/ - //c->optimize(); mysrvc->ConnectionsFree->add(c); //} } else {