Merge pull request #4899 from sysown/v3.0_refactor_connection_info_param

Refactor PostgreSQL Connection-Level Parameters Handling
pull/4914/head
René Cannaò 1 year ago committed by GitHub
commit 6df3d6b84b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -38,7 +38,6 @@ enum PgSQL_Param_Name {
PG_REQUIRE_AUTH, // Specifies the authentication method that the client requires from the server
PG_CHANNEL_BINDING, // Controls the client's use of channel binding
PG_CONNECT_TIMEOUT, // Maximum time to wait while connecting, in seconds
PG_CLIENT_ENCODING, // Sets the client_encoding configuration parameter for this connection
PG_OPTIONS, // Specifies command-line options to send to the server at connection start
PG_APPLICATION_NAME, // Specifies a value for the application_name configuration parameter
PG_FALLBACK_APPLICATION_NAME, // Specifies a fallback value for the application_name configuration parameter
@ -70,120 +69,81 @@ enum PgSQL_Param_Name {
PG_TARGET_SESSION_ATTRS, // Determines whether the session must have certain properties to be acceptable
PG_LOAD_BALANCE_HOSTS, // Controls the order in which the client tries to connect to the available hosts and addresses
// Environment Options
PG_DATESTYLE, // Sets the value of the DateStyle parameter
PG_TIMEZONE, // Sets the value of the TimeZone parameter
PG_GEQO, // Enables or disables the use of the GEQO query optimizer
//PG_DATESTYLE, // Sets the value of the DateStyle parameter
//PG_TIMEZONE, // Sets the value of the TimeZone parameter
//PG_GEQO, // Enables or disables the use of the GEQO query optimizer
PG_PARAM_SIZE
};
static const char* PgSQL_Param_Name_Str[] = {
"host",
"hostaddr",
"port",
"database",
"user",
"password",
"passfile",
"require_auth",
"channel_binding",
"connect_timeout",
"client_encoding",
"options",
"application_name",
"fallback_application_name",
"keepalives",
"keepalives_idle",
"keepalives_interval",
"keepalives_count",
"tcp_user_timeout",
"replication",
"gsseencmode",
"sslmode",
"requiressl",
"sslcompression",
"sslcert",
"sslkey",
"sslpassword",
"sslcertmode",
"sslrootcert",
"sslcrl",
"sslcrldir",
"sslsni",
"requirepeer",
"ssl_min_protocol_version",
"ssl_max_protocol_version",
"krbsrvname",
"gsslib",
"gssdelegation",
"service",
"target_session_attrs",
"load_balance_hosts",
// Environment Options
"datestyle",
"timezone",
"geqo"
};
struct Param_Name_Validation {
const char** accepted_values;
int default_value_idx;
};
static const Param_Name_Validation require_auth{ (const char* []) { "password","md5","gss","sspi","scram-sha-256","none",nullptr },-1 };
static const Param_Name_Validation replication{ (const char* []) { "true","on","yes","1","database","false","off","no","0",nullptr },-1 };
static const Param_Name_Validation gsseencmode{ (const char* []) { "disable","prefer","require",nullptr },1 };
static const Param_Name_Validation sslmode{ (const char* []) { "disable","allow","prefer","require","verify-ca","verify-full",nullptr },2 };
static const Param_Name_Validation sslcertmode{ (const char* []) { "disable","allow","require",nullptr },1 };
static const Param_Name_Validation target_session_attrs{ (const char* []) { "any","read-write","read-only","primary","standby","prefer-standby",nullptr },0 };
static const Param_Name_Validation load_balance_hosts{ (const char* []) { "disable","random",nullptr },-1 };
// Excluding client_encoding since it is managed as part of the session variables
#define PARAMETER_LIST \
PARAM("host", nullptr) \
PARAM("hostaddr", nullptr) \
PARAM("port", nullptr) \
PARAM("database", nullptr) \
PARAM("user", nullptr) \
PARAM("password", nullptr) \
PARAM("passfile", nullptr) \
PARAM("require_auth", &require_auth) \
PARAM("channel_binding", nullptr) \
PARAM("connect_timeout", nullptr) \
PARAM("options", nullptr) \
PARAM("application_name", nullptr) \
PARAM("fallback_application_name", nullptr) \
PARAM("keepalives", nullptr) \
PARAM("keepalives_idle", nullptr) \
PARAM("keepalives_interval", nullptr) \
PARAM("keepalives_count", nullptr) \
PARAM("tcp_user_timeout", nullptr) \
PARAM("replication", &replication) \
PARAM("gsseencmode", &gsseencmode) \
PARAM("sslmode", &sslmode) \
PARAM("requiressl", nullptr) \
PARAM("sslcompression", nullptr) \
PARAM("sslcert", nullptr) \
PARAM("sslkey", nullptr) \
PARAM("sslpassword", nullptr) \
PARAM("sslcertmode", &sslcertmode) \
PARAM("sslrootcert", nullptr) \
PARAM("sslcrl", nullptr) \
PARAM("sslcrldir", nullptr) \
PARAM("sslsni", nullptr) \
PARAM("requirepeer", nullptr) \
PARAM("ssl_min_protocol_version", nullptr) \
PARAM("ssl_max_protocol_version", nullptr) \
PARAM("krbsrvname", nullptr) \
PARAM("gsslib", nullptr) \
PARAM("gssdelegation", nullptr) \
PARAM("service", nullptr) \
PARAM("target_session_attrs", &target_session_attrs) \
PARAM("load_balance_hosts", &load_balance_hosts)
// Generate parameter array
constexpr const char* param_name[] = {
#define PARAM(name, val) name,
PARAMETER_LIST
#undef PARAM
};
static const Param_Name_Validation require_auth {(const char*[]){"password","md5","gss","sspi","scram-sha-256","none",nullptr},-1};
static const Param_Name_Validation replication {(const char*[]){"true","on","yes","1","database","false","off","no","0",nullptr},-1};
static const Param_Name_Validation gsseencmode {(const char*[]){"disable","prefer","require",nullptr},1};
static const Param_Name_Validation sslmode {(const char*[]){"disable","allow","prefer","require","verify-ca","verify-full",nullptr},2};
static const Param_Name_Validation sslcertmode {(const char*[]){"disable","allow","require",nullptr},1};
static const Param_Name_Validation target_session_attrs {(const char*[]){"any","read-write","read-only","primary","standby","prefer-standby",nullptr},0 };
static const Param_Name_Validation load_balance_hosts {(const char*[]){"disable","random",nullptr},-1};
static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SIZE] = {
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&require_auth,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&replication,
&gsseencmode,
&sslmode,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&sslcertmode,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&target_session_attrs,
&load_balance_hosts,
nullptr,
nullptr,
nullptr
// make sure all the keys are in lower case
static const std::unordered_map<std::string_view, const Param_Name_Validation*> param_name_map = {
#define PARAM(name, val) {name, val},
PARAMETER_LIST
#undef PARAM
};
#define PG_EVENT_NONE 0x00
@ -192,83 +152,66 @@ static const Param_Name_Validation* PgSQL_Param_Name_Accepted_Values[PG_PARAM_SI
#define PG_EVENT_EXCEPT 0x04
#define PG_EVENT_TIMEOUT 0x08
class PgSQL_Conn_Param {
private:
bool validate(PgSQL_Param_Name key, const char* val) {
assert(val);
const Param_Name_Validation* validation = PgSQL_Param_Name_Accepted_Values[key];
if (validation != nullptr && validation->accepted_values) {
const char** accepted_value = validation->accepted_values;
while (accepted_value != nullptr) {
if (strcmp(val, *accepted_value) == 0) {
return true;
}
}
} else {
return true;
}
return false;
}
/**
* @class PgSQL_Conn_Param
* @brief Stores PostgreSQL connection parameters sent by client.
*
* This class stores key-value pairs representing connection parameters
* for PostgreSQL connection.
*
*/
class PgSQL_Conn_Param {
public:
PgSQL_Conn_Param() {}
~PgSQL_Conn_Param() {
for (int i = 0; i < PG_PARAM_SIZE; i++) {
if (param_value[i])
free(param_value[i]);
}
}
bool set_value(PgSQL_Param_Name key, const char* val) {
if (key == -1) return false;
if (validate(key, val)) {
if (param_value[key]) {
free(param_value[key]);
}
param_value[key] = strdup(val);
param_set.push_back(key);
return true;
}
return false;
}
bool set_value(const char* key, const char* val) {
return set_value((PgSQL_Param_Name)get_param_name(key), val);
}
void reset_value(PgSQL_Param_Name key) {
if (param_value[key]) {
free(param_value[key]);
}
param_value[key] = nullptr;
// this has O(n) complexity. need to fix it....
param_set.erase(param_set.begin() + static_cast<int>(key));
}
PgSQL_Conn_Param() {}
~PgSQL_Conn_Param() {}
bool set_value(const char* key, const char* val) {
if (key == nullptr || val == nullptr) return false;
connection_parameters[key] = val;
return true;
}
inline
const char* get_value(PgSQL_Param_Name key) const {
return get_value(param_name[key]);
}
const char* get_value(const char* key) const {
auto it = connection_parameters.find(key);
if (it != connection_parameters.end()) {
return it->second.c_str();
}
return nullptr;
}
bool remove_value(const char* key) {
auto it = connection_parameters.find(key);
if (it != connection_parameters.end()) {
connection_parameters.erase(it);
return true;
}
return false;
}
inline
bool is_empty() const {
return connection_parameters.empty();
}
inline
void clear() {
connection_parameters.clear();
}
const char* get_value(PgSQL_Param_Name key) const {
return param_value[key];
}
int get_param_name(const char* name) {
int key = -1;
for (int i = 0; i < PG_PARAM_SIZE; i++) {
if (strcmp(name, PgSQL_Param_Name_Str[i]) == 0) {
key = i;
break;
}
}
if (key == -1) {
proxy_warning("Unrecognized connection option. Please report this as a bug for future enhancements:%s\n", name);
}
return key;
}
private:
/**
* @brief Stores the connection parameters as key-value pairs.
*/
std::map<std::string, std::string> connection_parameters;
std::vector<PgSQL_Param_Name> param_set;
char* param_value[PG_PARAM_SIZE]{};
friend class PgSQL_Session;
friend class PgSQL_Protocol;
};
class PgSQL_Variable {
@ -418,12 +361,9 @@ public:
PGresult* get_result();
void next_multi_statement_result(PGresult* result);
bool set_single_row_mode();
void optimize() {}
void update_bytes_recv(uint64_t bytes_recv);
void update_bytes_sent(uint64_t bytes_sent);
void ProcessQueryAndSetStatusFlags(char* query_digest_text);
void set_charset(const char* charset);
void connect_start_SetCharset(const char* client_encoding, uint32_t hash = 0);
inline const PGconn* get_pg_connection() const { return pgsql_conn; }
inline int get_pg_server_version() { return PQserverVersion(pgsql_conn); }

@ -986,7 +986,7 @@ private:
* @note This function iterates through the key-value pairs in the startup
* packet and stores them in the connection parameters object.
*/
void load_conn_parameters(pgsql_hdr* pkt, bool startup);
bool load_conn_parameters(pgsql_hdr* pkt);
/**
* @brief Handles the client's first message in a SCRAM-SHA-256

@ -283,9 +283,6 @@ private:
bool handler_again___status_SETTING_LDAP_USER_VARIABLE(int*);
bool handler_again___status_SETTING_SQL_MODE(int*);
bool handler_again___status_SETTING_SESSION_TRACK_GTIDS(int*);
#endif // 0
bool handler_again___status_CHANGING_CHARSET(int* _rc);
#if 0
bool handler_again___status_CHANGING_SCHEMA(int*);
#endif // 0
bool handler_again___status_CONNECTING_SERVER(int*);
@ -414,12 +411,6 @@ public:
int current_hostgroup;
int default_hostgroup;
int previous_hostgroup;
/**
* @brief Charset directly specified by the client. Supplied and updated via 'HandshakeResponse'
* and 'COM_CHANGE_USER' packets.
* @details Used when session needs to be restored via 'COM_RESET_CONNECTION'.
*/
int default_charset;
int locked_on_hostgroup;
int next_query_flagIN;
int mirror_hostgroup;

@ -37,7 +37,9 @@ class PgSQL_Set_Stmt_Parser {
void generateRE_parse1v2();
// First implemenation of the parser for TRANSACTION ISOLATION LEVEL and TRANSACTION READ/WRITE
std::map<std::string, std::vector<std::string>> parse2();
#if 0
std::string parse_character_set();
#endif
std::string remove_comments(const std::string& q);
};

@ -14,10 +14,8 @@ extern void print_backtrace(void);
typedef bool (*pgsql_verify_var)(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash);
typedef bool (*pgsql_update_var)(PgSQL_Session* session, int idx, int &_rc);
bool validate_charset(PgSQL_Session* session, int idx, int &_rc);
bool update_server_variable(PgSQL_Session* session, int idx, int &_rc);
bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash);
bool verify_set_names(PgSQL_Session* session);
class PgSQL_Variables {
static pgsql_verify_var verifiers[PGSQL_NAME_LAST_HIGH_WM];
@ -31,13 +29,13 @@ public:
PgSQL_Variables();
~PgSQL_Variables();
bool client_set_value(PgSQL_Session* session, int idx, const std::string& value);
bool client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx);
bool client_set_hash_and_value(PgSQL_Session* session, int idx, const std::string& value, uint32_t hash);
void client_reset_value(PgSQL_Session* session, int idx);
const char* client_get_value(PgSQL_Session* session, int idx) const;
uint32_t client_get_hash(PgSQL_Session* session, int idx) const;
void server_set_value(PgSQL_Session* session, int idx, const char* value);
void server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx);
void server_set_hash_and_value(PgSQL_Session* session, int idx, const char* value, uint32_t hash);
void server_reset_value(PgSQL_Session* session, int idx);
const char* server_get_value(PgSQL_Session* session, int idx) const;

@ -9,7 +9,8 @@ typedef enum {
VARIABLE_TYPE_BOOL, /**< Boolean variable type. */
VARIABLE_TYPE_STRING, /**< String variable type. */
VARIABLE_TYPE_DATESTYLE, /**< DateStyle variable type. */
VARIABLE_TYPE_MAINTENANCE_WORK_MEM
VARIABLE_TYPE_MAINTENANCE_WORK_MEM, /**< MaintenanceWorkMem variable type. */
VARIABLE_TYPE_CLIENT_ENCODING /**< ClientEncoding variable type. */
} pgsql_variable_type_t;
@ -61,5 +62,5 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag
extern const pgsql_variable_validator pgsql_variable_validator_bytea_output;
extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits;
extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem;
extern const pgsql_variable_validator pgsql_variable_validator_client_encoding;
#endif // PGSQL_VARIABLES_VALIDATOR_H

@ -244,7 +244,7 @@ enum mysql_variable_name {
};
/* NOTE:
make special ATTENTION that the order in mysql_variable_name
make special ATTENTION that the order in pgsql_variable_name
and pgsql_tracked_variables[] is THE SAME
*/
enum pgsql_variable_name {
@ -337,8 +337,6 @@ struct pgsql_variable_st {
enum session_status status; // what status should be changed after setting this variables
const char* set_variable_name; // what variable name (or string) will be used when setting it to backend
const char* internal_variable_name; // variable name as displayed in admin , WITHOUT "default_"
// Also used in INTERNAL SESSION
// if NULL , MySQL_Variables::MySQL_Variables will set it to set_variable_name during initialization
const char* default_value; // default value
uint8_t options; // options
const pgsql_variable_validator* validator; // validate value
@ -1797,9 +1795,10 @@ extern const pgsql_variable_validator pgsql_variable_validator_client_min_messag
extern const pgsql_variable_validator pgsql_variable_validator_bytea_output;
extern const pgsql_variable_validator pgsql_variable_validator_extra_float_digits;
extern const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem;
extern const pgsql_variable_validator pgsql_variable_validator_client_encoding;
pgsql_variable_st pgsql_tracked_variables[]{
{ PGSQL_CLIENT_ENCODING, SETTING_CHARSET, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_SET_TRANSACTION | PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), nullptr, { "names", nullptr } },
{ PGSQL_CLIENT_ENCODING, SETTING_VARIABLE, "client_encoding", "client_encoding", "UTF8", (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_client_encoding, { "names", nullptr } },
{ PGSQL_DATESTYLE, SETTING_VARIABLE, "datestyle", "datestyle", "ISO, MDY" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_datestyle, nullptr },
{ PGSQL_INTERVALSTYLE, SETTING_VARIABLE, "intervalstyle", "intervalstyle", "postgres" , (PGTRACKED_VAR_OPT_QUOTE | PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_intervalstyle, nullptr },
{ PGSQL_STANDARD_CONFORMING_STRINGS, SETTING_VARIABLE, "standard_conforming_strings", "standard_conforming_strings", "on", (PGTRACKED_VAR_OPT_PARAM_STATUS), &pgsql_variable_validator_bool, nullptr },
@ -1819,6 +1818,7 @@ pgsql_variable_st pgsql_tracked_variables[]{
{ PGSQL_MAINTENANCE_WORK_MEM, SETTING_VARIABLE, "maintenance_work_mem", "maintenance_work_mem", "64MB", (PGTRACKED_VAR_OPT_QUOTE), &pgsql_variable_validator_maintenance_work_mem, nullptr },
{ PGSQL_SYNCHRONOUS_COMMIT, SETTING_VARIABLE, "synchronous_commit", "synchronous_commit", "on", (PGTRACKED_VAR_OPT_QUOTE), &pgsql_variable_validator_synchronous_commit, nullptr},
};
#endif //EXCLUDE_TRACKING_VARAIABLES
#else

@ -14,103 +14,6 @@ using json = nlohmann::json;
#include "PgSQL_Query_Processor.h"
#include "MySQL_Variables.h"
#if 0
// some of the code that follows is from mariadb client library memory allocator
typedef int myf; // Type of MyFlags in my_funcs
#define MYF(v) (myf) (v)
#define MY_KEEP_PREALLOC 1
#define MY_ALIGN(A,L) (((A) + (L) - 1) & ~((L) - 1))
#define ALIGN_SIZE(A) MY_ALIGN((A),sizeof(double))
static void ma_free_root(MA_MEM_ROOT *root, myf MyFLAGS);
static void *ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size);
#define MAX(a,b) (((a) > (b)) ? (a) : (b))
static void * ma_alloc_root(MA_MEM_ROOT *mem_root, size_t Size)
{
size_t get_size;
void * point;
MA_USED_MEM *next= 0;
MA_USED_MEM **prev;
Size= ALIGN_SIZE(Size);
if ((*(prev= &mem_root->free)))
{
if ((*prev)->left < Size &&
mem_root->first_block_usage++ >= 16 &&
(*prev)->left < 4096)
{
next= *prev;
*prev= next->next;
next->next= mem_root->used;
mem_root->used= next;
mem_root->first_block_usage= 0;
}
for (next= *prev; next && next->left < Size; next= next->next)
prev= &next->next;
}
if (! next)
{ /* Time to alloc new block */
get_size= MAX(Size+ALIGN_SIZE(sizeof(MA_USED_MEM)),
(mem_root->block_size & ~1) * ( (mem_root->block_num >> 2) < 4 ? 4 : (mem_root->block_num >> 2) ) );
if (!(next = (MA_USED_MEM*) malloc(get_size)))
{
if (mem_root->error_handler)
(*mem_root->error_handler)();
return((void *) 0); /* purecov: inspected */
}
mem_root->block_num++;
next->next= *prev;
next->size= get_size;
next->left= get_size-ALIGN_SIZE(sizeof(MA_USED_MEM));
*prev=next;
}
point= (void *) ((char*) next+ (next->size-next->left));
if ((next->left-= Size) < mem_root->min_malloc)
{ /* Full block */
*prev=next->next; /* Remove block from list */
next->next=mem_root->used;
mem_root->used=next;
mem_root->first_block_usage= 0;
}
return(point);
}
static void ma_free_root(MA_MEM_ROOT *root, myf MyFlags)
{
MA_USED_MEM *next,*old;
if (!root)
return; /* purecov: inspected */
if (!(MyFlags & MY_KEEP_PREALLOC))
root->pre_alloc=0;
for ( next=root->used; next ;)
{
old=next; next= next->next ;
if (old != root->pre_alloc)
free(old);
}
for (next= root->free ; next ; )
{
old=next; next= next->next ;
if (old != root->pre_alloc)
free(old);
}
root->used=root->free=0;
if (root->pre_alloc)
{
root->free=root->pre_alloc;
root->free->left=root->pre_alloc->size-ALIGN_SIZE(sizeof(MA_USED_MEM));
root->free->next=0;
}
}
#endif // 0
extern char * binary_sha1;
#include "proxysql_find_charset.h"
@ -841,31 +744,22 @@ void PgSQL_Connection::connect_start() {
conninfo << "sslmode='disable' "; // not supporting SSL
}
if (myds && myds->sess) {
const char* charset = NULL;
uint32_t charset_hash = 0;
// Take client character set and use it to connect to backend
charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING);
if (charset_hash != 0)
charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING);
//if (!charset) {
// charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING];
//}
if (myds && myds->sess && myds->sess->client_myds) {
// Client Encoding should be always set
assert(charset);
connect_start_SetCharset(charset, charset_hash);
escaped_str = escape_string_single_quotes_and_backslashes((char*)charset, false);
const char* client_charset = pgsql_variables.client_get_value(myds->sess, PGSQL_CLIENT_ENCODING);
assert(client_charset);
uint32_t client_charset_hash = pgsql_variables.client_get_hash(myds->sess, PGSQL_CLIENT_ENCODING);
assert(client_charset_hash);
const char* escaped_str = escape_string_backslash_spaces(client_charset);
conninfo << "client_encoding='" << escaped_str << "' ";
if (escaped_str != charset)
free(escaped_str);
if (escaped_str != client_charset)
free((char*)escaped_str);
// charset validation is already done
pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, client_charset, client_charset_hash);
std::vector<unsigned int> client_options;
client_options.reserve(PGSQL_NAME_LAST_LOW_WM + dynamic_variables_idx.size());
client_options.reserve(PGSQL_NAME_LAST_LOW_WM + myds->sess->client_myds->myconn->dynamic_variables_idx.size());
// excluding PGSQL_CLIENT_ENCODING
for (unsigned int idx = 1; idx < PGSQL_NAME_LAST_LOW_WM; idx++) {
@ -873,9 +767,9 @@ void PgSQL_Connection::connect_start() {
client_options.push_back(idx);
}
for (std::vector<unsigned int>::const_iterator it_c = dynamic_variables_idx.begin(); it_c != dynamic_variables_idx.end(); it_c++) {
assert(pgsql_variables.client_get_hash(myds->sess, *it_c));
client_options.push_back(*it_c);
for (uint32_t idx : myds->sess->client_myds->myconn->dynamic_variables_idx) {
assert(pgsql_variables.client_get_hash(myds->sess, idx));
client_options.push_back(idx);
}
if (client_options.empty() == false ||
@ -1882,31 +1776,6 @@ void PgSQL_Connection::ProcessQueryAndSetStatusFlags(char* query_digest_text) {
}
}
void PgSQL_Connection::set_charset(const char* charset) {
proxy_debug(PROXY_DEBUG_MYSQL_CONNPOOL, 4, "Setting client encoding %s\n", charset);
pgsql_variables.client_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset);
}
void PgSQL_Connection::connect_start_SetCharset(const char* charset, uint32_t hash) {
assert(charset);
int charset_encoding = PgSQL_Connection::char_to_encoding(charset);
if (charset_encoding == -1) {
proxy_error("Cannot find character set [%s]\n", charset);
assert(0);
}
/* We are connecting to backend setting charset in connection parameters.
* Client already has sent us a character set and client connection variables have been already set.
* Now we store this charset in server connection variables to avoid updating this variables on backend.
*/
if (hash == 0)
pgsql_variables.server_set_value(myds->sess, PGSQL_CLIENT_ENCODING, charset);
else
pgsql_variables.server_set_hash_and_value(myds->sess, PGSQL_CLIENT_ENCODING, charset, hash);
}
// this function is identical to async_query() , with the only exception that MyRS should never be set
int PgSQL_Connection::async_send_simple_command(short event, char* stmt, unsigned long length) {
PROXY_TRACE();

@ -1789,7 +1789,6 @@ void PgSQL_HostGroups_Manager::push_MyConn_to_pool(PgSQL_Connection *c, bool _lo
mysrvc->ConnectionsUsed->add(c); // Add the connection back to the list of used connections
destroy_MyConn_from_pool(c, false); // Destroy the connection from the pool
} else {*/
c->optimize();
mysrvc->ConnectionsFree->add(c);
//}
} else {

@ -403,7 +403,7 @@ bool PgSQL_Protocol::generate_pkt_initial_handshake(bool send, void** _ptr, unsi
if (RAND_bytes((*myds)->tmp_login_salt, sizeof((*myds)->tmp_login_salt)) != 1) {
// Fallback method: using a basic pseudo-random generator
srand((unsigned int)time(NULL));
for (int i = 0; i < sizeof((*myds)->tmp_login_salt); i++) {
for (size_t i = 0; i < sizeof((*myds)->tmp_login_salt); i++) {
(*myds)->tmp_login_salt[i] = rand() % 256;
}
}
@ -630,26 +630,34 @@ unsigned int get_string(const char* data, unsigned int len, const char** dst_p)
return (nul + 1 - data);
}
void PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt, bool startup)
bool PgSQL_Protocol::load_conn_parameters(pgsql_hdr* pkt)
{
const char* key, * val;
unsigned int read_pos = 0;
while (1) {
uint32_t offset = 0;
int pos = get_string(((const char*)pkt->data.ptr) + read_pos, pkt->data.size - read_pos, &key);
if (pos == 0) return;
while (offset < pkt->data.size) {
char* nameptr = (char*)pkt->data.ptr + offset;
uint32_t valoffset;
char* valptr;
read_pos += pos;
if (*nameptr == '\0')
break; /* found packet terminator */
valoffset = offset + strlen(nameptr) + 1;
if (valoffset >= pkt->data.size)
break; /* missing value, will complain below */
valptr = (char*)pkt->data.ptr + valoffset;
pos = get_string(((const char*)pkt->data.ptr) + read_pos, pkt->data.size - read_pos, &val);
if (pos == 0) return;
(*myds)->myconn->conn_params.set_value(nameptr, valptr);
read_pos += pos;
offset = valoffset + strlen(valptr) + 1;
}
//slog_debug(server, "S: param: %s = %s", key, val);
(*myds)->myconn->conn_params.set_value(key, val);
if (offset != pkt->data.size - 1) {
proxy_error("Malformed startup packet was received from client %s:%d\n", (*myds)->addr.addr, (*myds)->addr.port);
return false;
}
return true;
}
bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len, bool& ssl_request) {
@ -677,13 +685,19 @@ bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len
return false;
}
load_conn_parameters(&hdr, true);
if (!load_conn_parameters(&hdr)) {
proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. malformed startup packet.\n", (*myds)->sess, (*myds));
generate_error_packet(true, false, "invalid startup packet layout: expected terminator as last byte",
PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true);
return false;
}
const unsigned char* user = (unsigned char*)(*myds)->myconn->conn_params.get_value(PG_USER);
if (!user || *user == '\0') {
proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. no username supplied.\n", (*myds), (*myds)->sess);
generate_error_packet(true, false, "no username supplied", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true);
proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p. no username supplied.\n", (*myds)->sess, (*myds));
generate_error_packet(true, false, "no PostgreSQL user name specified in startup packet",
PGSQL_ERROR_CODES::ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION, true);
return false;
}
@ -769,6 +783,7 @@ std::vector<std::pair<std::string, std::string>> PgSQL_Protocol::parse_options(c
// Add key-value pair to the list
if (!key.empty()) {
std::transform(key.begin(), key.end(), key.begin(), ::tolower);
options_list.emplace_back(std::move(key), std::move(value));
}
}
@ -808,26 +823,7 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char*
return EXECUTION_STATE::FAILED;
}
// set charset but first verify
const char* charset = (*myds)->myconn->conn_params.get_value(PG_CLIENT_ENCODING);
// if client does not provide client_encoding, PostgreSQL uses the default client encoding.
// We need to save the default client encoding to send it to the client in case client doesn't provide one.
if (charset == NULL) charset = pgsql_thread___default_variables[PGSQL_CLIENT_ENCODING];
assert(charset);
int charset_encoding = (*myds)->myconn->char_to_encoding(charset);
if (charset_encoding == -1) {
proxy_error("Cannot find charset [%s]\n", charset);
proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "DS=%p , Session=%p , charset='%s'. Client charset not supported.\n", (*myds), (*myds)->sess, charset);
generate_error_packet(true, false, "Client charset not supported", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true);
return EXECUTION_STATE::FAILED;
}
(*myds)->sess->default_charset = charset_encoding;
user = (char*)(*myds)->myconn->conn_params.get_value(PG_USER);
if (!user || *user == '\0') {
@ -1058,92 +1054,190 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char*
userinfo->username = strdup((const char*)user);
userinfo->password = strdup((const char*)password);
const char* db = (*myds)->myconn->conn_params.get_value(PG_DATABASE);
userinfo->set_dbname(db ? db : userinfo->username);
std::vector<std::pair<std::string, std::string>> parameters;
std::vector<std::pair<std::string, std::string>> options_list;
assert(sess);
assert(sess->client_myds);
parameters.reserve((*myds)->myconn->conn_params.connection_parameters.size());
PgSQL_Connection* myconn = sess->client_myds->myconn;
assert(myconn);
myconn->set_charset(charset);
sess->set_default_session_variable(PGSQL_CLIENT_ENCODING, charset);
/* Note: Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior.
PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds,
but the error is encountered when executing a query.
This is behaviour is intentional, as newer PostgreSQL versions may introduce parameters that ProxySQL is not yet aware of.
*/
// New implementation
for (const auto& [param_name, param_val] : (*myds)->myconn->conn_params.connection_parameters) {
std::string param_name_lowercase(param_name);
std::transform(param_name_lowercase.cbegin(), param_name_lowercase.cend(), param_name_lowercase.begin(), ::tolower);
// get datestyle from connection parameters
const char* datestyle_tmp = (*myds)->myconn->conn_params.get_value(PG_DATESTYLE);
std::string datestyle = datestyle_tmp ? datestyle_tmp : "";
// check if parameter is part of connection-level parameters
auto itr = param_name_map.find(param_name_lowercase.c_str());
if (itr != param_name_map.end()) {
if (datestyle.empty()) {
// No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable.
datestyle = pgsql_thread___default_variables[PGSQL_DATESTYLE];
} else {
PgSQL_DateStyle_t datestyle_parsed = PgSQL_DateStyle_Util::parse_datestyle(datestyle);
if (param_name_lowercase.compare("user") == 0 || param_name_lowercase.compare("password") == 0) {
continue;
}
// If DateStyle provided in the connection parameters is incomplete, the missing parts will be taken from the default DateStyle.
if (datestyle_parsed.format == DATESTYLE_FORMAT_NONE || datestyle_parsed.order == DATESTYLE_ORDER_NONE) {
PgSQL_DateStyle_t datestyle_default = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]);
datestyle = PgSQL_DateStyle_Util::datestyle_to_string(datestyle_parsed, datestyle_default);
}
}
bool is_validation_success = false;
const Param_Name_Validation* validation = itr->second;
if (validation != nullptr && validation->accepted_values) {
const char** accepted_value = validation->accepted_values;
while (*accepted_value) {
if (strcmp(param_val.c_str(), *accepted_value) == 0) {
is_validation_success = true;
break;
}
accepted_value++;
}
} else {
is_validation_success = true;
}
assert(datestyle.empty() == false);
if (is_validation_success == false) {
char* m = NULL;
char* errmsg = NULL;
proxy_error("invalid value for parameter \"%s\": \"%s\"\n", param_name.c_str(), param_val.c_str());
m = (char*)"invalid value for parameter \"%s\": \"%s\"";
errmsg = (char*)malloc(param_val.length() + param_name.length() + strlen(m));
sprintf(errmsg, m, param_name.c_str(), param_val.c_str());
generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true);
free(errmsg);
ret = EXECUTION_STATE::FAILED;
// freeing userinfo->username and userinfo->password to prevent invalid password error generation.
free(userinfo->username);
free(userinfo->password);
userinfo->username = strdup("");
userinfo->password = strdup("");
//
goto __exit_process_pkt_handshake_response;
}
if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str())) {
// change current datestyle
sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle);
sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str());
if (param_name_lowercase.compare("database") == 0) {
userinfo->set_dbname(param_val.empty() ? user : param_val.c_str());
} else if (param_name_lowercase.compare("options") == 0) {
options_list = parse_options(param_val.c_str());
}
} else {
// session parameters/variables?
parameters.push_back(std::make_pair(param_name_lowercase, param_val));
}
}
// get timezone from connection parameters
const char* timezone = (*myds)->myconn->conn_params.get_value(PG_TIMEZONE);
if (timezone == NULL)
timezone = pgsql_thread___default_variables[PGSQL_TIMEZONE];
pgsql_variables.client_set_value(sess, PGSQL_TIMEZONE, timezone);
sess->set_default_session_variable(PGSQL_TIMEZONE, timezone);
const char* intervalstyle = pgsql_thread___default_variables[PGSQL_INTERVALSTYLE];
if (intervalstyle) {
pgsql_variables.client_set_value(sess, PGSQL_INTERVALSTYLE, intervalstyle);
sess->set_default_session_variable(PGSQL_INTERVALSTYLE, intervalstyle);
if (userinfo->dbname == nullptr) {
userinfo->set_dbname(user);
}
// get standard_conforming_strings from connection parameters
const char* standard_conforming_strings = pgsql_thread___default_variables[PGSQL_STANDARD_CONFORMING_STRINGS];
if (standard_conforming_strings) {
pgsql_variables.client_set_value(sess, PGSQL_STANDARD_CONFORMING_STRINGS, standard_conforming_strings);
sess->set_default_session_variable(PGSQL_STANDARD_CONFORMING_STRINGS, standard_conforming_strings);
// Merge options with parameters.
// Options are processed first, followed by connection parameters.
// If a parameter is specified in both, the connection parameter takes precedence
// and overwrites the previosly set value.
if (options_list.empty() == false) {
options_list.reserve(parameters.size() + options_list.size());
options_list.insert(options_list.end(), std::make_move_iterator(parameters.begin()), std::make_move_iterator(parameters.end()));
parameters = std::move(options_list);
}
const char* options = (*myds)->myconn->conn_params.get_value(PG_OPTIONS);
// assign default datestyle to current datestyle.
// This is needed by PgSQL_DateStyle_Util::parse_datestyle
sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]);
auto options_list = parse_options(options);
for (const auto&[param_key, param_val] : parameters) {
for (auto& option : options_list) {
int idx = PGSQL_NAME_LAST_HIGH_WM;
for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
if (i == PGSQL_NAME_LAST_LOW_WM)
if (i == PGSQL_NAME_LAST_LOW_WM)
continue;
if (variable_name_exists(pgsql_tracked_variables[i], option.first.c_str()) == true) {
if (strncmp(param_key.c_str(), pgsql_tracked_variables[i].set_variable_name,
strlen(pgsql_tracked_variables[i].set_variable_name)) == 0) {
idx = i;
break;
}
}
}
if (idx != PGSQL_NAME_LAST_HIGH_WM) {
pgsql_variables.client_set_value(sess, idx, option.second.c_str());
sess->set_default_session_variable((enum pgsql_variable_name)idx, option.second.c_str());
std::string value_copy = param_val;
char* transformed_value = nullptr;
if (pgsql_tracked_variables[idx].validator && pgsql_tracked_variables[idx].validator->validate &&
(
*pgsql_tracked_variables[idx].validator->validate)(
value_copy.c_str(), &pgsql_tracked_variables[idx].validator->params, sess, &transformed_value) == false
) {
char* m = NULL;
char* errmsg = NULL;
proxy_error("invalid value for parameter \"%s\": \"%s\"\n", pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str());
m = (char*)"invalid value for parameter \"%s\": \"%s\"";
errmsg = (char*)malloc(value_copy.length() + strlen(pgsql_tracked_variables[idx].set_variable_name) + strlen(m));
sprintf(errmsg, m, pgsql_tracked_variables[idx].set_variable_name, value_copy.c_str());
generate_error_packet(true, false, errmsg, PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, true);
free(errmsg);
ret = EXECUTION_STATE::FAILED;
// freeing userinfo->username and userinfo->password to prevent invalid password error generation.
free(userinfo->username);
free(userinfo->password);
userinfo->username = strdup("");
userinfo->password = strdup("");
//
goto __exit_process_pkt_handshake_response;
}
if (transformed_value) {
value_copy = transformed_value;
free(transformed_value);
}
if (idx == PGSQL_DATESTYLE) {
// get datestyle from connection parameters
std::string datestyle = value_copy.empty() == false ? value_copy : "";
if (datestyle.empty()) {
// No need to validate default DateStyle again; it is already verified in PgSQL_Threads_Handler::set_variable.
datestyle = pgsql_thread___default_variables[PGSQL_DATESTYLE];
}
else {
PgSQL_DateStyle_t datestyle_parsed = PgSQL_DateStyle_Util::parse_datestyle(datestyle);
// If DateStyle provided in the connection parameters is incomplete, the missing parts will be taken from the default DateStyle.
if (datestyle_parsed.format == DATESTYLE_FORMAT_NONE || datestyle_parsed.order == DATESTYLE_ORDER_NONE) {
PgSQL_DateStyle_t datestyle_default = PgSQL_DateStyle_Util::parse_datestyle(pgsql_thread___default_variables[PGSQL_DATESTYLE]);
datestyle = PgSQL_DateStyle_Util::datestyle_to_string(datestyle_parsed, datestyle_default);
}
}
assert(datestyle.empty() == false);
if (pgsql_variables.client_set_value(sess, PGSQL_DATESTYLE, datestyle.c_str(), false)) {
// change current datestyle
sess->current_datestyle = PgSQL_DateStyle_Util::parse_datestyle(datestyle);
sess->set_default_session_variable(PGSQL_DATESTYLE, datestyle.c_str());
}
} else {
pgsql_variables.client_set_value(sess, idx, value_copy.c_str(), false);
sess->set_default_session_variable((enum pgsql_variable_name)idx, value_copy.c_str());
}
} else {
const char* val = option.second.c_str();
const char* escaped_str = escape_string_backslash_spaces(val);
sess->untracked_option_parameters = "-c " + option.first + "=" + escaped_str + " ";
if (escaped_str != val)
// parameter provided is not part of the tracked variables. Will lock on hostgroup on next query.
const char* val_cstr = param_val.c_str();
proxy_warning("Unrecognized connection parameter. Please report this as a bug for future enhancements:%s:%s\n", param_key.c_str(), val_cstr);
const char* escaped_str = escape_string_backslash_spaces(val_cstr);
sess->untracked_option_parameters = "-c " + param_key + "=" + escaped_str + " ";
if (escaped_str != val_cstr)
free((char*)escaped_str);
}
}
//if (charset)
// (*myds)->sess->default_charset = charset;
// fill all crtical variables with default values, if not set by client
for (int i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) {
if (pgsql_variables.client_get_hash(sess, i) != 0)
continue;
const char* val = pgsql_thread___default_variables[i];
pgsql_variables.client_set_value(sess, i, val, false);
sess->set_default_session_variable((pgsql_variable_name)i, val);
}
sess->client_myds->myconn->reorder_dynamic_variables_idx();
}
else {
// we always duplicate username and password, or crashes happen

@ -81,7 +81,8 @@ static inline char is_normal_char(char c) {
}
*/
static const std::array<std::string,5> pgsql_critical_variables = {
static const std::array<std::string,6> pgsql_critical_variables = {
"client_encoding",
"datestyle",
"intervalstyle",
"standard_conforming_strings",
@ -825,10 +826,9 @@ void PgSQL_Session::generate_proxysql_internal_session_json(json& j) {
//j["conn"]["ps"]["client_stmt_to_global_ids"] = client_myds->myconn->local_stmts->client_stmt_to_global_ids;
const PgSQL_Conn_Param& conn_params = client_myds->myconn->conn_params;
for (size_t i = 0; i < conn_params.param_set.size(); i++) {
if (conn_params.param_value[conn_params.param_set[i]] != NULL) {
j["client"]["conn"]["connection_options"][PgSQL_Param_Name_Str[conn_params.param_set[i]]] = conn_params.param_value[conn_params.param_set[i]];
}
for (const auto& [key, val] : conn_params.connection_parameters) {
j["client"]["conn"]["connection_options"][key.c_str()] = val.c_str();
}
}
}
@ -1508,98 +1508,6 @@ bool PgSQL_Session::handler_again___status_SETTING_INIT_CONNECT(int* _rc) {
return ret;
}
bool PgSQL_Session::handler_again___status_CHANGING_CHARSET(int* _rc) {
assert(mybe->server_myds->myconn);
PgSQL_Data_Stream* myds = mybe->server_myds;
PgSQL_Connection* myconn = myds->myconn;
/* Validate that server can support client's charset */
if (!validate_charset(this, PGSQL_CLIENT_ENCODING, *_rc)) {
return false;
}
myds->DSS = STATE_MARIADB_QUERY;
enum session_status st = status;
if (myds->mypolls == NULL) {
thread->mypolls.add(POLLIN | POLLOUT, mybe->server_myds->fd, mybe->server_myds, thread->curtime);
}
std::string charset = pgsql_variables.client_get_value(this, PGSQL_CLIENT_ENCODING);
std::string query = "SET CLIENT_ENCODING TO '" + charset + "'";
int rc = myconn->async_send_simple_command(myds->revents, (char*)query.c_str(),query.length());
if (rc == 0) {
__sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1);
myds->DSS = STATE_MARIADB_GENERIC;
st = previous_status.top();
previous_status.pop();
NEXT_IMMEDIATE_NEW(st);
} else {
if (rc == -1) {
// the command failed
const bool error_present = myconn->is_error_present();
PgHGM->p_update_pgsql_error_counter(
p_pgsql_error_type::pgsql,
myconn->parent->myhgc->hid,
myconn->parent->address,
myconn->parent->port,
(error_present ? 9999 : ER_PROXYSQL_OFFLINE_SRV) // TOFIX: 9999 is a placeholder for the actual error code
);
if (error_present == false || (error_present == true && myconn->is_connection_in_reusable_state() == false)) {
bool retry_conn = false;
// client error, serious
proxy_error(
"Client trying to set a charset (%s) not supported by backend (%s:%d). Changing it to %s\n",
charset.c_str(), myconn->parent->address, myconn->parent->port, pgsql_tracked_variables[PGSQL_CLIENT_ENCODING].default_value
);
detected_broken_connection(__FILE__, __LINE__, __func__, "during SET CLIENT_ENCODING", myconn);
if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) {
retry_conn = true;
}
myds->destroy_MySQL_Connection_From_Pool(false);
myds->fd = 0;
if (retry_conn) {
myds->DSS = STATE_NOT_INITIALIZED;
NEXT_IMMEDIATE_NEW(CONNECTING_SERVER);
}
*_rc = -1;
return false;
} else {
proxy_warning("Error during SET CLIENT_ENCODING: %s\n", myconn->get_error_code_with_message().c_str());
// we won't go back to PROCESSING_QUERY
st = previous_status.top();
previous_status.pop();
client_myds->myprot.generate_error_packet(true, true, myconn->get_error_message().c_str(), myconn->get_error_code(), false);
myds->destroy_MySQL_Connection_From_Pool(true);
myds->fd = 0;
RequestEnd(myds); //fix bug #682
}
} else {
if (rc == -2) {
bool retry_conn = false;
proxy_error("Timeout during SET CLIENT_ENCODING on %s , %d\n", myconn->parent->address, myconn->parent->port);
PgHGM->p_update_pgsql_error_counter(p_pgsql_error_type::pgsql, myconn->parent->myhgc->hid, myconn->parent->address, myconn->parent->port, ER_PROXYSQL_CHANGE_USER_TIMEOUT);
if ((myds->myconn->reusable == true) && myds->myconn->IsActiveTransaction() == false && myds->myconn->MultiplexDisabled() == false) {
retry_conn = true;
}
myds->destroy_MySQL_Connection_From_Pool(false);
myds->fd = 0;
if (retry_conn) {
myds->DSS = STATE_NOT_INITIALIZED;
NEXT_IMMEDIATE_NEW(CONNECTING_SERVER);
}
*_rc = -1;
return false;
}
else {
// rc==1 , nothing to do for now
}
}
}
return false;
}
bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, const char* var_name, const char* var_value,
bool no_quote, bool set_transaction) {
bool ret = false;
@ -1650,6 +1558,9 @@ bool PgSQL_Session::handler_again___status_SETTING_GENERIC_VARIABLE(int* _rc, co
query = NULL;
}
if (rc == 0) {
if (strncasecmp(var_name, "client_encoding", sizeof("client_encoding")-1) == 0) {
__sync_fetch_and_add(&PgHGM->status.backend_set_client_encoding, 1);
}
myds->revents |= POLLOUT; // we also set again POLLOUT to send a query immediately!
myds->DSS = STATE_MARIADB_GENERIC;
st = previous_status.top();
@ -3306,11 +3217,6 @@ handler_again:
}
if (locked_on_hostgroup == -1 || locked_on_hostgroup_and_all_variables_set == false) {
// verify charset
if (verify_set_names(this)) {
goto handler_again;
}
for (auto i = 0; i < PGSQL_NAME_LAST_LOW_WM; i++) {
auto client_hash = client_myds->myconn->var_hash[i];
#ifdef DEBUG
@ -3596,9 +3502,6 @@ bool PgSQL_Session::handler_again___multiple_statuses(int* rc) {
//case SETTING_INIT_CONNECT:
// ret = handler_again___status_SETTING_INIT_CONNECT(rc);
break;
case SETTING_CHARSET:
ret = handler_again___status_CHANGING_CHARSET(rc);
break;
default:
break;
}
@ -3926,7 +3829,7 @@ void PgSQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE(
else {
client_addr = strdup((char*)"");
}
if (client_myds->myconn->userinfo->username) {
if (client_myds->myconn->userinfo->username && client_myds->myconn->userinfo->username[0] != '\0') {
char* _s = (char*)malloc(strlen(client_myds->myconn->userinfo->username) + 100 + strlen(client_addr));
uint8_t _pid = 2;
if (client_myds->switching_auth_stage) _pid += 2;
@ -4296,7 +4199,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
int idx = PGSQL_NAME_LAST_HIGH_WM;
for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
if (variable_name_exists(pgsql_tracked_variables[i], var.c_str()) == true) {
idx = pgsql_tracked_variables[i].idx;
idx = i;
break;
}
}
@ -4347,8 +4250,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", var.c_str(), value1.c_str());
uint32_t var_hash_int = SpookyHash::Hash32(value1.c_str(), value1.length(), 10);
if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value1.c_str())) {
if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, idx, value1.c_str(), true)) {
return false;
}
if (idx == PGSQL_DATESTYLE) {
@ -4628,7 +4531,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
l_free(pkt->size, pkt->ptr);
return true;
}*/ else if (match_regexes && match_regexes[3]->match(dig)) {
} else if (match_regexes && match_regexes[3]->match(dig)) {
std::vector<std::pair<std::string, std::string>> param_status;
PgSQL_Set_Stmt_Parser parser(nq);
std::string charset = parser.parse_character_set();
@ -4680,7 +4583,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
RequestEnd(NULL);
l_free(pkt->size, pkt->ptr);
return true;
} else {
}*/ else {
unable_to_parse_set_statement(lock_hostgroup);
return false;
}
@ -4719,26 +4622,41 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
if (strncasecmp(nq.c_str(), "ALL", 3) == 0) {
for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) {
if (i == PGSQL_NAME_LAST_LOW_WM)
continue;
const char* name = pgsql_tracked_variables[idx].set_variable_name;
const char* value = get_default_session_variable((enum pgsql_variable_name)idx);
const char* name = pgsql_tracked_variables[i].set_variable_name;
const char* value = get_default_session_variable((enum pgsql_variable_name)i);
proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value);
uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10);
if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[i].idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[i].idx, value)) {
if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, idx, value, false)) {
return false;
}
if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) {
param_status.push_back(std::make_pair(name, value));
}
}
}
for (int idx : client_myds->myconn->dynamic_variables_idx) {
assert(idx < PGSQL_NAME_LAST_HIGH_WM);
const char* name = pgsql_tracked_variables[idx].set_variable_name;
const char* value = get_default_session_variable((enum pgsql_variable_name)idx);
proxy_debug(PROXY_DEBUG_MYSQL_COM, 8, "Changing connection %s to %s\n", name, value);
uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10);
if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, idx, value, false)) {
return false;
}
if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[i])) {
if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) {
param_status.push_back(std::make_pair(name, value));
}
}
}
client_myds->myconn->reorder_dynamic_variables_idx();
} else if (std::find(pgsql_variables.ignore_vars.begin(), pgsql_variables.ignore_vars.end(), nq) != pgsql_variables.ignore_vars.end()) {
// this is a variable we parse but ignore
#ifdef DEBUG
@ -4763,8 +4681,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
const char* value = get_default_session_variable((enum pgsql_variable_name)idx);
uint32_t var_hash_int = SpookyHash::Hash32(value, strlen(value), 10);
if (pgsql_variables.client_get_hash(this, pgsql_tracked_variables[idx].idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, pgsql_tracked_variables[idx].idx, value)) {
if (pgsql_variables.client_get_hash(this, idx) != var_hash_int) {
if (!pgsql_variables.client_set_value(this, idx, value, true)) {
return false;
}
if (IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx])) {

@ -202,6 +202,7 @@ std::map<std::string,std::vector<std::string>> PgSQL_Set_Stmt_Parser::parse2() {
return result;
}
#if 0
std::string PgSQL_Set_Stmt_Parser::parse_character_set() {
#ifdef DEBUG
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 4, "Parsing query %s\n", query.c_str());
@ -224,7 +225,7 @@ std::string PgSQL_Set_Stmt_Parser::parse_character_set() {
delete opt2;
return value3;
}
#endif
std::string PgSQL_Set_Stmt_Parser::remove_comments(const std::string& q) {
std::string result = "";
bool in_multiline_comment = false;

@ -10,58 +10,40 @@
#include <sstream>
/*static inline char is_digit(char c) {
if(c >= '0' && c <= '9')
return 1;
return 0;
}*/
pgsql_verify_var PgSQL_Variables::verifiers[PGSQL_NAME_LAST_HIGH_WM];
pgsql_update_var PgSQL_Variables::updaters[PGSQL_NAME_LAST_HIGH_WM];
PgSQL_Variables::PgSQL_Variables() {
// add here all the variables we want proxysql to recognize, but ignore
ignore_vars.push_back("application_name");
//ignore_vars.push_back("interactive_timeout");
//ignore_vars.push_back("wait_timeout");
//ignore_vars.push_back("net_read_timeout");
//ignore_vars.push_back("net_write_timeout");
//ignore_vars.push_back("net_buffer_length");
//ignore_vars.push_back("read_buffer_size");
//ignore_vars.push_back("read_rnd_buffer_size");
// NOTE: This variable has been temporarily ignored. Check issues #3442 and #3441.
//ignore_vars.push_back("session_track_schema");
variables_regexp = "";
/*
NOTE:
make special ATTENTION that the order in pgsql_variable_name
and pgsqll_tracked_variables[] is THE SAME
NOTE:
PgSQL_Variables::PgSQL_Variables() has a built-in check to make sure that the order is correct,
and that variables are in alphabetical order
*/
for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
//Array index and enum value (idx) should be same
assert(i == pgsql_tracked_variables[i].idx);
if (i > PGSQL_NAME_LAST_LOW_WM + 1) {
assert(strcmp(pgsql_tracked_variables[i].set_variable_name, pgsql_tracked_variables[i - 1].set_variable_name) > 0);
}
// we initialized all the internal_variable_name if set to NULL
if (pgsql_tracked_variables[i].internal_variable_name == NULL) {
pgsql_tracked_variables[i].internal_variable_name = pgsql_tracked_variables[i].set_variable_name;
}
}
/*
NOTE:
make special ATTENTION that the order in pgsql_variable_name
and pgsqll_tracked_variables[] is THE SAME
NOTE:
PgSQL_Variables::PgSQL_Variables() has a built-in check to make sure that the order is correct,
and that variables are in alphabetical order
*/
for (int i = PGSQL_NAME_LAST_LOW_WM; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
assert(i == pgsql_tracked_variables[i].idx);
if (i > PGSQL_NAME_LAST_LOW_WM+1) {
assert(strcmp(pgsql_tracked_variables[i].set_variable_name, pgsql_tracked_variables[i-1].set_variable_name) > 0);
}
}
for (auto i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) {
if (i == PGSQL_CLIENT_ENCODING) {
PgSQL_Variables::updaters[i] = NULL;
PgSQL_Variables::verifiers[i] = NULL;
} else {
PgSQL_Variables::verifiers[i] = verify_server_variable;
PgSQL_Variables::updaters[i] = update_server_variable;
}
PgSQL_Variables::verifiers[i] = verify_server_variable;
PgSQL_Variables::updaters[i] = update_server_variable;
if (pgsql_tracked_variables[i].status == SETTING_VARIABLE) {
variables_regexp += pgsql_tracked_variables[i].set_variable_name;
variables_regexp += "|";
@ -74,6 +56,7 @@ PgSQL_Variables::PgSQL_Variables() {
}
}
}
for (std::vector<std::string>::iterator it=ignore_vars.begin(); it != ignore_vars.end(); it++) {
variables_regexp += *it;
variables_regexp += "|";
@ -135,7 +118,7 @@ void PgSQL_Variables::server_set_hash_and_value(PgSQL_Session* session, int idx,
session->mybe->server_myds->myconn->variables[idx].value = strdup(value);
}
bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value) {
bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const std::string& value, bool reorder_dynamic_variables_idx) {
if (!session || !session->client_myds || !session->client_myds->myconn) {
proxy_warning("Session validation failed\n");
return false;
@ -147,7 +130,7 @@ bool PgSQL_Variables::client_set_value(PgSQL_Session* session, int idx, const st
}
session->client_myds->myconn->variables[idx].value = strdup(value.c_str());
if (idx > PGSQL_NAME_LAST_LOW_WM) {
if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) {
// we now regererate dynamic_variables_idx
session->client_myds->myconn->reorder_dynamic_variables_idx();
}
@ -168,7 +151,7 @@ uint32_t PgSQL_Variables::client_get_hash(PgSQL_Session* session, int idx) const
return session->client_myds->myconn->var_hash[idx];
}
void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value) {
void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const char* value, bool reorder_dynamic_variables_idx) {
assert(session);
assert(session->mybe);
assert(session->mybe->server_myds);
@ -181,7 +164,7 @@ void PgSQL_Variables::server_set_value(PgSQL_Session* session, int idx, const ch
}
session->mybe->server_myds->myconn->variables[idx].value = strdup(value);
if (idx > PGSQL_NAME_LAST_LOW_WM) {
if (reorder_dynamic_variables_idx && idx > PGSQL_NAME_LAST_LOW_WM) {
// we now regererate dynamic_variables_idx
session->mybe->server_myds->myconn->reorder_dynamic_variables_idx();
}
@ -253,98 +236,6 @@ bool PgSQL_Variables::verify_variable(PgSQL_Session* session, int idx) const {
return ret;
}
bool validate_charset(PgSQL_Session* session, int idx, int &_rc) {
/*if (idx == PGSQL_CLIENT_ENCODING || idx == PGSQL_SET_NAMES) {
PgSQL_Data_Stream *myds = session->mybe->server_myds;
PgSQL_Connection *myconn = myds->myconn;
char msg[128];
const MARIADB_CHARSET_INFO *ci = NULL;
const char* replace_collation = NULL;
const char* not_supported_collation = NULL;
unsigned int replace_collation_nr = 0;
std::stringstream ss;
int charset = atoi(pgsql_variables.client_get_value(session, idx));
if (charset >= 255 && myconn->pgsql->server_version[0] != '8') {
switch(pgsql_thread___handle_unknown_charset) {
case HANDLE_UNKNOWN_CHARSET__DISCONNECT_CLIENT:
snprintf(msg,sizeof(msg),"Can't initialize character set %s", pgsql_variables.client_get_value(session, idx));
proxy_error("Can't initialize character set on %s, %d: Error %d (%s). Closing client connection %s:%d.\n",
myconn->parent->address, myconn->parent->port, 2019, msg, session->client_myds->addr.addr, session->client_myds->addr.port);
myds->destroy_MySQL_Connection_From_Pool(false);
myds->fd=0;
_rc=-1;
return false;
case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT_VERBOSE:
ci = proxysql_find_charset_nr(charset);
if (!ci) {
// LCOV_EXCL_START
proxy_error("Cannot find character set [%s]\n", pgsql_variables.client_get_value(session, idx));
assert(0);
// LCOV_EXCL_STOP
}
not_supported_collation = ci->name;
if (idx == SQL_COLLATION_CONNECTION) {
ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]);
} else {
if (pgsql_thread___default_variables[idx]) {
ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]);
} else {
ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]);
}
}
if (!ci) {
// LCOV_EXCL_START
proxy_error("Cannot find character set [%s]\n", pgsql_thread___default_variables[idx]);
assert(0);
// LCOV_EXCL_STOP
}
replace_collation = ci->name;
replace_collation_nr = ci->nr;
proxy_warning("Server doesn't support collation (%s) %s. Replacing it with the configured default (%d) %s. Client %s:%d\n",
pgsql_variables.client_get_value(session, idx), not_supported_collation,
replace_collation_nr, replace_collation, session->client_myds->addr.addr, session->client_myds->addr.port);
ss << replace_collation_nr;
pgsql_variables.client_set_value(session, idx, ss.str());
_rc=0;
return true;
case HANDLE_UNKNOWN_CHARSET__REPLACE_WITH_DEFAULT:
if (idx == SQL_COLLATION_CONNECTION) {
ci = proxysql_find_charset_collate(pgsql_thread___default_variables[idx]);
} else {
if (pgsql_thread___default_variables[idx]) {
ci = proxysql_find_charset_name(pgsql_thread___default_variables[idx]);
} else {
ci = proxysql_find_charset_name(pgsql_thread___default_variables[SQL_CHARACTER_SET]);
}
}
if (!ci) {
// LCOV_EXCL_START
proxy_error("Cannot filnd charset [%s]\n", pgsql_thread___default_variables[idx]);
assert(0);
// LCOV_EXCL_STOP
}
replace_collation_nr = ci->nr;
ss << replace_collation_nr;
pgsql_variables.client_set_value(session, idx, ss.str());
_rc=0;
return true;
default:
proxy_error("Wrong configuration of the handle_unknown_charset\n");
_rc=-1;
return false;
}
}
}*/
_rc=0;
return true;
}
bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) {
bool no_quote = true;
if (IS_PGTRACKED_VAR_OPTION_SET_QUOTE(pgsql_tracked_variables[idx])) no_quote = false;
@ -352,49 +243,12 @@ bool update_server_variable(PgSQL_Session* session, int idx, int &_rc) {
const char *set_var_name = pgsql_tracked_variables[idx].set_variable_name;
bool ret = false;
if (!validate_charset(session, idx, _rc)) {
return false;
}
const char* value = pgsql_variables.client_get_value(session, idx);
pgsql_variables.server_set_value(session, idx, value);
pgsql_variables.server_set_value(session, idx, value, true);
ret = session->handler_again___status_SETTING_GENERIC_VARIABLE(&_rc, set_var_name, value, no_quote, st);
return ret;
}
bool verify_set_names(PgSQL_Session* session) {
uint32_t client_charset_hash = pgsql_variables.client_get_hash(session, PGSQL_CLIENT_ENCODING);
if (client_charset_hash == 0)
return false;
if (client_charset_hash != pgsql_variables.server_get_hash(session, PGSQL_CLIENT_ENCODING)) {
switch(session->status) { // this switch can be replaced with a simple previous_status.push(status), but it is here for readibility
case PROCESSING_QUERY:
session->previous_status.push(PROCESSING_QUERY);
break;
/*
case PROCESSING_STMT_PREPARE:
session->previous_status.push(PROCESSING_STMT_PREPARE);
break;
case PROCESSING_STMT_EXECUTE:
session->previous_status.push(PROCESSING_STMT_EXECUTE);
break;
*/
default:
// LCOV_EXCL_START
proxy_error("Wrong status %d\n", session->status);
assert(0);
break;
// LCOV_EXCL_STOP
}
session->set_status(SETTING_CHARSET);
const char* client_charset_value = pgsql_variables.client_get_value(session, PGSQL_CLIENT_ENCODING);
pgsql_variables.server_set_hash_and_value(session, PGSQL_CLIENT_ENCODING, client_charset_value, client_charset_hash);
return true;
}
return false;
}
inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t client_hash, uint32_t server_hash) {
if (client_hash && client_hash != server_hash) {
// Edge case for set charset command, because we do not know database character set
@ -427,7 +281,7 @@ inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t cli
// LCOV_EXCL_STOP
}
session->set_status(pgsql_tracked_variables[idx].status);
pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx));
pgsql_variables.server_set_value(session, idx, pgsql_variables.client_get_value(session, idx), true);
return true;
}
return false;

@ -422,6 +422,31 @@ bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const pa
return true;
}
bool pgsql_variable_validate_client_encoding(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) {
(void)params;
if (transformed_value) *transformed_value = nullptr;
int charset_encoding = PgSQL_Connection::char_to_encoding(value);
if (charset_encoding == -1) {
return false;
}
if (transformed_value) {
*transformed_value = strdup(value);
if (*transformed_value) { // Ensure strdup succeeded
char* tmp_val = *transformed_value;
while (*tmp_val) {
*tmp_val = toupper((unsigned char)*tmp_val);
tmp_val++;
}
}
}
return true;
}
const pgsql_variable_validator pgsql_variable_validator_bool = {
.type = VARIABLE_TYPE_BOOL,
@ -482,3 +507,9 @@ const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = {
.uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB
}
};
const pgsql_variable_validator pgsql_variable_validator_client_encoding = {
.type = VARIABLE_TYPE_CLIENT_ENCODING,
.validate = &pgsql_variable_validate_client_encoding,
.params = {}
};

@ -0,0 +1,909 @@
/**
* @file pgsql-connection_parameters_test-t-t.cpp
* @brief This TAP test validates if connection parameters are correctly handled by ProxySQL.
*
* Note:
* - Using libpq to test ProxySQL's handling of undocumented parameters isn't possible, as libpq enforces a strict subset of PostgreSQL
* connection parameters as per the official documentation, rejecting any undocumented parameters. However, actual
* PostgreSQL servers accept additional parameters (e.g., extra_float_digits) and apply them at the connection/session level.
* To test this behavior, a raw socket is used to connect to a ProxySQL server and send custom built messages to communicate
* with ProxySQL. It currently works with plain text password authentication, without ssl support.
*
* - Failure due to an invalid parameter returned by the PostgreSQL server, differs from ProxySQL's behavior.
* PostgreSQL returns an error during the connection handshake phase, whereas in ProxySQL, the connection succeeds,
* but the error is encountered when executing a query.
* This is behaviour is intentional, as newer PostgreSQL versions may introduce new parameters that ProxySQL is not yet aware of.
*
*
*/
#include <unistd.h>
#include <arpa/inet.h>
#include <string>
#include <sstream>
#include <chrono>
#include <thread>
#include "libpq-fe.h"
#include "command_line.h"
#include "tap.h"
#include "utils.h"
CommandLine cl;
using PGConnPtr = std::unique_ptr<PGconn, decltype(&PQfinish)>;
enum ConnType {
ADMIN,
BACKEND
};
PGConnPtr createNewConnection(ConnType conn_type, const std::string& parameters = "", bool with_ssl = false) {
const char* host = (conn_type == BACKEND) ? cl.pgsql_host : cl.pgsql_admin_host;
int port = (conn_type == BACKEND) ? cl.pgsql_port : cl.pgsql_admin_port;
const char* username = (conn_type == BACKEND) ? cl.pgsql_username : cl.admin_username;
const char* password = (conn_type == BACKEND) ? cl.pgsql_password : cl.admin_password;
std::stringstream ss;
ss << "host=" << host << " port=" << port;
ss << " user=" << username << " password=" << password;
ss << (with_ssl ? " sslmode=require" : " sslmode=disable");
if (parameters.empty() == false) {
ss << " " << parameters;
}
PGconn* conn = PQconnectdb(ss.str().c_str());
if (PQstatus(conn) != CONNECTION_OK) {
fprintf(stderr, "Connection failed to '%s': %s\n", (conn_type == BACKEND ? "Backend" : "Admin"), PQerrorMessage(conn));
PQfinish(conn);
return PGConnPtr(nullptr, &PQfinish);
}
return PGConnPtr(conn, &PQfinish);
}
bool executeQueries(PGconn* conn, const std::vector<std::string>& queries) {
auto fnResultType = [](const char* query) -> int {
const char* fs = strchr(query, ' ');
size_t qtlen = strlen(query);
if (fs != NULL) {
qtlen = (fs - query) + 1;
}
char buf[qtlen];
memcpy(buf, query, qtlen - 1);
buf[qtlen - 1] = 0;
if (strncasecmp(buf, "SELECT", sizeof("SELECT") - 1) == 0) {
return PGRES_TUPLES_OK;
}
else if (strncasecmp(buf, "COPY", sizeof("COPY") - 1) == 0) {
return PGRES_COPY_OUT;
}
return PGRES_COMMAND_OK;
};
for (const auto& query : queries) {
diag("Running: %s", query.c_str());
PGresult* res = PQexec(conn, query.c_str());
bool success = PQresultStatus(res) == fnResultType(query.c_str());
if (!success) {
fprintf(stderr, "Failed to execute query '%s': %s\n",
query.c_str(), PQerrorMessage(conn));
PQclear(res);
return false;
}
PQclear(res);
}
return true;
}
struct parameter {
std::string name;
std::string value;
};
struct parameter_test {
std::vector<parameter> set_admin_vars; // Admin variables to set
std::vector<parameter> conn_params; // Parameters in startup message
std::vector<std::string> conn_options; // Options (-c flags) in startup
std::vector<std::string> set_commands; // SET commands after connection
std::vector<std::string> expected; // Expected SHOW values
bool reset_after; // Whether to RESET parameters
bool expect_failure; // If connection/query should fail
};
/**
* @struct MyPGresult
* @brief Represents the result of a PostgreSQL query.
*
* This structure holds the columns, rows, status, and error message of a PostgreSQL query result.
*/
struct MyPGresult {
std::vector<std::string> columns; ///< Column names of the result set.
std::vector<std::vector<std::string>> rows; ///< Rows of the result set.
std::string status; ///< Status of the query execution.
std::string error; ///< Error message if the query failed.
};
struct PgSQLResponse {
char type; ///< Type of the response message.
int32_t length; ///< Length of the response message.
std::vector<char> data; ///< Data of the response message.
};
void add_param(std::vector<char>& msg, std::string_view key, std::string_view value) {
msg.insert(msg.end(), key.begin(), key.end());
msg.push_back('\0');
msg.insert(msg.end(), value.begin(), value.end());
msg.push_back('\0');
}
// Function to receive data from socket
bool recv_data(int sock, void* buffer, size_t len) {
if (recv(sock, buffer, len, 0) <= 0) {
fprintf(stderr, "Error receiving data\n");
return false;
}
return true;
}
// Function to send data over socket
bool send_data(int sock, const void* data, size_t len) {
if (send(sock, data, len, 0) != len) {
fprintf(stderr, "Error sending data\n");
return false;
}
return true;
}
/**
* @brief Builds a startup message for PostgreSQL connection.
*
* This function constructs a startup message for PostgreSQL connection using the provided
* connection parameters.
*
* @param parameters A vector of key-value pairs representing the connection parameters.
* @return A vector of characters representing the constructed startup message.
*/
std::vector<char> build_startup_message(const std::vector<parameter>& parameters) {
// Build startup message
std::vector<char> startup_body;
int32_t protocol = htonl(0x00030000); // Protocol 3.0
startup_body.insert(startup_body.end(), (char*)&protocol, (char*)&protocol + 4);
// Add connection parameters
for (const auto& param : parameters) {
add_param(startup_body, param.name, param.value);
}
startup_body.push_back('\0');
// Prepend message length
std::vector<char> startup_msg;
int32_t len = htonl(startup_body.size() + 4);
startup_msg.insert(startup_msg.end(), (char*)&len, (char*)&len + 4);
startup_msg.insert(startup_msg.end(), startup_body.begin(), startup_body.end());
return startup_msg;
}
/**
* @brief Builds a password message for PostgreSQL authentication.
*
* This function constructs a password message to be sent to the PostgreSQL server
* during the authentication process.
*
* @param password The password to be included in the message.
* @return A vector of characters representing the constructed password message.
*/
std::vector<char> build_password_message(std::string_view password) {
std::vector<char> password_message;
int pass_msg_len = htonl(password.size() + 1 + 4);
password_message.push_back('p');
password_message.insert(password_message.end(), (char*)&pass_msg_len, (char*)&pass_msg_len + 4);
password_message.insert(password_message.end(), password.begin(), password.end());
password_message.push_back('\0');
return password_message;
}
/**
* @brief Connects to a PostgreSQL server.
*
* This function establishes a connection to a PostgreSQL server using the provided
* host and port number.
*
* @param host The hostname or IP address of the PostgreSQL server.
* @param port The port number of the PostgreSQL server.
* @return The socket file descriptor for the connection.
*/
int connect_server(const std::string& host, int port) {
int sock;
struct sockaddr_in server;
// Create socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == -1) {
perror("Socket creation failed");
return 1;
}
// Configure server address
server.sin_family = AF_INET;
server.sin_port = htons(port);
server.sin_addr.s_addr = inet_addr(host.c_str());
// Connect to PostgreSQL server
if (connect(sock, (struct sockaddr*)&server, sizeof(server)) < 0) {
fprintf(stderr, "Connection failed\n");
return -1;
}
return sock;
}
/**
* @brief Sends a startup message to the PostgreSQL server.
*
* This function constructs and sends a startup message to the PostgreSQL server
* using the provided socket and connection parameters.
*
* @param sock The socket file descriptor for the connection.
* @param params A vector of key-value pairs representing the connection parameters.
*/
void send_startup_message(int sock, const std::vector<std::pair<std::string, std::string>>& params) {
int param_count = params.size();
char msg[512];
int offset = 0;
uint32_t length = 4 + 4;
for (int i = 0; i < param_count; i++) {
length += params[i].first.size() + 1 + params[i].second.size() + 1;
}
length += 1;
uint32_t length_nbo = htonl(length);
uint32_t protocol = htonl(196608);
memcpy(msg + offset, &length_nbo, 4);
offset += 4;
memcpy(msg + offset, &protocol, 4);
offset += 4;
for (int i = 0; i < param_count; i++) {
strcpy(msg + offset, params[i].first.c_str());
offset += params[i].first.size() + 1;
strcpy(msg + offset, params[i].second.c_str());
offset += params[i].second.size() + 1;
}
msg[offset++] = '\0';
send(sock, msg, offset, 0);
}
/**
* @brief Parses an error message from a PostgreSQL response payload.
*
* This function extracts and returns the error message from a given PostgreSQL response payload.
*
* @param payload A vector of characters representing the PostgreSQL response payload.
* @return A string containing the extracted error message.
*/
std::string parse_error(const std::vector<char>& payload) {
const char* data = payload.data();
size_t pos = 0;
std::string message;
while (pos < payload.size()) {
if (data[pos] == '\0') break;
char field_type = data[pos++];
std::string field_value;
while (pos < payload.size() && data[pos] != '\0') {
field_value += data[pos++];
}
pos++;
if (field_type == 'M') message = field_value;
}
return message;
}
/**
* @brief Handles cleartext password authentication for PostgreSQL.
*
* This function processes the cleartext password authentication request from the PostgreSQL server.
*
* @param sock The socket file descriptor for the connection.
* @param password The password to be sent for authentication.
* @return True if authentication is successful, false otherwise.
*/
bool handle_cleartext_auth(int sock, std::string_view password) {
char response[1024];
if (recv_data(sock, response, 9) == false) { // Read first 8 bytes (message type + length + auth type)
fprintf(stderr, "Error: failed to receive authentication message in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
if (response[0] != 'R') {
fprintf(stderr, "Unexpected authentication response\n");
return false;
}
uint32_t auth_type;
memcpy(&auth_type, response + 5, 4);
auth_type = ntohl(auth_type);
if (auth_type == 3) { // AuthenticationCleartextPassword
diag("Server requests cleartext password authentication\n");
std::vector<char> password_msg = build_password_message(password);
// Send PasswordMessage
if (send_data(sock, password_msg.data(), password_msg.size()) == false) {
fprintf(stderr, "Error: failed to send password message in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
// Receive AuthenticationOK or Failure
if (recv_data(sock, response, 5) == false) {
fprintf(stderr, "Error: failed to receive authentication response in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
if (response[0] == 'E') {
uint32_t err_msg_len = ntohl(*reinterpret_cast<uint32_t*>(response + 1));
if (recv_data(sock, response, err_msg_len - 4) == false) {
fprintf(stderr, "Error: failed to receive error message in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
const std::vector<char> payload (response, response + err_msg_len - 4);
std::string error_message = parse_error(payload);
fprintf(stderr, "%s\n", error_message.c_str());
return false;
}
else if (response[0] != 'R') {
fprintf(stderr, "Unexpected authentication response\n");
return false;
}
if (recv_data(sock, response, 4) == false) {
fprintf(stderr, "Error: failed to receive error message in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
memcpy(&auth_type, response, 4);
auth_type = ntohl(auth_type);
if (auth_type != 0) {
diag("Authentication failed!\n");
return false;
}
std::vector<PgSQLResponse> msg_list;
while (true) {
int bytes_received = recv(sock, response, sizeof(response), MSG_DONTWAIT);
if (bytes_received == 0) {
fprintf(stderr, "Error: Connection closed in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
if (bytes_received == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break;
}
fprintf(stderr, "Error: Connection closed or error occurred in file %s, line %d\n", __FILE__, __LINE__);
return false;
}
size_t offset = 0;
while (offset < static_cast<size_t>(bytes_received)) {
char messageType = response[offset];
int32_t messageLength;
if (offset + 5 > static_cast<size_t>(bytes_received)) {
fprintf(stderr, "Incomplete message header received\n");
return false;
}
memcpy(&messageLength, response + offset + 1, 4);
messageLength = ntohl(messageLength);
if (offset + messageLength > static_cast<size_t>(bytes_received)) {
fprintf(stderr, "Incomplete message body received.\n");
return false;
}
PgSQLResponse msg;
msg.type = messageType;
msg.length = messageLength;
msg.data.assign(response + offset + 5, response + offset + messageLength + 1);
msg_list.emplace_back(std::move(msg));
offset += messageLength + 1;
}
}
if (msg_list.back().type == 'E') {
const std::vector<char> payload = msg_list.back().data;
std::string error_message = parse_error(payload);
fprintf(stderr, "%s\n", error_message.c_str());
return false;
} else if (msg_list.back().type != 'Z') {
fprintf(stderr, "Unexpected message type %c\n", msg_list.back().type);
return false;
}
} else {
diag("Unexpected authentication method: %d\n", auth_type);
return false;
}
diag("Authentication successful!\n");
return true;
}
/**
* @brief Executes a query on the PostgreSQL server.
*
* This function sends a query to the PostgreSQL server and processes the response.
*
* @param sock The socket file descriptor for the connection.
* @param query The SQL query to be executed.
* @return A unique pointer to a MyPGresult structure containing the query result.
*/
std::unique_ptr<MyPGresult> execute_query(int sock, const std::string& query) {
std::vector<char> msg;
// Build query message
msg.push_back('Q');
uint32_t length = htonl(query.size() + 5); // 1 byte for Q + 4 for length + query + null
msg.insert(msg.end(), reinterpret_cast<char*>(&length), reinterpret_cast<char*>(&length) + 4);
msg.insert(msg.end(), query.begin(), query.end());
msg.push_back('\0');
// Send Query message
if (send_data(sock, msg.data(), msg.size()) == false) {
fprintf(stderr, "Error: failed to send query in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
std::unique_ptr<MyPGresult> result(new MyPGresult);
while (true) {
char msg_type;
if (recv_data(sock, reinterpret_cast<char*>(&msg_type), 1) == false) {
fprintf(stderr, "Error: failed to receive message in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
uint32_t msg_length;
if (recv_data(sock, reinterpret_cast<char*>(&msg_length), 4) == false) {
fprintf(stderr, "Error: failed to receive message length in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
msg_length = ntohl(msg_length);
std::vector<char> payload(msg_length - 4);
if (msg_length > 4) {
if (recv_data(sock, payload.data(), msg_length - 4) == false) {
fprintf(stderr, "Error: failed to receive message payload in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
}
switch (msg_type) {
case 'T': { // Row Description
const char* data = payload.data();
uint16_t num_columns;
memcpy(&num_columns, data, 2);
num_columns = ntohs(num_columns);
size_t pos = 2;
result->columns.reserve(num_columns);
for (int i = 0; i < num_columns; i++) {
std::string col_name(data + pos);
pos += col_name.length() + 1 + 22; // Skip field metadata
result->columns.push_back(col_name);
}
break;
}
case 'D': { // Data Row
const char* data = payload.data();
uint16_t num_fields;
memcpy(&num_fields, data, 2);
num_fields = ntohs(num_fields);
std::vector<std::string> row;
size_t pos = 2;
for (int i = 0; i < num_fields; i++) {
int32_t field_len;
memcpy(&field_len, data + pos, 4);
pos += 4;
field_len = ntohl(field_len);
if (field_len == -1) {
row.push_back("NULL");
}
else {
row.emplace_back(data + pos, field_len);
pos += field_len;
}
}
result->rows.push_back(row);
break;
}
case 'C': // Command Complete
result->status = std::string(payload.data(), payload.size());
break;
case 'E': // Error
result->error = parse_error(payload);
// Consume remaining messages
while (msg_type != 'Z') {
if (recv_data(sock, &msg_type, 1) == false) {
fprintf(stderr, "Error: failed to receive message in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
if (recv_data(sock, reinterpret_cast<char*>(&msg_length), 4) == false) {
fprintf(stderr, "Error: failed to receive message length in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
msg_length = ntohl(msg_length);
if (msg_length > 4) {
std::vector<char> temp(msg_length - 4);
if (recv_data(sock, temp.data(), msg_length - 4) == false) {
fprintf(stderr, "Error: failed to receive message payload in file %s, line %d\n", __FILE__, __LINE__);
return nullptr;
}
}
}
return std::move(result);
case 'Z': // Ready for Query
return std::move(result);
default:
break;
}
}
return nullptr;
}
const char* escape_string_backslash_spaces(const char* input) {
const char* c;
int input_len = 0;
int escape_count = 0;
for (c = input; *c != '\0'; c++) {
if ((*c == ' ')) {
escape_count += 3;
}
else if ((*c == '\\')) {
escape_count += 2;
}
input_len++;
}
if (escape_count == 0)
return input;
char* output = (char*)malloc(input_len + escape_count + 1);
char* p = output;
for (c = input; *c != '\0'; c++) {
if ((*c == ' ')) {
memcpy(p, "\\\\", 2);
p += 2;
}
else if (*c == '\\') {
*(p++) = '\\';
}
*(p++) = *c;
}
*(p++) = '\0';
return output;
}
bool test_parameters(PGconn* admin_conn, const parameter_test& test) {
char buffer[512];
bool ret = false;
for (const auto& parameter : test.set_admin_vars) {
snprintf(buffer, sizeof(buffer), "SET %s='%s'", parameter.name.c_str(), parameter.value.c_str());
if (executeQueries(admin_conn, { buffer, "LOAD PGSQL VARIABLES TO RUNTIME" }) == false) {
BAIL_OUT("Error: failed to set admin variable in file %s, line %d", __FILE__, __LINE__);
return false;
}
}
int sock = connect_server(cl.pgsql_host, cl.pgsql_port);
if (sock == -1) {
BAIL_OUT("Error: failed to connect to the server in file %s, line %d", __FILE__, __LINE__);
return false;
}
// Build startup message
std::vector<parameter> parameters = {
{ "user", cl.pgsql_username },
};
parameters.reserve(test.conn_params.size() + test.conn_options.size() + 1);
for (const auto& param : test.conn_params) {
if (param.value.empty() == true) continue;
parameters.push_back(param);
}
if (test.conn_options.empty() == false) {
std::string options_value;
for (size_t i = 0; i < test.conn_options.size(); i++) {
options_value += " -c " + test.conn_params[i].name;
options_value += "=";
const char* value = test.conn_options[i].c_str();
const char* escaped_value = escape_string_backslash_spaces(value);
options_value += escaped_value;
if (value != escaped_value)
free((void*)escaped_value);
}
parameters.push_back({ "options", options_value });
}
// print parameters
for (const auto& param : parameters) {
diag("Parameter: %s = %s\n", param.name.c_str(), param.value.c_str());
}
std::vector<char> startup_msg = build_startup_message(parameters);
// Send StartupMessage
if (send_data(sock, startup_msg.data(), startup_msg.size()) == false) {
BAIL_OUT("Error: failed to send startup message in file %s, line %d", __FILE__, __LINE__);
goto cleanup;
}
// Receive AuthenticationRequest
if (handle_cleartext_auth(sock, cl.pgsql_password) == false) {
if (test.expect_failure) {
ok(true, "Authentication should fail");
ret = true;
} else {
diag("Error: failed to handle cleartext authentication in file %s, line %d", __FILE__, __LINE__);
}
goto cleanup;
}
for (size_t i = 0; i < test.set_commands.size(); i++) {
snprintf(buffer, sizeof(buffer), "SET %s='%s'", test.conn_params[i].name.c_str(), test.set_commands[i].c_str());
diag("Executing: %s\n", buffer);
auto result = execute_query(sock, buffer);
if (result == nullptr) {
BAIL_OUT("Error: failed to execute query in file %s, line %d", __FILE__, __LINE__);
goto cleanup;
}
if (result->error.empty() == false) {
ok(test.expect_failure, "Query '%s' should fail. %s", buffer, result->error.c_str());
}
}
if (test.reset_after) {
for (const auto& param : test.conn_params) {
std::string reset_cmd = "RESET " + param.name;
diag("Executing: %s\n", reset_cmd.c_str());
auto result = execute_query(sock, reset_cmd);
if (result == nullptr) {
BAIL_OUT("Error: failed to reset parameter in file %s, line %d", __FILE__, __LINE__);
goto cleanup;
}
if (result->error.empty() == false) {
ok(test.expect_failure, "Query '%s' should fail. %s", reset_cmd.c_str(), result->error.c_str());
}
}
}
for (int i = 0; i < test.conn_params.size(); i++) {
const auto& param = test.conn_params[i];
std::string show_cmd = "SHOW " + param.name;
diag("Executing: %s\n", show_cmd.c_str());
auto result = execute_query(sock, show_cmd);
if (result == nullptr) {
BAIL_OUT("Error: failed to execute query in file %s, line %d", __FILE__, __LINE__);
goto cleanup;
}
if (test.expect_failure == false && result->error.empty()) {
ok(result->rows.size() == 1, "Number of rows should be 1");
ok(result->rows[0][0] == test.expected[i], "Parameter '%s' value should be '%s'. Actual: '%s'",
param.name.c_str(), test.expected[i].c_str(), result->rows[0][0].c_str());
} else {
ok(test.expect_failure, "Query '%s' should fail. %s", show_cmd.c_str(), result->error.c_str());
}
}
ret = true;
cleanup:
close(sock);
return ret;
}
std::vector<parameter_test> test_cases = {
// check if connection parameters validation is working correctly
{ {},
{{"sslmode", "test"}},
{},
{},
{""},
false,
true
},
// check if session parameter validation is working correctly
{ {},
{{"extra_float_digits", "20"}},
{},
{},
{""},
false,
true
},
// check if options parameters validation is working correctly
{ {},
{{"extra_float_digits", "1"}},
{"19"},
{},
{""},
false,
true
},
{ {},
{{"ENABLE_HASHJOIN", "off"}, {"enable_seqscan", "on"}},
{"on", "off"},
{},
{"off", "on"},
false,
false
},
{
{},
{{"extra_float_digits", "1"}},
{"2"},
{},
{"1"},
false,
false
},
{
{},
{{"enable_hashjoin", "off"}, {"enable_seqscan", "on"}},
{"on", "off"},
{"on", "off"},
{"off", "on"},
true,
false
},
{
{{"pgsql-default_datestyle", "ISO, MDY"}},
{{"datestyle", ""}},
{},
{"Postgres"},
{"Postgres, MDY"},
false, // Reset both
false
},
{
{{"pgsql-default_datestyle", "ISO, MDY"}},
{{"datestyle", ""}},
{},
{"Postgres"},
{"ISO, MDY"},
true, // Reset both
false
},
{
{},
{{"escape_string_warning", "on"}, {"standard_conforming_strings", "on"}},
{},
{"off", "off"},
{"off", "off"},
false,
false
},
{
{},
{{"client_encoding", "UTF8"}},
{"LATIN1"},
{},
{"UTF8"},
false,
false
},
{
{},
{{"client_encoding", "UTF8"}},
{"LATIN1"},
{"LATIN1"},
{"LATIN1"},
false,
false
},
{
{{"pgsql-default_client_encoding", "utf8"}},
{{"client_encoding", "UTF8"}},
{"LATIN1"},
{"LATIN1"},
{"UTF8"},
true,
false
},
{
{},
{{"invalid_param", "invalid"}},
{},
{},
{"invalid"},
false,
true
}
};
int main(int argc, char** argv) {
int test_count = 0;
for (const auto& test_case : test_cases) {
if (test_case.expect_failure) {
int case_count = 1;
if (test_case.set_commands.empty() == false)
case_count++;
if (test_case.reset_after)
case_count++;
test_count += test_case.conn_params.size() * case_count;
} else
test_count += test_case.conn_params.size() * 2;
}
plan(test_count);
if (cl.getEnv())
return exit_status();
auto admin_conn = createNewConnection(ConnType::ADMIN, "", false);
if (!admin_conn || PQstatus(admin_conn.get()) != CONNECTION_OK) {
BAIL_OUT("Error: failed to connect to the database in file %s, line %d", __FILE__, __LINE__);
return exit_status();
}
if (executeQueries(admin_conn.get(), { "SET pgsql-authentication_method=1",
"LOAD PGSQL VARIABLES TO RUNTIME" }) == false) {
BAIL_OUT("Error: failed to set pgsql-authentication_method=1 in file %s, line %d", __FILE__, __LINE__);
return exit_status();
}
for (const auto& test_case : test_cases) {
if (test_parameters(admin_conn.get(), test_case) == false) {
BAIL_OUT("Error: failed to test parameters in file %s, line %d", __FILE__, __LINE__);
return exit_status();
}
}
return exit_status();
}
Loading…
Cancel
Save