diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index a8048c05b..470278628 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -38,7 +38,6 @@ enum PgSQL_Param_Name { PG_REQUIRE_AUTH, // Specifies the authentication method that the client requires from the server PG_CHANNEL_BINDING, // Controls the client's use of channel binding PG_CONNECT_TIMEOUT, // Maximum time to wait while connecting, in seconds - PG_CLIENT_ENCODING, // Sets the client_encoding configuration parameter for this connection PG_OPTIONS, // Specifies command-line options to send to the server at connection start PG_APPLICATION_NAME, // Specifies a value for the application_name configuration parameter PG_FALLBACK_APPLICATION_NAME, // Specifies a fallback value for the application_name configuration parameter @@ -70,120 +69,81 @@ enum PgSQL_Param_Name { PG_TARGET_SESSION_ATTRS, // Determines whether the session must have certain properties to be acceptable PG_LOAD_BALANCE_HOSTS, // Controls the order in which the client tries to connect to the available hosts and addresses // Environment Options - PG_DATESTYLE, // Sets the value of the DateStyle parameter - PG_TIMEZONE, // Sets the value of the TimeZone parameter - PG_GEQO, // Enables or disables the use of the GEQO query optimizer + //PG_DATESTYLE, // Sets the value of the DateStyle parameter + //PG_TIMEZONE, // Sets the value of the TimeZone parameter + //PG_GEQO, // Enables or disables the use of the GEQO query optimizer PG_PARAM_SIZE }; -static const char* PgSQL_Param_Name_Str[] = { - "host", - "hostaddr", - "port", - "database", - "user", - "password", - "passfile", - "require_auth", - "channel_binding", - "connect_timeout", - "client_encoding", - "options", - "application_name", - "fallback_application_name", - "keepalives", - "keepalives_idle", - "keepalives_interval", - "keepalives_count", - "tcp_user_timeout", - "replication", - "gsseencmode", - "sslmode", - "requiressl", - "sslcompression", - "sslcert", - "sslkey", - "sslpassword", - "sslcertmode", - "sslrootcert", - "sslcrl", - "sslcrldir", - "sslsni", - "requirepeer", - "ssl_min_protocol_version", - "ssl_max_protocol_version", - "krbsrvname", - "gsslib", - "gssdelegation", - "service", - "target_session_attrs", - "load_balance_hosts", - // Environment Options - "datestyle", - "timezone", - "geqo" -}; - struct Param_Name_Validation { const char** accepted_values; int default_value_idx; +}; +static const Param_Name_Validation require_auth{ (const char* []) { "password","md5","gss","sspi","scram-sha-256","none",nullptr },-1 }; +static const Param_Name_Validation replication{ (const char* []) { "true","on","yes","1","database","false","off","no","0",nullptr },-1 }; +static const Param_Name_Validation gsseencmode{ (const char* []) { "disable","prefer","require",nullptr },1 }; +static const Param_Name_Validation sslmode{ (const char* []) { "disable","allow","prefer","require","verify-ca","verify-full",nullptr },2 }; +static const Param_Name_Validation sslcertmode{ (const char* []) { "disable","allow","require",nullptr },1 }; +static const Param_Name_Validation target_session_attrs{ (const char* []) { "any","read-write","read-only","primary","standby","prefer-standby",nullptr },0 }; +static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disable","random",nullptr },-1 }; + +// Excluding client_encoding since it is managed as part of the session variables +#define PARAMETER_LIST \ + PARAM("host", nullptr) \ + PARAM("hostaddr", nullptr) \ + PARAM("port", nullptr) \ + PARAM("database", nullptr) \ + PARAM("user", nullptr) \ + PARAM("password", nullptr) \ + PARAM("passfile", nullptr) \ + PARAM("require_auth", &require_auth) \ + PARAM("channel_binding", nullptr) \ + PARAM("connect_timeout", nullptr) \ + PARAM("options", nullptr) \ + PARAM("application_name", nullptr) \ + PARAM("fallback_application_name", nullptr) \ + PARAM("keepalives", nullptr) \ + PARAM("keepalives_idle", nullptr) \ + PARAM("keepalives_interval", nullptr) \ + PARAM("keepalives_count", nullptr) \ + PARAM("tcp_user_timeout", nullptr) \ + PARAM("replication", &replication) \ + PARAM("gsseencmode", &gsseencmode) \ + PARAM("sslmode", &sslmode) \ + PARAM("requiressl", nullptr) \ + PARAM("sslcompression", nullptr) \ + PARAM("sslcert", nullptr) \ + PARAM("sslkey", nullptr) \ + PARAM("sslpassword", nullptr) \ + PARAM("sslcertmode", &sslcertmode) \ + PARAM("sslrootcert", nullptr) \ + PARAM("sslcrl", nullptr) \ + PARAM("sslcrldir", nullptr) \ + PARAM("sslsni", nullptr) \ + PARAM("requirepeer", nullptr) \ + PARAM("ssl_min_protocol_version", nullptr) \ + PARAM("ssl_max_protocol_version", nullptr) \ + PARAM("krbsrvname", nullptr) \ + PARAM("gsslib", nullptr) \ + PARAM("gssdelegation", nullptr) \ + PARAM("service", nullptr) \ + PARAM("target_session_attrs", &target_session_attrs) \ + PARAM("load_balance_hosts", &load_balance_hosts) + +// Generate parameter array +constexpr const char* param_name[] = { +#define PARAM(name, val) name, + PARAMETER_LIST +#undef PARAM }; -static const Param_Name_Validation require_auth {(const char*[]){"password","md5","gss","sspi","scram-sha-256","none",nullptr},-1}; -static const Param_Name_Validation replication {(const char*[]){"true","on","yes","1","database","false","off","no","0",nullptr},-1}; -static const Param_Name_Validation gsseencmode {(const char*[]){"disable","prefer","require",nullptr},1}; -static const Param_Name_Validation sslmode {(const char*[]){"disable","allow","prefer","require","verify-ca","verify-full",nullptr},2}; -static const Param_Name_Validation sslcertmode {(const char*[]){"disable","allow","require",nullptr},1}; -static const Param_Name_Validation target_session_attrs {(const char*[]){"any","read-write","read-only","primary","standby","prefer-standby",nullptr},0 }; -static const Param_Name_Validation load_balance_hosts {(const char*[]){"disable","random",nullptr},-1}; - -static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SIZE] = { - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &require_auth, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &replication, - &gsseencmode, - &sslmode, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &sslcertmode, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - &target_session_attrs, - &load_balance_hosts, - nullptr, - nullptr, - nullptr +// make sure all the keys are in lower case +static const std::unordered_map param_name_map = { +#define PARAM(name, val) {name, val}, + PARAMETER_LIST +#undef PARAM }; #define PG_EVENT_NONE 0x00 @@ -192,83 +152,66 @@ static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SI #define PG_EVENT_EXCEPT 0x04 #define PG_EVENT_TIMEOUT 0x08 -class PgSQL_Conn_Param { -private: - bool validate(PgSQL_Param_Name key, const char* val) { - assert(val); - const Param_Name_Validation* validation = PgSQL_Param_Name_Accepted_Values[key]; - - if (validation != nullptr && validation->accepted_values) { - const char** accepted_value = validation->accepted_values; - while (accepted_value != nullptr) { - if (strcmp(val, *accepted_value) == 0) { - return true; - } - } - } else { - return true; - } - - return false; - } +/** + * @class PgSQL_Conn_Param + * @brief Stores PostgreSQL connection parameters sent by client. + * + * This class stores key-value pairs representing connection parameters + * for PostgreSQL connection. + * + */ +class PgSQL_Conn_Param { public: - PgSQL_Conn_Param() {} - ~PgSQL_Conn_Param() { - for (int i = 0; i < PG_PARAM_SIZE; i++) { - if (param_value[i]) - free(param_value[i]); - } - } - - bool set_value(PgSQL_Param_Name key, const char* val) { - if (key == -1) return false; - if (validate(key, val)) { - if (param_value[key]) { - free(param_value[key]); - } - param_value[key] = strdup(val); - param_set.push_back(key); - return true; - } - return false; - } - - bool set_value(const char* key, const char* val) { - return set_value((PgSQL_Param_Name)get_param_name(key), val); - } - - void reset_value(PgSQL_Param_Name key) { - if (param_value[key]) { - free(param_value[key]); - } - param_value[key] = nullptr; - - // this has O(n) complexity. need to fix it.... - param_set.erase(param_set.begin() + static_cast(key)); - } + PgSQL_Conn_Param() {} + ~PgSQL_Conn_Param() {} + + bool set_value(const char* key, const char* val) { + if (key == nullptr || val == nullptr) return false; + connection_parameters[key] = val; + return true; + } + + inline + const char* get_value(PgSQL_Param_Name key) const { + return get_value(param_name[key]); + } + + const char* get_value(const char* key) const { + auto it = connection_parameters.find(key); + if (it != connection_parameters.end()) { + return it->second.c_str(); + } + return nullptr; + } + + bool remove_value(const char* key) { + auto it = connection_parameters.find(key); + if (it != connection_parameters.end()) { + connection_parameters.erase(it); + return true; + } + return false; + } + + inline + bool is_empty() const { + return connection_parameters.empty(); + } + + inline + void clear() { + connection_parameters.clear(); + } - const char* get_value(PgSQL_Param_Name key) const { - return param_value[key]; - } - - int get_param_name(const char* name) { - int key = -1; - - for (int i = 0; i < PG_PARAM_SIZE; i++) { - if (strcmp(name, PgSQL_Param_Name_Str[i]) == 0) { - key = i; - break; - } - } - if (key == -1) { - proxy_warning("Unrecognized connection option. Please report this as a bug for future enhancements:%s\n", name); - } - return key; - } +private: + /** + * @brief Stores the connection parameters as key-value pairs. + */ + std::map connection_parameters; - std::vector param_set; - char* param_value[PG_PARAM_SIZE]{}; + friend class PgSQL_Session; + friend class PgSQL_Protocol; }; class PgSQL_Variable { @@ -418,12 +361,9 @@ public: PGresult* get_result(); void next_multi_statement_result(PGresult* result); bool set_single_row_mode(); - void optimize() {} void update_bytes_recv(uint64_t bytes_recv); void update_bytes_sent(uint64_t bytes_sent); void ProcessQueryAndSetStatusFlags(char* query_digest_text); - void set_charset(const char* charset); - void connect_start_SetCharset(const char* client_encoding, uint32_t hash = 0); inline const PGconn* get_pg_connection() const { return pgsql_conn; } inline int get_pg_server_version() { return PQserverVersion(pgsql_conn); } diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index db903c205..a1d2a23c1 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -986,7 +986,7 @@ private: * @note This function iterates through the key-value pairs in the startup * packet and stores them in the connection parameters object. */ - void load_conn_parameters(pgsql_hdr* pkt, bool startup); + bool load_conn_parameters(pgsql_hdr* pkt); /** * @brief Handles the client's first message in a SCRAM-SHA-256 diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index f02c1d4c9..96975f608 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -283,9 +283,6 @@ private: bool handler_again___status_SETTING_LDAP_USER_VARIABLE(int*); bool handler_again___status_SETTING_SQL_MODE(int*); bool handler_again___status_SETTING_SESSION_TRACK_GTIDS(int*); -#endif // 0 - bool handler_again___status_CHANGING_CHARSET(int* _rc); -#if 0 bool handler_again___status_CHANGING_SCHEMA(int*); #endif // 0 bool handler_again___status_CONNECTING_SERVER(int*); @@ -414,12 +411,6 @@ public: int current_hostgroup; int default_hostgroup; int previous_hostgroup; - /** - * @brief Charset directly specified by the client. Supplied and updated via 'HandshakeResponse' - * and 'COM_CHANGE_USER' packets. - * @details Used when session needs to be restored via 'COM_RESET_CONNECTION'. - */ - int default_charset; int locked_on_hostgroup; int next_query_flagIN; int mirror_hostgroup; diff --git a/include/PgSQL_Set_Stmt_Parser.h b/include/PgSQL_Set_Stmt_Parser.h index c052184b8..ee3732fa2 100644 --- a/include/PgSQL_Set_Stmt_Parser.h +++ b/include/PgSQL_Set_Stmt_Parser.h @@ -37,7 +37,9 @@ class PgSQL_Set_Stmt_Parser { void generateRE_parse1v2(); // First implemenation of the parser for TRANSACTION ISOLATION LEVEL and TRANSACTION READ/WRITE std::map> parse2(); +#if 0 std::string parse_character_set(); +#endif std::string remove_comments(const std::string& q); }; diff --git a/include/PgSQL_Variables.h b/include/PgSQL_Variables.h index 24c1cbc07..7e2daec8c 100644 --- a/include/PgSQL_Variables.h +++ b/include/PgSQL_Variables.h @@ -14,10 +14,8 @@ extern void print_backtrace(void); typedef bool (*pgsql_verify_var)(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash); typedef bool (*pgsql_update_var)(PgSQL_Session* session, int idx, int &_rc); -bool validate_charset(PgSQL_Session* session, int idx, int &_rc); bool update_server_variable(PgSQL_Session* session, int idx, int &_rc); bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash); -bool verify_set_names(PgSQL_Session* session); class PgSQL_Variables { static pgsql_verify_var verifiers[PGSQL_NAME_LAST_HIGH_WM]; @@ -31,13 +29,13 @@ public: PgSQL_Variables(); ~PgSQL_Variables(); - bool client_set_value(PgSQL_Session* session, int idx, const std::string& value); + bool client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx); bool client_set_hash_and_value(PgSQL_Session* session, int idx, const std::string& value, uint32_t hash); void client_reset_value(PgSQL_Session* session, int idx); const char* client_get_value(PgSQL_Session* session, int idx) const; uint32_t client_get_hash(PgSQL_Session* session, int idx) const; - void server_set_value(PgSQL_Session* session, int idx, const char* value); + void server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx); void server_set_hash_and_value(PgSQL_Session* session, int idx, const char* value, uint32_t hash); void server_reset_value(PgSQL_Session* session, int idx); const char* server_get_value(PgSQL_Session* session, int idx) const; diff --git a/include/PgSQL_Variables_Validator.h b/include/PgSQL_Variables_Validator.h index ae5469847..e6d1183e5 100644 --- a/include/PgSQL_Variables_Validator.h +++ b/include/PgSQL_Variables_Validator.h @@ -9,7 +9,8 @@ typedef enum { VARIABLE_TYPE_BOOL, /**< Boolean variable type. */ VARIABLE_TYPE_STRING, /**< String variable type. */ VARIABLE_TYPE_DATESTYLE, /**< DateStyle variable type. */ - VARIABLE_TYPE_MAINTENANCE_WORK_MEM + VARIABLE_TYPE_MAINTENANCE_WORK_MEM, /**< MaintenanceWorkMem variable type. */ + VARIABLE_TYPE_CLIENT_ENCODING /**< ClientEncoding variable type. */ } pgsql_variable_type_t; @@ -61,5 +62,5 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag extern const pgsql_variable_validator pgsql_variable_validator_bytea_output; extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits; extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem; - +extern const pgsql_variable_validator pgsql_variable_validator_client_encoding; #endif // PGSQL_VARIABLES_VALIDATOR_H diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 2450e8a82..4bce4d057 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -244,7 +244,7 @@ enum mysql_variable_name { }; /* NOTE: - make special ATTENTION that the order in mysql_variable_name + make special ATTENTION that the order in pgsql_variable_name and pgsql_tracked_variables[] is THE SAME */ enum pgsql_variable_name { @@ -337,8 +337,6 @@ struct pgsql_variable_st { enum session_status status; // what status should be changed after setting this variables const char* set_variable_name; // what variable name (or string) will be used when setting it to backend const char* internal_variable_name; // variable name as displayed in admin , WITHOUT "default_" - // Also used in INTERNAL SESSION - // if NULL , MySQL_Variables::MySQL_Variables will set it to set_variable_name during initialization const char* default_value; // default value uint8_t options; // options const pgsql_variable_validator* validator; // validate value @@ -1797,9 +1795,10 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag extern const pgsql_variable_validator pgsql_variable_validator_bytea_output; extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits; extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem; +extern const pgsql_variable_validator pgsql_variable_validator_client_encoding; pgsql_variable_st pgsql_tracked_variables[]{ - { PGSQL_CLIENT_ENCODING, SETTING_CHARSET, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_SET_TRANSACTION | PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), nullptr, { "names", nullptr } }, + { PGSQL_CLIENT_ENCODING, SETTING_VARIABLE, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_client_encoding, { "names", nullptr } }, { PGSQL_DATESTYLE, SETTING_VARIABLE, "datestyle", "datestyle", "ISO, MDY" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_datestyle, nullptr }, { PGSQL_INTERVALSTYLE, SETTING_VARIABLE, "intervalstyle", "intervalstyle", "postgres" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_intervalstyle, nullptr }, { PGSQL_STANDARD_CONFORMING_STRINGS, SETTING_VARIABLE, "standard_conforming_strings", "standard_conforming_strings", "on", (PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_bool, nullptr }, @@ -1819,6 +1818,7 @@ pgsql_variable_st pgsql_tracked_variables[]{ { PGSQL_MAINTENANCE_WORK_MEM, SETTING_VARIABLE, "maintenance_work_mem", "maintenance_work_mem", "64MB", (PGTRACKED_VAR_OPT_QUOTE), &pgsql_variable_validator_maintenance_work_mem, nullptr }, { PGSQL_SYNCHRONOUS_COMMIT, SETTING_VARIABLE, "synchronous_commit", "synchronous_commit", "on", (PGTRACKED_VAR_OPT_QUOTE), &pgsql_variable_validator_synchronous_commit, nullptr}, }; + #endif //EXCLUDE_TRACKING_VARAIABLES #else diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index 2cd818309..68baf89da 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -14,103 +14,6 @@ using json = nlohmann::json; #include "PgSQL_Query_Processor.h" #include "MySQL_Variables.h" - -#if 0 -// some of the code that follows is from mariadb client library memory allocator -typedef int myf; // Type of MyFlags in my_funcs -#define MYF(v) (myf) (v) -#define MY_KEEP_PREALLOC 1 -#define MY_ALIGN(A,L) (((A) + (L) - 1) & ~((L) - 1)) -#define ALIGN_SIZE(A) MY_ALIGN((A),sizeof(double)) -static void ma_free_root(MA_MEM_ROOT *root, myf MyFLAGS); -static void *ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size); -#define MAX(a,b) (((a) > (b)) ? (a) : (b)) - - -static void * ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size) -{ - size_t get_size; - void * point; - MA_USED_MEM *next= 0; - MA_USED_MEM **prev; - - Size= ALIGN_SIZE(Size); - - if ((*(prev= &mem_root->free))) - { - if ((*prev)->left < Size && - mem_root->first_block_usage++ >= 16 && - (*prev)->left < 4096) - { - next= *prev; - *prev= next->next; - next->next= mem_root->used; - mem_root->used= next; - mem_root->first_block_usage= 0; - } - for (next= *prev; next && next->left < Size; next= next->next) - prev= &next->next; - } - if (! next) - { /* Time to alloc new block */ - get_size= MAX(Size+ALIGN_SIZE(sizeof(MA_USED_MEM)), - (mem_root->block_size & ~1) * ( (mem_root->block_num >> 2) < 4 ? 4 : (mem_root->block_num >> 2) ) ); - - if (!(next = (MA_USED_MEM*) malloc(get_size))) - { - if (mem_root->error_handler) - (*mem_root->error_handler)(); - return((void *) 0); /* purecov: inspected */ - } - mem_root->block_num++; - next->next= *prev; - next->size= get_size; - next->left= get_size-ALIGN_SIZE(sizeof(MA_USED_MEM)); - *prev=next; - } - point= (void *) ((char*) next+ (next->size-next->left)); - if ((next->left-= Size) < mem_root->min_malloc) - { /* Full block */ - *prev=next->next; /* Remove block from list */ - next->next=mem_root->used; - mem_root->used=next; - mem_root->first_block_usage= 0; - } - return(point); -} - - -static void ma_free_root(MA_MEM_ROOT *root, myf MyFlags) -{ - MA_USED_MEM *next,*old; - - if (!root) - return; /* purecov: inspected */ - if (!(MyFlags & MY_KEEP_PREALLOC)) - root->pre_alloc=0; - - for ( next=root->used; next ;) - { - old=next; next= next->next ; - if (old != root->pre_alloc) - free(old); - } - for (next= root->free ; next ; ) - { - old=next; next= next->next ; - if (old != root->pre_alloc) - free(old); - } - root->used=root->free=0; - if (root->pre_alloc) - { - root->free=root->pre_alloc; - root->free->left=root->pre_alloc->size-ALIGN_SIZE(sizeof(MA_USED_MEM)); - root->free->next=0; - } -} -#endif // 0 - extern char * binary_sha1; #include "proxysql_find_charset.h" @@ -841,31 +744,22 @@ void PgSQL_Connection::connect_start() { conninfo << "sslmode='disable' "; // not supporting SSL } - if (myds && myds->sess) { - const char* charset = NULL; - uint32_t charset_hash = 0; - - // Take client character set and use it to connect to backend - charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING); - if (charset_hash != 0) - charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING); - - //if (!charset) { - // charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING]; - //} - + if (myds && myds->sess && myds->sess->client_myds) { // Client Encoding should be always set - assert(charset); - - connect_start_SetCharset(charset, charset_hash); - - escaped_str = escape_string_single_quotes_and_backslashes((char*)charset, false); + const char* client_charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING); + assert(client_charset); + uint32_t client_charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING); + assert(client_charset_hash); + const char* escaped_str = escape_string_backslash_spaces(client_charset); conninfo << "client_encoding='" << escaped_str << "' "; - if (escaped_str != charset) - free(escaped_str); + if (escaped_str != client_charset) + free((char*)escaped_str); + + // charset validation is already done + pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, client_charset, client_charset_hash); std::vector client_options; - client_options.reserve(PGSQL_NAME_LAST_LOW_WM + dynamic_variables_idx.size()); + client_options.reserve(PGSQL_NAME_LAST_LOW_WM + myds->sess->client_myds->myconn->dynamic_variables_idx.size()); // excluding PGSQL_CLIENT_ENCODING for (unsigned int idx = 1; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { @@ -873,9 +767,9 @@ void PgSQL_Connection::connect_start() { client_options.push_back(idx); } - for (std::vector::const_iterator it_c = dynamic_variables_idx.begin(); it_c != dynamic_variables_idx.end(); it_c++) { - assert(pgsql_variables.client_get_hash(myds->sess, *it_c)); - client_options.push_back(*it_c); + for (uint32_t idx : myds->sess->client_myds->myconn->dynamic_variables_idx) { + assert(pgsql_variables.client_get_hash(myds->sess, idx)); + client_options.push_back(idx); } if (client_options.empty() == false || @@ -1882,31 +1776,6 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) { } } -void PgSQL_Connection::set_charset(const char* charset) { - proxy_debug(PROXY_DEBUG_MYSQL_CONNPOOL, 4, "Setting client encoding %s\n", charset); - pgsql_variables.client_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset); -} - -void PgSQL_Connection::connect_start_SetCharset(const char* charset, uint32_t hash) { - assert(charset); - - int charset_encoding = PgSQL_Connection::char_to_encoding(charset); - - if (charset_encoding == -1) { - proxy_error("Cannot find character set [%s]\n", charset); - assert(0); - } - - /* We are connecting to backend setting charset in connection parameters. - * Client already has sent us a character set and client connection variables have been already set. - * Now we store this charset in server connection variables to avoid updating this variables on backend. - */ - if (hash == 0) - pgsql_variables.server_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset); - else - pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, charset, hash); -} - // this function is identical to async_query() , with the only exception that MyRS should never be set int PgSQL_Connection::async_send_simple_command(short event, char* stmt, unsigned long length) { PROXY_TRACE(); diff --git a/lib/PgSQL_HostGroups_Manager.cpp b/lib/PgSQL_HostGroups_Manager.cpp index b3455ccc6..64b6f835a 100644 --- a/lib/PgSQL_HostGroups_Manager.cpp +++ b/lib/PgSQL_HostGroups_Manager.cpp @@ -1789,7 +1789,6 @@ void PgSQL_HostGroups_Manager::push_MyConn_to_pool(PgSQL_Connection *c, bool _lo mysrvc->ConnectionsUsed->add(c); // Add the connection back to the list of used connections destroy_MyConn_from_pool(c, false); // Destroy the connection from the pool } else {*/ - c->optimize(); mysrvc->ConnectionsFree->add(c); //} } else { diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 65084631f..3dd2b79cf 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -403,7 +403,7 @@ bool PgSQL_Protocol::generate_pkt_initial_handshake(bool send, void** _ptr, unsi if (RAND_bytes((*myds)->tmp_login_salt, sizeof((*myds)->tmp_login_salt)) != 1) { // Fallback method: using a basic pseudo-random generator srand((unsigned int)time(NULL)); - for (int i = 0; i < sizeof((*myds)->tmp_login_salt); i++) { + for (size_t i = 0; i < sizeof((*myds)->tmp_login_salt); i++) { (*myds)->tmp_login_salt[i] = rand() % 256; } } @@ -630,26 +630,34 @@ unsigned int get_string(const char* data, unsigned int len, const char** dst_p) return (nul + 1 - data); } -void PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt, bool startup) +bool PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt) { - const char* key, * val; - unsigned int read_pos = 0; - - while (1) { + uint32_t offset = 0; - int pos = get_string(((const char*)pkt->data.ptr) + read_pos, pkt->data.size - read_pos, &key); - if (pos == 0) return; + while (offset < pkt->data.size) { + char* nameptr = (char*)pkt->data.ptr + offset; + uint32_t valoffset; + char* valptr; - read_pos += pos; + if (*nameptr == '\0') + break; /* found packet terminator */ + valoffset = offset + strlen(nameptr) + 1; + if (valoffset >= pkt->data.size) + break; /* missing value, will complain below */ + valptr = (char*)pkt->data.ptr + valoffset; - pos = get_string(((const char*)pkt->data.ptr) + read_pos, pkt->data.size - read_pos, &val); - if (pos == 0) return; + (*myds)->myconn->conn_params.set_value(nameptr, valptr); - read_pos += pos; + offset = valoffset + strlen(valptr) + 1; + } - //slog_debug(server, "S: param: %s = %s", key, val); - (*myds)->myconn->conn_params.set_value(key, val); + if (offset != pkt->data.size - 1) { + proxy_error("Malformed startup packet was received from client %s:%d\n", (*myds)->addr.addr, (*myds)->addr.port); + return false; } + + return true; + } bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len, bool& ssl_request) { @@ -677,13 +685,19 @@ bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len return false; } - load_conn_parameters(&hdr, true); + if (!load_conn_parameters(&hdr)) { + proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. malformed startup packet.\n", (*myds)->sess, (*myds)); + generate_error_packet(true, false, "invalid startup packet layout: expected terminator as last byte", + PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); + return false; + } const unsigned char* user = (unsigned char*)(*myds)->myconn->conn_params.get_value(PG_USER); if (!user || *user == '\0') { - proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. no username supplied.\n", (*myds), (*myds)->sess); - generate_error_packet(true, false, "no username supplied", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); + proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. no username supplied.\n", (*myds)->sess, (*myds)); + generate_error_packet(true, false, "no PostgreSQL user name specified in startup packet", + PGSQL_ERROR_CODES::ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION, true); return false; } @@ -769,6 +783,7 @@ std::vector> PgSQL_Protocol::parse_options(c // Add key-value pair to the list if (!key.empty()) { + std::transform(key.begin(), key.end(), key.begin(), ::tolower); options_list.emplace_back(std::move(key), std::move(value)); } } @@ -808,26 +823,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* return EXECUTION_STATE::FAILED; } - // set charset but first verify - const char* charset = (*myds)->myconn->conn_params.get_value(PG_CLIENT_ENCODING); - - // if client does not provide client_encoding, PostgreSQL uses the default client encoding. - // We need to save the default client encoding to send it to the client in case client doesn't provide one. - if (charset == NULL) charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING]; - - assert(charset); - - int charset_encoding = (*myds)->myconn->char_to_encoding(charset); - if (charset_encoding == -1) { - proxy_error("Cannot find charset [%s]\n", charset); - proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "DS=%p , Session=%p , charset='%s'. Client charset not supported.\n", (*myds), (*myds)->sess, charset); - generate_error_packet(true, false, "Client charset not supported", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); - return EXECUTION_STATE::FAILED; - } - - (*myds)->sess->default_charset = charset_encoding; - user = (char*)(*myds)->myconn->conn_params.get_value(PG_USER); if (!user || *user == '\0') { @@ -1058,92 +1054,190 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* userinfo->username = strdup((const char*)user); userinfo->password = strdup((const char*)password); - const char* db = (*myds)->myconn->conn_params.get_value(PG_DATABASE); - userinfo->set_dbname(db ? db : userinfo->username); + std::vector> parameters; + std::vector> options_list; - assert(sess); - assert(sess->client_myds); + parameters.reserve((*myds)->myconn->conn_params.connection_parameters.size()); - PgSQL_Connection* myconn = sess->client_myds->myconn; - assert(myconn); - myconn->set_charset(charset); - sess->set_default_session_variable(PGSQL_CLIENT_ENCODING, charset); + /* Note: Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. + PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, + but the error is encountered when executing a query. + This is behaviour is intentional, as newer PostgreSQL versions may introduce parameters that ProxySQL is not yet aware of. + */ + // New implementation + for (const auto& [param_name, param_val] : (*myds)->myconn->conn_params.connection_parameters) { + std::string param_name_lowercase(param_name); + std::transform(param_name_lowercase.cbegin(), param_name_lowercase.cend(), param_name_lowercase.begin(), ::tolower); - // get datestyle from connection parameters - const char* datestyle_tmp = (*myds)->myconn->conn_params.get_value(PG_DATESTYLE); - std::string datestyle = datestyle_tmp ? datestyle_tmp : ""; + // check if parameter is part of connection-level parameters + auto itr = param_name_map.find(param_name_lowercase.c_str()); + if (itr != param_name_map.end()) { - if (datestyle.empty()) { - // No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable. - datestyle = pgsql_thread___default_variables[PGSQL_DATESTYLE]; - } else { - PgSQL_DateStyle_t datestyle_parsed = PgSQL_DateStyle_Util::parse_datestyle(datestyle); + if (param_name_lowercase.compare("user") == 0 || param_name_lowercase.compare("password") == 0) { + continue; + } - // If DateStyle provided in the connection parameters is incomplete, the missing parts will be taken from the default DateStyle. - if (datestyle_parsed.format == DATESTYLE_FORMAT_NONE || datestyle_parsed.order == DATESTYLE_ORDER_NONE) { - PgSQL_DateStyle_t datestyle_default = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); - datestyle = PgSQL_DateStyle_Util::datestyle_to_string(datestyle_parsed, datestyle_default); - } - } + bool is_validation_success = false; + const Param_Name_Validation* validation = itr->second; + + if (validation != nullptr && validation->accepted_values) { + const char** accepted_value = validation->accepted_values; + while (*accepted_value) { + if (strcmp(param_val.c_str(), *accepted_value) == 0) { + is_validation_success = true; + break; + } + accepted_value++; + } + } else { + is_validation_success = true; + } - assert(datestyle.empty() == false); + if (is_validation_success == false) { + char* m = NULL; + char* errmsg = NULL; + proxy_error("invalid value for parameter \"%s\": \"%s\"\n", param_name.c_str(), param_val.c_str()); + m = (char*)"invalid value for parameter \"%s\": \"%s\""; + errmsg = (char*)malloc(param_val.length() + param_name.length() + strlen(m)); + sprintf(errmsg, m, param_name.c_str(), param_val.c_str()); + generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); + free(errmsg); + ret = EXECUTION_STATE::FAILED; + + // freeing userinfo->username and userinfo->password to prevent invalid password error generation. + free(userinfo->username); + free(userinfo->password); + userinfo->username = strdup(""); + userinfo->password = strdup(""); + // + goto __exit_process_pkt_handshake_response; + } - if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str())) { - // change current datestyle - sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle); - sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str()); + if (param_name_lowercase.compare("database") == 0) { + userinfo->set_dbname(param_val.empty() ? user : param_val.c_str()); + } else if (param_name_lowercase.compare("options") == 0) { + options_list = parse_options(param_val.c_str()); + } + } else { + // session parameters/variables? + parameters.push_back(std::make_pair(param_name_lowercase, param_val)); + } } - // get timezone from connection parameters - const char* timezone = (*myds)->myconn->conn_params.get_value(PG_TIMEZONE); - if (timezone == NULL) - timezone = pgsql_thread___default_variables[PGSQL_TIMEZONE]; - - pgsql_variables.client_set_value(sess, PGSQL_TIMEZONE, timezone); - sess->set_default_session_variable(PGSQL_TIMEZONE, timezone); - - const char* intervalstyle = pgsql_thread___default_variables[PGSQL_INTERVALSTYLE]; - if (intervalstyle) { - pgsql_variables.client_set_value(sess, PGSQL_INTERVALSTYLE, intervalstyle); - sess->set_default_session_variable(PGSQL_INTERVALSTYLE, intervalstyle); + if (userinfo->dbname == nullptr) { + userinfo->set_dbname(user); } - // get standard_conforming_strings from connection parameters - const char* standard_conforming_strings = pgsql_thread___default_variables[PGSQL_STANDARD_CONFORMING_STRINGS]; - if (standard_conforming_strings) { - pgsql_variables.client_set_value(sess, PGSQL_STANDARD_CONFORMING_STRINGS, standard_conforming_strings); - sess->set_default_session_variable(PGSQL_STANDARD_CONFORMING_STRINGS, standard_conforming_strings); + // Merge options with parameters. + // Options are processed first, followed by connection parameters. + // If a parameter is specified in both, the connection parameter takes precedence + // and overwrites the previosly set value. + if (options_list.empty() == false) { + options_list.reserve(parameters.size() + options_list.size()); + options_list.insert(options_list.end(), std::make_move_iterator(parameters.begin()), std::make_move_iterator(parameters.end())); + parameters = std::move(options_list); } - const char* options = (*myds)->myconn->conn_params.get_value(PG_OPTIONS); + // assign default datestyle to current datestyle. + // This is needed by PgSQL_DateStyle_Util::parse_datestyle + sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); - auto options_list = parse_options(options); + for (const auto&[param_key, param_val] : parameters) { - for (auto& option : options_list) { int idx = PGSQL_NAME_LAST_HIGH_WM; for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - if (i == PGSQL_NAME_LAST_LOW_WM) + if (i == PGSQL_NAME_LAST_LOW_WM) continue; - if (variable_name_exists(pgsql_tracked_variables[i], option.first.c_str()) == true) { + if (strncmp(param_key.c_str(), pgsql_tracked_variables[i].set_variable_name, + strlen(pgsql_tracked_variables[i].set_variable_name)) == 0) { idx = i; break; - } + } } if (idx != PGSQL_NAME_LAST_HIGH_WM) { - pgsql_variables.client_set_value(sess, idx, option.second.c_str()); - sess->set_default_session_variable((enum pgsql_variable_name)idx, option.second.c_str()); + std::string value_copy = param_val; + + char* transformed_value = nullptr; + if (pgsql_tracked_variables[idx].validator && pgsql_tracked_variables[idx].validator->validate && + ( + *pgsql_tracked_variables[idx].validator->validate)( + value_copy.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false + ) { + char* m = NULL; + char* errmsg = NULL; + proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str()); + m = (char*)"invalid value for parameter \"%s\": \"%s\""; + errmsg = (char*)malloc(value_copy.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m)); + sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str()); + generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true); + free(errmsg); + ret = EXECUTION_STATE::FAILED; + + // freeing userinfo->username and userinfo->password to prevent invalid password error generation. + free(userinfo->username); + free(userinfo->password); + userinfo->username = strdup(""); + userinfo->password = strdup(""); + // + goto __exit_process_pkt_handshake_response; + } + + if (transformed_value) { + value_copy = transformed_value; + free(transformed_value); + } + + if (idx == PGSQL_DATESTYLE) { + // get datestyle from connection parameters + std::string datestyle = value_copy.empty() == false ? value_copy : ""; + + if (datestyle.empty()) { + // No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable. + datestyle = pgsql_thread___default_variables[PGSQL_DATESTYLE]; + } + else { + PgSQL_DateStyle_t datestyle_parsed = PgSQL_DateStyle_Util::parse_datestyle(datestyle); + + // If DateStyle provided in the connection parameters is incomplete, the missing parts will be taken from the default DateStyle. + if (datestyle_parsed.format == DATESTYLE_FORMAT_NONE || datestyle_parsed.order == DATESTYLE_ORDER_NONE) { + PgSQL_DateStyle_t datestyle_default = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]); + datestyle = PgSQL_DateStyle_Util::datestyle_to_string(datestyle_parsed, datestyle_default); + } + } + + assert(datestyle.empty() == false); + + if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str(), false)) { + // change current datestyle + sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle); + sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str()); + } + } else { + pgsql_variables.client_set_value(sess, idx, value_copy.c_str(), false); + sess->set_default_session_variable((enum pgsql_variable_name)idx, value_copy.c_str()); + } } else { - const char* val = option.second.c_str(); - const char* escaped_str = escape_string_backslash_spaces(val); - sess->untracked_option_parameters = "-c " + option.first + "=" + escaped_str + " "; - if (escaped_str != val) + // parameter provided is not part of the tracked variables. Will lock on hostgroup on next query. + const char* val_cstr = param_val.c_str(); + proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", param_key.c_str(), val_cstr); + const char* escaped_str = escape_string_backslash_spaces(val_cstr); + sess->untracked_option_parameters = "-c " + param_key + "=" + escaped_str + " "; + if (escaped_str != val_cstr) free((char*)escaped_str); } } - //if (charset) - // (*myds)->sess->default_charset = charset; + // fill all crtical variables with default values, if not set by client + for (int i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { + if (pgsql_variables.client_get_hash(sess, i) != 0) + continue; + const char* val = pgsql_thread___default_variables[i]; + pgsql_variables.client_set_value(sess, i, val, false); + sess->set_default_session_variable((pgsql_variable_name)i, val); + } + + sess->client_myds->myconn->reorder_dynamic_variables_idx(); } else { // we always duplicate username and password, or crashes happen diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index cb88f8335..c09dce42c 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -81,7 +81,8 @@ static inline char is_normal_char(char c) { } */ -static const std::array pgsql_critical_variables = { +static const std::array pgsql_critical_variables = { + "client_encoding", "datestyle", "intervalstyle", "standard_conforming_strings", @@ -825,10 +826,9 @@ void PgSQL_Session::generate_proxysql_internal_session_json(json& j) { //j["conn"]["ps"]["client_stmt_to_global_ids"] = client_myds->myconn->local_stmts->client_stmt_to_global_ids; const PgSQL_Conn_Param& conn_params = client_myds->myconn->conn_params; - for (size_t i = 0; i < conn_params.param_set.size(); i++) { - if (conn_params.param_value[conn_params.param_set[i]] != NULL) { - j["client"]["conn"]["connection_options"][PgSQL_Param_Name_Str[conn_params.param_set[i]]] = conn_params.param_value[conn_params.param_set[i]]; - } + + for (const auto& [key, val] : conn_params.connection_parameters) { + j["client"]["conn"]["connection_options"][key.c_str()] = val.c_str(); } } } @@ -1508,98 +1508,6 @@ bool PgSQL_Session::handler_again___status_SETTING_INIT_CONNECT(int* _rc) { return ret; } -bool PgSQL_Session::handler_again___status_CHANGING_CHARSET(int* _rc) { - assert(mybe->server_myds->myconn); - PgSQL_Data_Stream* myds = mybe->server_myds; - PgSQL_Connection* myconn = myds->myconn; - - /* Validate that server can support client's charset */ - if (!validate_charset(this, PGSQL_CLIENT_ENCODING, *_rc)) { - return false; - } - - myds->DSS = STATE_MARIADB_QUERY; - enum session_status st = status; - if (myds->mypolls == NULL) { - thread->mypolls.add(POLLIN | POLLOUT, mybe->server_myds->fd, mybe->server_myds, thread->curtime); - } - - std::string charset = pgsql_variables.client_get_value(this, PGSQL_CLIENT_ENCODING); - std::string query = "SET CLIENT_ENCODING TO '" + charset + "'"; - - int rc = myconn->async_send_simple_command(myds->revents, (char*)query.c_str(),query.length()); - - if (rc == 0) { - __sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1); - myds->DSS = STATE_MARIADB_GENERIC; - st = previous_status.top(); - previous_status.pop(); - NEXT_IMMEDIATE_NEW(st); - } else { - if (rc == -1) { - // the command failed - const bool error_present = myconn->is_error_present(); - PgHGM->p_update_pgsql_error_counter( - p_pgsql_error_type::pgsql, - myconn->parent->myhgc->hid, - myconn->parent->address, - myconn->parent->port, - (error_present ? 9999 : ER_PROXYSQL_OFFLINE_SRV) // TOFIX: 9999 is a placeholder for the actual error code - ); - if (error_present == false || (error_present == true && myconn->is_connection_in_reusable_state() == false)) { - bool retry_conn = false; - // client error, serious - proxy_error( - "Client trying to set a charset (%s) not supported by backend (%s:%d). Changing it to %s\n", - charset.c_str(), myconn->parent->address, myconn->parent->port, pgsql_tracked_variables[PGSQL_CLIENT_ENCODING].default_value - ); - detected_broken_connection(__FILE__, __LINE__, __func__, "during SET CLIENT_ENCODING", myconn); - if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) { - retry_conn = true; - } - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd = 0; - if (retry_conn) { - myds->DSS = STATE_NOT_INITIALIZED; - NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); - } - *_rc = -1; - return false; - } else { - proxy_warning("Error during SET CLIENT_ENCODING: %s\n", myconn->get_error_code_with_message().c_str()); - // we won't go back to PROCESSING_QUERY - st = previous_status.top(); - previous_status.pop(); - client_myds->myprot.generate_error_packet(true, true, myconn->get_error_message().c_str(), myconn->get_error_code(), false); - myds->destroy_MySQL_Connection_From_Pool(true); - myds->fd = 0; - RequestEnd(myds); //fix bug #682 - } - } else { - if (rc == -2) { - bool retry_conn = false; - proxy_error("Timeout during SET CLIENT_ENCODING on %s , %d\n", myconn->parent->address, myconn->parent->port); - PgHGM->p_update_pgsql_error_counter(p_pgsql_error_type::pgsql, myconn->parent->myhgc->hid, myconn->parent->address, myconn->parent->port, ER_PROXYSQL_CHANGE_USER_TIMEOUT); - if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) { - retry_conn = true; - } - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd = 0; - if (retry_conn) { - myds->DSS = STATE_NOT_INITIALIZED; - NEXT_IMMEDIATE_NEW(CONNECTING_SERVER); - } - *_rc = -1; - return false; - } - else { - // rc==1 , nothing to do for now - } - } - } - return false; -} - bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, const char* var_name, const char* var_value, bool no_quote, bool set_transaction) { bool ret = false; @@ -1650,6 +1558,9 @@ bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, co query = NULL; } if (rc == 0) { + if (strncasecmp(var_name, "client_encoding", sizeof("client_encoding")-1) == 0) { + __sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1); + } myds->revents |= POLLOUT; // we also set again POLLOUT to send a query immediately! myds->DSS = STATE_MARIADB_GENERIC; st = previous_status.top(); @@ -3306,11 +3217,6 @@ handler_again: } if (locked_on_hostgroup == -1 || locked_on_hostgroup_and_all_variables_set == false) { - // verify charset - if (verify_set_names(this)) { - goto handler_again; - } - for (auto i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) { auto client_hash = client_myds->myconn->var_hash[i]; #ifdef DEBUG @@ -3596,9 +3502,6 @@ bool PgSQL_Session::handler_again___multiple_statuses(int* rc) { //case SETTING_INIT_CONNECT: // ret = handler_again___status_SETTING_INIT_CONNECT(rc); break; - case SETTING_CHARSET: - ret = handler_again___status_CHANGING_CHARSET(rc); - break; default: break; } @@ -3926,7 +3829,7 @@ void PgSQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE( else { client_addr = strdup((char*)""); } - if (client_myds->myconn->userinfo->username) { + if (client_myds->myconn->userinfo->username && client_myds->myconn->userinfo->username[0] != '\0') { char* _s = (char*)malloc(strlen(client_myds->myconn->userinfo->username) + 100 + strlen(client_addr)); uint8_t _pid = 2; if (client_myds->switching_auth_stage) _pid += 2; @@ -4296,7 +4199,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C int idx = PGSQL_NAME_LAST_HIGH_WM; for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { if (variable_name_exists(pgsql_tracked_variables[i], var.c_str()) == true) { - idx = pgsql_tracked_variables[i].idx; + idx = i; break; } } @@ -4347,8 +4250,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", var.c_str(), value1.c_str()); uint32_t var_hash_int = SpookyHash::Hash32(value1.c_str(), value1.length(), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str())) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value1.c_str(), true)) { return false; } if (idx == PGSQL_DATESTYLE) { @@ -4628,7 +4531,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C l_free(pkt->size, pkt->ptr); return true; - }*/ else if (match_regexes && match_regexes[3]->match(dig)) { + } else if (match_regexes && match_regexes[3]->match(dig)) { std::vector> param_status; PgSQL_Set_Stmt_Parser parser(nq); std::string charset = parser.parse_character_set(); @@ -4680,7 +4583,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C RequestEnd(NULL); l_free(pkt->size, pkt->ptr); return true; - } else { + }*/ else { unable_to_parse_set_statement(lock_hostgroup); return false; } @@ -4719,26 +4622,41 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C if (strncasecmp(nq.c_str(), "ALL", 3) == 0) { - for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { + for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) { - if (i == PGSQL_NAME_LAST_LOW_WM) - continue; + const char* name = pgsql_tracked_variables[idx].set_variable_name; + const char* value = get_default_session_variable((enum pgsql_variable_name)idx); - const char* name = pgsql_tracked_variables[i].set_variable_name; - const char* value = get_default_session_variable((enum pgsql_variable_name)i); - proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[i].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[i].idx, value)) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value, false)) { + return false; + } + if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { + param_status.push_back(std::make_pair(name, value)); + } + } + } + + for (int idx : client_myds->myconn->dynamic_variables_idx) { + assert(idx < PGSQL_NAME_LAST_HIGH_WM); + const char* name = pgsql_tracked_variables[idx].set_variable_name; + const char* value = get_default_session_variable((enum pgsql_variable_name)idx); + proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value); + uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value, false)) { return false; } - if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) { + if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { param_status.push_back(std::make_pair(name, value)); } } } + client_myds->myconn->reorder_dynamic_variables_idx(); + } else if (std::find(pgsql_variables.ignore_vars.begin(), pgsql_variables.ignore_vars.end(), nq) != pgsql_variables.ignore_vars.end()) { // this is a variable we parse but ignore #ifdef DEBUG @@ -4763,8 +4681,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C const char* value = get_default_session_variable((enum pgsql_variable_name)idx); uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10); - if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) { - if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value)) { + if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) { + if (!pgsql_variables.client_set_value(this, idx, value, true)) { return false; } if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) { diff --git a/lib/PgSQL_Set_Stmt_Parser.cpp b/lib/PgSQL_Set_Stmt_Parser.cpp index 42218f8f2..3f92ef12a 100644 --- a/lib/PgSQL_Set_Stmt_Parser.cpp +++ b/lib/PgSQL_Set_Stmt_Parser.cpp @@ -202,6 +202,7 @@ std::map> PgSQL_Set_Stmt_Parser::parse2() { return result; } +#if 0 std::string PgSQL_Set_Stmt_Parser::parse_character_set() { #ifdef DEBUG proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 4, "Parsing query %s\n", query.c_str()); @@ -224,7 +225,7 @@ std::string PgSQL_Set_Stmt_Parser::parse_character_set() { delete opt2; return value3; } - +#endif std::string PgSQL_Set_Stmt_Parser::remove_comments(const std::string& q) { std::string result = ""; bool in_multiline_comment = false; diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index 83473651c..7285436c6 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -10,58 +10,40 @@ #include - -/*static inline char is_digit(char c) { - if(c >= '0' && c <= '9') - return 1; - return 0; -}*/ - pgsql_verify_var PgSQL_Variables::verifiers[PGSQL_NAME_LAST_HIGH_WM]; pgsql_update_var PgSQL_Variables::updaters[PGSQL_NAME_LAST_HIGH_WM]; - PgSQL_Variables::PgSQL_Variables() { // add here all the variables we want proxysql to recognize, but ignore ignore_vars.push_back("application_name"); - //ignore_vars.push_back("interactive_timeout"); - //ignore_vars.push_back("wait_timeout"); - //ignore_vars.push_back("net_read_timeout"); - //ignore_vars.push_back("net_write_timeout"); - //ignore_vars.push_back("net_buffer_length"); - //ignore_vars.push_back("read_buffer_size"); - //ignore_vars.push_back("read_rnd_buffer_size"); // NOTE: This variable has been temporarily ignored. Check issues #3442 and #3441. //ignore_vars.push_back("session_track_schema"); variables_regexp = ""; + + /* + NOTE: + make special ATTENTION that the order in pgsql_variable_name + and pgsqll_tracked_variables[] is THE SAME + NOTE: + PgSQL_Variables::PgSQL_Variables() has a built-in check to make sure that the order is correct, + and that variables are in alphabetical order + */ for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { + //Array index and enum value (idx) should be same + assert(i == pgsql_tracked_variables[i].idx); + + if (i > PGSQL_NAME_LAST_LOW_WM + 1) { + assert(strcmp(pgsql_tracked_variables[i].set_variable_name, pgsql_tracked_variables[i - 1].set_variable_name) > 0); + } + // we initialized all the internal_variable_name if set to NULL if (pgsql_tracked_variables[i].internal_variable_name == NULL) { pgsql_tracked_variables[i].internal_variable_name = pgsql_tracked_variables[i].set_variable_name; } - } -/* - NOTE: - make special ATTENTION that the order in pgsql_variable_name - and pgsqll_tracked_variables[] is THE SAME - NOTE: - PgSQL_Variables::PgSQL_Variables() has a built-in check to make sure that the order is correct, - and that variables are in alphabetical order -*/ - for (int i = PGSQL_NAME_LAST_LOW_WM; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - assert(i == pgsql_tracked_variables[i].idx); - if (i > PGSQL_NAME_LAST_LOW_WM+1) { - assert(strcmp(pgsql_tracked_variables[i].set_variable_name, pgsql_tracked_variables[i-1].set_variable_name) > 0); - } - } - for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { - if (i == PGSQL_CLIENT_ENCODING) { - PgSQL_Variables::updaters[i] = NULL; - PgSQL_Variables::verifiers[i] = NULL; - } else { - PgSQL_Variables::verifiers[i] = verify_server_variable; - PgSQL_Variables::updaters[i] = update_server_variable; - } + + PgSQL_Variables::verifiers[i] = verify_server_variable; + PgSQL_Variables::updaters[i] = update_server_variable; + if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) { variables_regexp += pgsql_tracked_variables[i].set_variable_name; variables_regexp += "|"; @@ -74,6 +56,7 @@ PgSQL_Variables::PgSQL_Variables() { } } } + for (std::vector::iterator it=ignore_vars.begin(); it != ignore_vars.end(); it++) { variables_regexp += *it; variables_regexp += "|"; @@ -135,7 +118,7 @@ void PgSQL_Variables::server_set_hash_and_value(PgSQL_Session* session, int idx, session->mybe->server_myds->myconn->variables[idx].value = strdup(value); } -bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value) { +bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx) { if (!session || !session->client_myds || !session->client_myds->myconn) { proxy_warning("Session validation failed\n"); return false; @@ -147,7 +130,7 @@ bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const st } session->client_myds->myconn->variables[idx].value = strdup(value.c_str()); - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx session->client_myds->myconn->reorder_dynamic_variables_idx(); } @@ -168,7 +151,7 @@ uint32_t PgSQL_Variables::client_get_hash(PgSQL_Session* session, int idx) const return session->client_myds->myconn->var_hash[idx]; } -void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value) { +void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx) { assert(session); assert(session->mybe); assert(session->mybe->server_myds); @@ -181,7 +164,7 @@ void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const ch } session->mybe->server_myds->myconn->variables[idx].value = strdup(value); - if (idx > PGSQL_NAME_LAST_LOW_WM) { + if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) { // we now regererate dynamic_variables_idx session->mybe->server_myds->myconn->reorder_dynamic_variables_idx(); } @@ -253,98 +236,6 @@ bool PgSQL_Variables::verify_variable(PgSQL_Session* session, int idx) const { return ret; } -bool validate_charset(PgSQL_Session* session, int idx, int &_rc) { - /*if (idx == PGSQL_CLIENT_ENCODING || idx == PGSQL_SET_NAMES) { - PgSQL_Data_Stream *myds = session->mybe->server_myds; - PgSQL_Connection *myconn = myds->myconn; - char msg[128]; - const MARIADB_CHARSET_INFO *ci = NULL; - const char* replace_collation = NULL; - const char* not_supported_collation = NULL; - unsigned int replace_collation_nr = 0; - std::stringstream ss; - int charset = atoi(pgsql_variables.client_get_value(session, idx)); - if (charset >= 255 && myconn->pgsql->server_version[0] != '8') { - switch(pgsql_thread___handle_unknown_charset) { - case HANDLE_UNKNOWN_CHARSET__DISCONNECT_CLIENT: - snprintf(msg,sizeof(msg),"Can't initialize character set %s", pgsql_variables.client_get_value(session, idx)); - proxy_error("Can't initialize character set on %s, %d: Error %d (%s). Closing client connection %s:%d.\n", - myconn->parent->address, myconn->parent->port, 2019, msg, session->client_myds->addr.addr, session->client_myds->addr.port); - myds->destroy_MySQL_Connection_From_Pool(false); - myds->fd=0; - _rc=-1; - return false; - case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT_VERBOSE: - ci = proxysql_find_charset_nr(charset); - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot find character set [%s]\n", pgsql_variables.client_get_value(session, idx)); - assert(0); - // LCOV_EXCL_STOP - } - not_supported_collation = ci->name; - - if (idx == SQL_COLLATION_CONNECTION) { - ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]); - } else { - if (pgsql_thread___default_variables[idx]) { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]); - } else { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]); - } - } - - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot find character set [%s]\n", pgsql_thread___default_variables[idx]); - assert(0); - // LCOV_EXCL_STOP - } - replace_collation = ci->name; - replace_collation_nr = ci->nr; - - proxy_warning("Server doesn't support collation (%s) %s. Replacing it with the configured default (%d) %s. Client %s:%d\n", - pgsql_variables.client_get_value(session, idx), not_supported_collation, - replace_collation_nr, replace_collation, session->client_myds->addr.addr, session->client_myds->addr.port); - - ss << replace_collation_nr; - pgsql_variables.client_set_value(session, idx, ss.str()); - _rc=0; - return true; - case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT: - if (idx == SQL_COLLATION_CONNECTION) { - ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]); - } else { - if (pgsql_thread___default_variables[idx]) { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]); - } else { - ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]); - } - } - - if (!ci) { - // LCOV_EXCL_START - proxy_error("Cannot filnd charset [%s]\n", pgsql_thread___default_variables[idx]); - assert(0); - // LCOV_EXCL_STOP - } - replace_collation_nr = ci->nr; - - ss << replace_collation_nr; - pgsql_variables.client_set_value(session, idx, ss.str()); - _rc=0; - return true; - default: - proxy_error("Wrong configuration of the handle_unknown_charset\n"); - _rc=-1; - return false; - } - } - }*/ - _rc=0; - return true; -} - bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) { bool no_quote = true; if (IS_PGTRACKED_VAR_OPTION_SET_QUOTE(pgsql_tracked_variables[idx])) no_quote = false; @@ -352,49 +243,12 @@ bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) { const char *set_var_name = pgsql_tracked_variables[idx].set_variable_name; bool ret = false; - if (!validate_charset(session, idx, _rc)) { - return false; - } - const char* value = pgsql_variables.client_get_value(session, idx); - pgsql_variables.server_set_value(session, idx, value); + pgsql_variables.server_set_value(session, idx, value, true); ret = session->handler_again___status_SETTING_GENERIC_VARIABLE(&_rc, set_var_name, value, no_quote, st); return ret; } -bool verify_set_names(PgSQL_Session* session) { - uint32_t client_charset_hash = pgsql_variables.client_get_hash(session, PGSQL_CLIENT_ENCODING); - if (client_charset_hash == 0) - return false; - - if (client_charset_hash != pgsql_variables.server_get_hash(session, PGSQL_CLIENT_ENCODING)) { - switch(session->status) { // this switch can be replaced with a simple previous_status.push(status), but it is here for readibility - case PROCESSING_QUERY: - session->previous_status.push(PROCESSING_QUERY); - break; - /* - case PROCESSING_STMT_PREPARE: - session->previous_status.push(PROCESSING_STMT_PREPARE); - break; - case PROCESSING_STMT_EXECUTE: - session->previous_status.push(PROCESSING_STMT_EXECUTE); - break; - */ - default: - // LCOV_EXCL_START - proxy_error("Wrong status %d\n", session->status); - assert(0); - break; - // LCOV_EXCL_STOP - } - session->set_status(SETTING_CHARSET); - const char* client_charset_value = pgsql_variables.client_get_value(session, PGSQL_CLIENT_ENCODING); - pgsql_variables.server_set_hash_and_value(session, PGSQL_CLIENT_ENCODING, client_charset_value, client_charset_hash); - return true; - } - return false; -} - inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash) { if (client_hash && client_hash != server_hash) { // Edge case for set charset command, because we do not know database character set @@ -427,7 +281,7 @@ inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t cli // LCOV_EXCL_STOP } session->set_status(pgsql_tracked_variables[idx].status); - pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx)); + pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx), true); return true; } return false; diff --git a/lib/PgSQL_Variables_Validator.cpp b/lib/PgSQL_Variables_Validator.cpp index f1b6b7edf..3286a72a8 100644 --- a/lib/PgSQL_Variables_Validator.cpp +++ b/lib/PgSQL_Variables_Validator.cpp @@ -422,6 +422,31 @@ bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const pa return true; } +bool pgsql_variable_validate_client_encoding(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) { + (void)params; + if (transformed_value) *transformed_value = nullptr; + + int charset_encoding = PgSQL_Connection::char_to_encoding(value); + + if (charset_encoding == -1) { + return false; + } + + if (transformed_value) { + *transformed_value = strdup(value); + + if (*transformed_value) { // Ensure strdup succeeded + char* tmp_val = *transformed_value; + + while (*tmp_val) { + *tmp_val = toupper((unsigned char)*tmp_val); + tmp_val++; + } + } + } + + return true; +} const pgsql_variable_validator pgsql_variable_validator_bool = { .type = VARIABLE_TYPE_BOOL, @@ -482,3 +507,9 @@ const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = { .uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB } }; + +const pgsql_variable_validator pgsql_variable_validator_client_encoding = { + .type = VARIABLE_TYPE_CLIENT_ENCODING, + .validate = &pgsql_variable_validate_client_encoding, + .params = {} +}; diff --git a/test/tap/tests/pgsql-connection_parameters_test-t.cpp b/test/tap/tests/pgsql-connection_parameters_test-t.cpp new file mode 100644 index 000000000..063a75e77 --- /dev/null +++ b/test/tap/tests/pgsql-connection_parameters_test-t.cpp @@ -0,0 +1,909 @@ +/** + * @file pgsql-connection_parameters_test-t-t.cpp + * @brief This TAP test validates if connection parameters are correctly handled by ProxySQL. + * + * Note: + * - Using libpq to test ProxySQL's handling of undocumented parameters isn't possible, as libpq enforces a strict subset of PostgreSQL + * connection parameters as per the official documentation, rejecting any undocumented parameters. However, actual + * PostgreSQL servers accept additional parameters (e.g., extra_float_digits) and apply them at the connection/session level. + * To test this behavior, a raw socket is used to connect to a ProxySQL server and send custom built messages to communicate + * with ProxySQL. It currently works with plain text password authentication, without ssl support. + * + * - Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior. + * PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds, + * but the error is encountered when executing a query. + * This is behaviour is intentional, as newer PostgreSQL versions may introduce new parameters that ProxySQL is not yet aware of. + * + * + */ + +#include +#include +#include +#include +#include +#include +#include "libpq-fe.h" +#include "command_line.h" +#include "tap.h" +#include "utils.h" + +CommandLine cl; + +using PGConnPtr = std::unique_ptr; + +enum ConnType { + ADMIN, + BACKEND +}; + +PGConnPtr createNewConnection(ConnType conn_type, const std::string& parameters = "", bool with_ssl = false) { + + const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host; + int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port; + const char* username = (conn_type == BACKEND) ? cl.pgsql_username : cl.admin_username; + const char* password = (conn_type == BACKEND) ? cl.pgsql_password : cl.admin_password; + + std::stringstream ss; + + ss << "host=" << host << " port=" << port; + ss << " user=" << username << " password=" << password; + ss << (with_ssl ? " sslmode=require" : " sslmode=disable"); + + if (parameters.empty() == false) { + ss << " " << parameters; + } + + PGconn* conn = PQconnectdb(ss.str().c_str()); + if (PQstatus(conn) != CONNECTION_OK) { + fprintf(stderr, "Connection failed to '%s': %s\n", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn)); + PQfinish(conn); + return PGConnPtr(nullptr, &PQfinish); + } + return PGConnPtr(conn, &PQfinish); +} + +bool executeQueries(PGconn* conn, const std::vector& queries) { + auto fnResultType = [](const char* query) -> int { + const char* fs = strchr(query, ' '); + size_t qtlen = strlen(query); + if (fs != NULL) { + qtlen = (fs - query) + 1; + } + char buf[qtlen]; + memcpy(buf, query, qtlen - 1); + buf[qtlen - 1] = 0; + + if (strncasecmp(buf, "SELECT", sizeof("SELECT") - 1) == 0) { + return PGRES_TUPLES_OK; + } + else if (strncasecmp(buf, "COPY", sizeof("COPY") - 1) == 0) { + return PGRES_COPY_OUT; + } + + return PGRES_COMMAND_OK; + }; + + + for (const auto& query : queries) { + diag("Running: %s", query.c_str()); + PGresult* res = PQexec(conn, query.c_str()); + bool success = PQresultStatus(res) == fnResultType(query.c_str()); + if (!success) { + fprintf(stderr, "Failed to execute query '%s': %s\n", + query.c_str(), PQerrorMessage(conn)); + PQclear(res); + return false; + } + PQclear(res); + } + return true; +} + +struct parameter { + std::string name; + std::string value; +}; + +struct parameter_test { + std::vector set_admin_vars; // Admin variables to set + std::vector conn_params; // Parameters in startup message + std::vector conn_options; // Options (-c flags) in startup + std::vector set_commands; // SET commands after connection + std::vector expected; // Expected SHOW values + bool reset_after; // Whether to RESET parameters + bool expect_failure; // If connection/query should fail +}; + +/** + * @struct MyPGresult + * @brief Represents the result of a PostgreSQL query. + * + * This structure holds the columns, rows, status, and error message of a PostgreSQL query result. + */ +struct MyPGresult { + std::vector columns; ///< Column names of the result set. + std::vector> rows; ///< Rows of the result set. + std::string status; ///< Status of the query execution. + std::string error; ///< Error message if the query failed. +}; + +struct PgSQLResponse { + char type; ///< Type of the response message. + int32_t length; ///< Length of the response message. + std::vector data; ///< Data of the response message. +}; + +void add_param(std::vector& msg, std::string_view key, std::string_view value) { + msg.insert(msg.end(), key.begin(), key.end()); + msg.push_back('\0'); + msg.insert(msg.end(), value.begin(), value.end()); + msg.push_back('\0'); +} + +// Function to receive data from socket +bool recv_data(int sock, void* buffer, size_t len) { + + if (recv(sock, buffer, len, 0) <= 0) { + fprintf(stderr, "Error receiving data\n"); + return false; + } + return true; +} + +// Function to send data over socket +bool send_data(int sock, const void* data, size_t len) { + if (send(sock, data, len, 0) != len) { + fprintf(stderr, "Error sending data\n"); + return false; + } + return true; +} + +/** + * @brief Builds a startup message for PostgreSQL connection. + * + * This function constructs a startup message for PostgreSQL connection using the provided + * connection parameters. + * + * @param parameters A vector of key-value pairs representing the connection parameters. + * @return A vector of characters representing the constructed startup message. + */ +std::vector build_startup_message(const std::vector& parameters) { + // Build startup message + std::vector startup_body; + int32_t protocol = htonl(0x00030000); // Protocol 3.0 + startup_body.insert(startup_body.end(), (char*)&protocol, (char*)&protocol + 4); + + // Add connection parameters + for (const auto& param : parameters) { + add_param(startup_body, param.name, param.value); + } + startup_body.push_back('\0'); + + // Prepend message length + std::vector startup_msg; + int32_t len = htonl(startup_body.size() + 4); + startup_msg.insert(startup_msg.end(), (char*)&len, (char*)&len + 4); + startup_msg.insert(startup_msg.end(), startup_body.begin(), startup_body.end()); + + return startup_msg; +} + +/** + * @brief Builds a password message for PostgreSQL authentication. + * + * This function constructs a password message to be sent to the PostgreSQL server + * during the authentication process. + * + * @param password The password to be included in the message. + * @return A vector of characters representing the constructed password message. + */ +std::vector build_password_message(std::string_view password) { + std::vector password_message; + int pass_msg_len = htonl(password.size() + 1 + 4); + password_message.push_back('p'); + password_message.insert(password_message.end(), (char*)&pass_msg_len, (char*)&pass_msg_len + 4); + password_message.insert(password_message.end(), password.begin(), password.end()); + password_message.push_back('\0'); + return password_message; +} + +/** + * @brief Connects to a PostgreSQL server. + * + * This function establishes a connection to a PostgreSQL server using the provided + * host and port number. + * + * @param host The hostname or IP address of the PostgreSQL server. + * @param port The port number of the PostgreSQL server. + * @return The socket file descriptor for the connection. + */ +int connect_server(const std::string& host, int port) { + int sock; + struct sockaddr_in server; + + // Create socket + sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == -1) { + perror("Socket creation failed"); + return 1; + } + + // Configure server address + server.sin_family = AF_INET; + server.sin_port = htons(port); + server.sin_addr.s_addr = inet_addr(host.c_str()); + + // Connect to PostgreSQL server + if (connect(sock, (struct sockaddr*)&server, sizeof(server)) < 0) { + fprintf(stderr, "Connection failed\n"); + return -1; + } + return sock; +} + +/** + * @brief Sends a startup message to the PostgreSQL server. + * + * This function constructs and sends a startup message to the PostgreSQL server + * using the provided socket and connection parameters. + * + * @param sock The socket file descriptor for the connection. + * @param params A vector of key-value pairs representing the connection parameters. + */ +void send_startup_message(int sock, const std::vector>& params) { + + int param_count = params.size(); + + char msg[512]; + int offset = 0; + uint32_t length = 4 + 4; + + for (int i = 0; i < param_count; i++) { + length += params[i].first.size() + 1 + params[i].second.size() + 1; + } + length += 1; + + uint32_t length_nbo = htonl(length); + uint32_t protocol = htonl(196608); + + memcpy(msg + offset, &length_nbo, 4); + offset += 4; + memcpy(msg + offset, &protocol, 4); + offset += 4; + + for (int i = 0; i < param_count; i++) { + strcpy(msg + offset, params[i].first.c_str()); + offset += params[i].first.size() + 1; + strcpy(msg + offset, params[i].second.c_str()); + offset += params[i].second.size() + 1; + } + msg[offset++] = '\0'; + + send(sock, msg, offset, 0); +} + +/** + * @brief Parses an error message from a PostgreSQL response payload. + * + * This function extracts and returns the error message from a given PostgreSQL response payload. + * + * @param payload A vector of characters representing the PostgreSQL response payload. + * @return A string containing the extracted error message. + */ +std::string parse_error(const std::vector& payload) { + const char* data = payload.data(); + size_t pos = 0; + std::string message; + + while (pos < payload.size()) { + if (data[pos] == '\0') break; + char field_type = data[pos++]; + + std::string field_value; + while (pos < payload.size() && data[pos] != '\0') { + field_value += data[pos++]; + } + pos++; + + if (field_type == 'M') message = field_value; + } + return message; +} + +/** + * @brief Handles cleartext password authentication for PostgreSQL. + * + * This function processes the cleartext password authentication request from the PostgreSQL server. + * + * @param sock The socket file descriptor for the connection. + * @param password The password to be sent for authentication. + * @return True if authentication is successful, false otherwise. + */ +bool handle_cleartext_auth(int sock, std::string_view password) { + + char response[1024]; + if (recv_data(sock, response, 9) == false) { // Read first 8 bytes (message type + length + auth type) + fprintf(stderr, "Error: failed to receive authentication message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + if (response[0] != 'R') { + fprintf(stderr, "Unexpected authentication response\n"); + return false; + } + + uint32_t auth_type; + memcpy(&auth_type, response + 5, 4); + auth_type = ntohl(auth_type); + + if (auth_type == 3) { // AuthenticationCleartextPassword + diag("Server requests cleartext password authentication\n"); + + std::vector password_msg = build_password_message(password); + + // Send PasswordMessage + if (send_data(sock, password_msg.data(), password_msg.size()) == false) { + fprintf(stderr, "Error: failed to send password message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + // Receive AuthenticationOK or Failure + if (recv_data(sock, response, 5) == false) { + fprintf(stderr, "Error: failed to receive authentication response in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + if (response[0] == 'E') { + uint32_t err_msg_len = ntohl(*reinterpret_cast(response + 1)); + if (recv_data(sock, response, err_msg_len - 4) == false) { + fprintf(stderr, "Error: failed to receive error message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + const std::vector payload (response, response + err_msg_len - 4); + std::string error_message = parse_error(payload); + fprintf(stderr, "%s\n", error_message.c_str()); + return false; + } + else if (response[0] != 'R') { + fprintf(stderr, "Unexpected authentication response\n"); + return false; + } + + if (recv_data(sock, response, 4) == false) { + fprintf(stderr, "Error: failed to receive error message in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + memcpy(&auth_type, response, 4); + auth_type = ntohl(auth_type); + + if (auth_type != 0) { + diag("Authentication failed!\n"); + return false; + } + + std::vector msg_list; + + while (true) { + int bytes_received = recv(sock, response, sizeof(response), MSG_DONTWAIT); + if (bytes_received == 0) { + fprintf(stderr, "Error: Connection closed in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + if (bytes_received == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } + fprintf(stderr, "Error: Connection closed or error occurred in file %s, line %d\n", __FILE__, __LINE__); + return false; + } + + size_t offset = 0; + while (offset < static_cast(bytes_received)) { + char messageType = response[offset]; + int32_t messageLength; + + if (offset + 5 > static_cast(bytes_received)) { + fprintf(stderr, "Incomplete message header received\n"); + return false; + } + + memcpy(&messageLength, response + offset + 1, 4); + messageLength = ntohl(messageLength); + + if (offset + messageLength > static_cast(bytes_received)) { + fprintf(stderr, "Incomplete message body received.\n"); + return false; + } + + PgSQLResponse msg; + msg.type = messageType; + msg.length = messageLength; + msg.data.assign(response + offset + 5, response + offset + messageLength + 1); + msg_list.emplace_back(std::move(msg)); + offset += messageLength + 1; + } + } + + if (msg_list.back().type == 'E') { + const std::vector payload = msg_list.back().data; + std::string error_message = parse_error(payload); + fprintf(stderr, "%s\n", error_message.c_str()); + return false; + } else if (msg_list.back().type != 'Z') { + fprintf(stderr, "Unexpected message type %c\n", msg_list.back().type); + return false; + } + + } else { + diag("Unexpected authentication method: %d\n", auth_type); + return false; + } + diag("Authentication successful!\n"); + return true; +} + +/** + * @brief Executes a query on the PostgreSQL server. + * + * This function sends a query to the PostgreSQL server and processes the response. + * + * @param sock The socket file descriptor for the connection. + * @param query The SQL query to be executed. + * @return A unique pointer to a MyPGresult structure containing the query result. + */ +std::unique_ptr execute_query(int sock, const std::string& query) { + std::vector msg; + + // Build query message + msg.push_back('Q'); + uint32_t length = htonl(query.size() + 5); // 1 byte for Q + 4 for length + query + null + msg.insert(msg.end(), reinterpret_cast(&length), reinterpret_cast(&length) + 4); + msg.insert(msg.end(), query.begin(), query.end()); + msg.push_back('\0'); + + // Send Query message + if (send_data(sock, msg.data(), msg.size()) == false) { + fprintf(stderr, "Error: failed to send query in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + std::unique_ptr result(new MyPGresult); + + while (true) { + char msg_type; + if (recv_data(sock, reinterpret_cast(&msg_type), 1) == false) { + fprintf(stderr, "Error: failed to receive message in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + uint32_t msg_length; + if (recv_data(sock, reinterpret_cast(&msg_length), 4) == false) { + fprintf(stderr, "Error: failed to receive message length in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + msg_length = ntohl(msg_length); + + std::vector payload(msg_length - 4); + if (msg_length > 4) { + if (recv_data(sock, payload.data(), msg_length - 4) == false) { + fprintf(stderr, "Error: failed to receive message payload in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + } + + switch (msg_type) { + case 'T': { // Row Description + const char* data = payload.data(); + uint16_t num_columns; + memcpy(&num_columns, data, 2); + num_columns = ntohs(num_columns); + + size_t pos = 2; + result->columns.reserve(num_columns); + for (int i = 0; i < num_columns; i++) { + std::string col_name(data + pos); + pos += col_name.length() + 1 + 22; // Skip field metadata + result->columns.push_back(col_name); + } + break; + } + + case 'D': { // Data Row + const char* data = payload.data(); + uint16_t num_fields; + memcpy(&num_fields, data, 2); + num_fields = ntohs(num_fields); + + std::vector row; + size_t pos = 2; + for (int i = 0; i < num_fields; i++) { + int32_t field_len; + memcpy(&field_len, data + pos, 4); + pos += 4; + field_len = ntohl(field_len); + + if (field_len == -1) { + row.push_back("NULL"); + } + else { + row.emplace_back(data + pos, field_len); + pos += field_len; + } + } + result->rows.push_back(row); + break; + } + + case 'C': // Command Complete + result->status = std::string(payload.data(), payload.size()); + break; + + case 'E': // Error + result->error = parse_error(payload); + // Consume remaining messages + while (msg_type != 'Z') { + if (recv_data(sock, &msg_type, 1) == false) { + fprintf(stderr, "Error: failed to receive message in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + if (recv_data(sock, reinterpret_cast(&msg_length), 4) == false) { + fprintf(stderr, "Error: failed to receive message length in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + + msg_length = ntohl(msg_length); + if (msg_length > 4) { + std::vector temp(msg_length - 4); + if (recv_data(sock, temp.data(), msg_length - 4) == false) { + fprintf(stderr, "Error: failed to receive message payload in file %s, line %d\n", __FILE__, __LINE__); + return nullptr; + } + } + } + return std::move(result); + + case 'Z': // Ready for Query + return std::move(result); + + default: + break; + } + } + return nullptr; +} + +const char* escape_string_backslash_spaces(const char* input) { + const char* c; + int input_len = 0; + int escape_count = 0; + + for (c = input; *c != '\0'; c++) { + if ((*c == ' ')) { + escape_count += 3; + } + else if ((*c == '\\')) { + escape_count += 2; + } + input_len++; + } + + if (escape_count == 0) + return input; + + char* output = (char*)malloc(input_len + escape_count + 1); + char* p = output; + + for (c = input; *c != '\0'; c++) { + if ((*c == ' ')) { + memcpy(p, "\\\\", 2); + p += 2; + } + else if (*c == '\\') { + *(p++) = '\\'; + } + *(p++) = *c; + } + *(p++) = '\0'; + return output; +} + +bool test_parameters(PGconn* admin_conn, const parameter_test& test) { + char buffer[512]; + bool ret = false; + for (const auto& parameter : test.set_admin_vars) { + snprintf(buffer, sizeof(buffer), "SET %s='%s'", parameter.name.c_str(), parameter.value.c_str()); + + if (executeQueries(admin_conn, { buffer, "LOAD PGSQL VARIABLES TO RUNTIME" }) == false) { + BAIL_OUT("Error: failed to set admin variable in file %s, line %d", __FILE__, __LINE__); + return false; + } + } + + int sock = connect_server(cl.pgsql_host, cl.pgsql_port); + if (sock == -1) { + BAIL_OUT("Error: failed to connect to the server in file %s, line %d", __FILE__, __LINE__); + return false; + } + + // Build startup message + std::vector parameters = { + { "user", cl.pgsql_username }, + }; + + parameters.reserve(test.conn_params.size() + test.conn_options.size() + 1); + + for (const auto& param : test.conn_params) { + if (param.value.empty() == true) continue; + parameters.push_back(param); + } + + if (test.conn_options.empty() == false) { + std::string options_value; + + for (size_t i = 0; i < test.conn_options.size(); i++) { + + options_value += " -c " + test.conn_params[i].name; + options_value += "="; + + const char* value = test.conn_options[i].c_str(); + const char* escaped_value = escape_string_backslash_spaces(value); + options_value += escaped_value; + + if (value != escaped_value) + free((void*)escaped_value); + } + + parameters.push_back({ "options", options_value }); + } + + // print parameters + for (const auto& param : parameters) { + diag("Parameter: %s = %s\n", param.name.c_str(), param.value.c_str()); + } + + std::vector startup_msg = build_startup_message(parameters); + + // Send StartupMessage + if (send_data(sock, startup_msg.data(), startup_msg.size()) == false) { + BAIL_OUT("Error: failed to send startup message in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + + // Receive AuthenticationRequest + if (handle_cleartext_auth(sock, cl.pgsql_password) == false) { + if (test.expect_failure) { + ok(true, "Authentication should fail"); + ret = true; + } else { + diag("Error: failed to handle cleartext authentication in file %s, line %d", __FILE__, __LINE__); + } + goto cleanup; + } + + for (size_t i = 0; i < test.set_commands.size(); i++) { + snprintf(buffer, sizeof(buffer), "SET %s='%s'", test.conn_params[i].name.c_str(), test.set_commands[i].c_str()); + diag("Executing: %s\n", buffer); + auto result = execute_query(sock, buffer); + + if (result == nullptr) { + BAIL_OUT("Error: failed to execute query in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + + if (result->error.empty() == false) { + ok(test.expect_failure, "Query '%s' should fail. %s", buffer, result->error.c_str()); + } + } + + if (test.reset_after) { + for (const auto& param : test.conn_params) { + std::string reset_cmd = "RESET " + param.name; + diag("Executing: %s\n", reset_cmd.c_str()); + auto result = execute_query(sock, reset_cmd); + if (result == nullptr) { + BAIL_OUT("Error: failed to reset parameter in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + if (result->error.empty() == false) { + ok(test.expect_failure, "Query '%s' should fail. %s", reset_cmd.c_str(), result->error.c_str()); + } + } + } + + for (int i = 0; i < test.conn_params.size(); i++) { + const auto& param = test.conn_params[i]; + std::string show_cmd = "SHOW " + param.name; + diag("Executing: %s\n", show_cmd.c_str()); + auto result = execute_query(sock, show_cmd); + if (result == nullptr) { + BAIL_OUT("Error: failed to execute query in file %s, line %d", __FILE__, __LINE__); + goto cleanup; + } + if (test.expect_failure == false && result->error.empty()) { + ok(result->rows.size() == 1, "Number of rows should be 1"); + ok(result->rows[0][0] == test.expected[i], "Parameter '%s' value should be '%s'. Actual: '%s'", + param.name.c_str(), test.expected[i].c_str(), result->rows[0][0].c_str()); + } else { + ok(test.expect_failure, "Query '%s' should fail. %s", show_cmd.c_str(), result->error.c_str()); + } + } + ret = true; +cleanup: + close(sock); + + return ret; +} + +std::vector test_cases = { + // check if connection parameters validation is working correctly + { {}, + {{"sslmode", "test"}}, + {}, + {}, + {""}, + false, + true + }, + // check if session parameter validation is working correctly + { {}, + {{"extra_float_digits", "20"}}, + {}, + {}, + {""}, + false, + true + }, + // check if options parameters validation is working correctly + { {}, + {{"extra_float_digits", "1"}}, + {"19"}, + {}, + {""}, + false, + true + }, + { {}, + {{"ENABLE_HASHJOIN", "off"}, {"enable_seqscan", "on"}}, + {"on", "off"}, + {}, + {"off", "on"}, + false, + false + }, + { + {}, + {{"extra_float_digits", "1"}}, + {"2"}, + {}, + {"1"}, + false, + false + }, + { + {}, + {{"enable_hashjoin", "off"}, {"enable_seqscan", "on"}}, + {"on", "off"}, + {"on", "off"}, + {"off", "on"}, + true, + false + }, + { + {{"pgsql-default_datestyle", "ISO, MDY"}}, + {{"datestyle", ""}}, + {}, + {"Postgres"}, + {"Postgres, MDY"}, + false, // Reset both + false + }, + { + {{"pgsql-default_datestyle", "ISO, MDY"}}, + {{"datestyle", ""}}, + {}, + {"Postgres"}, + {"ISO, MDY"}, + true, // Reset both + false + }, + { + {}, + {{"escape_string_warning", "on"}, {"standard_conforming_strings", "on"}}, + {}, + {"off", "off"}, + {"off", "off"}, + false, + false + }, + { + {}, + {{"client_encoding", "UTF8"}}, + {"LATIN1"}, + {}, + {"UTF8"}, + false, + false + }, + { + {}, + {{"client_encoding", "UTF8"}}, + {"LATIN1"}, + {"LATIN1"}, + {"LATIN1"}, + false, + false + }, + { + {{"pgsql-default_client_encoding", "utf8"}}, + {{"client_encoding", "UTF8"}}, + {"LATIN1"}, + {"LATIN1"}, + {"UTF8"}, + true, + false + }, + { + {}, + {{"invalid_param", "invalid"}}, + {}, + {}, + {"invalid"}, + false, + true + } +}; + + +int main(int argc, char** argv) { + + int test_count = 0; + + for (const auto& test_case : test_cases) { + + if (test_case.expect_failure) { + int case_count = 1; + + if (test_case.set_commands.empty() == false) + case_count++; + if (test_case.reset_after) + case_count++; + + test_count += test_case.conn_params.size() * case_count; + } else + test_count += test_case.conn_params.size() * 2; + } + + plan(test_count); + + if (cl.getEnv()) + return exit_status(); + + auto admin_conn = createNewConnection(ConnType::ADMIN, "", false); + + if (!admin_conn || PQstatus(admin_conn.get()) != CONNECTION_OK) { + BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + + if (executeQueries(admin_conn.get(), { "SET pgsql-authentication_method=1", + "LOAD PGSQL VARIABLES TO RUNTIME" }) == false) { + BAIL_OUT("Error: failed to set pgsql-authentication_method=1 in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + + for (const auto& test_case : test_cases) { + + if (test_parameters(admin_conn.get(), test_case) == false) { + BAIL_OUT("Error: failed to test parameters in file %s, line %d", __FILE__, __LINE__); + return exit_status(); + } + } + + return exit_status(); +}