diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index 99d9d6e75..db903c205 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -734,7 +734,8 @@ public: * updates the output buffer with the generated packet. If `ready` is * true, it also generates and sends a ready-for-query packet. */ - bool generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state = 'I', PtrSize_t* _ptr = NULL); + bool generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state = 'I', PtrSize_t* _ptr = NULL, + const std::vector>& param_status = std::vector>()); // temporary overriding generate_pkt_OK to avoid crash. FIXME remove this bool generate_pkt_OK(bool send, void** ptr, unsigned int* len, uint8_t sequence_id, unsigned int affected_rows, diff --git a/include/PgSQL_Variables.h b/include/PgSQL_Variables.h index c374228ca..13fef0570 100644 --- a/include/PgSQL_Variables.h +++ b/include/PgSQL_Variables.h @@ -45,8 +45,8 @@ public: bool verify_variable(PgSQL_Session* session, int idx) const; bool update_variable(PgSQL_Session* session, session_status status, int &_rc); - bool parse_variable_boolean(PgSQL_Session*sess, int idx, std::string &value1, bool* lock_hostgroup); - bool parse_variable_number(PgSQL_Session*sess, int idx, std::string &value1, bool* lock_hostgroup); + bool parse_variable_boolean(PgSQL_Session*sess, int idx, const std::string &value1, bool *lock_hostgroup, bool *send_param_status); + bool parse_variable_number(PgSQL_Session*sess, int idx, std::string &value1, bool *lock_hostgroup, bool *send_param_status); }; #endif // PGSQL_VARIABLES_H diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index b7b4b807d..85df4d706 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -1371,7 +1371,8 @@ char* extract_tag_from_query(const char* query) { } -bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state, PtrSize_t* _ptr) { +bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state, PtrSize_t* _ptr, + const std::vector>& param_status) { // to avoid memory leak assert(send == true || _ptr); @@ -1402,6 +1403,10 @@ bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, } free(tag); + for (auto& [param_name, param_value] : param_status) { + pgpkt.write_ParameterStatus(param_name.c_str(), param_value.c_str()); + } + if (ready == true) { pgpkt.write_ReadyForQuery(trx_state); pgpkt.set_multi_pkt_mode(false); diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index c98fd8200..ab0d9835b 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -435,8 +435,9 @@ inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t cli return false; } -bool PgSQL_Variables::parse_variable_boolean(PgSQL_Session *sess, int idx, string& value1, bool * lock_hostgroup) { +bool PgSQL_Variables::parse_variable_boolean(PgSQL_Session *sess, int idx, const std::string& value1, bool* lock_hostgroup, bool* send_param_status) { proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Processing SET %s value %s\n", pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); + *send_param_status = false; int __tmp_value = -1; if ( (strcasecmp(value1.c_str(),(char *)"0")==0) || @@ -454,17 +455,17 @@ bool PgSQL_Variables::parse_variable_boolean(PgSQL_Session *sess, int idx, strin } } + if (__tmp_value >= 0) { - proxy_debug(PROXY_DEBUG_MYSQL_COM, 7, "Processing SET %s value %s\n", pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); - uint32_t var_value_int=SpookyHash::Hash32(value1.c_str(),value1.length(),10); + const char* val = __tmp_value ? "ON" : "OFF"; + proxy_debug(PROXY_DEBUG_MYSQL_COM, 7, "Processing SET %s value %s\n", pgsql_tracked_variables[idx].set_variable_name, val); + uint32_t var_value_int=SpookyHash::Hash32(val, strlen(val), 10); if (pgsql_variables.client_get_hash(sess, idx) != var_value_int) { - if (__tmp_value == 0) { - if (!pgsql_variables.client_set_value(sess, idx, "OFF")) - return false; - } else { - if (!pgsql_variables.client_set_value(sess, idx, "ON")) - return false; - } + *send_param_status = IS_PGTRACKED_VAR_OPTION_SET_PARAM_STATUS(pgsql_tracked_variables[idx]); + + if (!pgsql_variables.client_set_value(sess, idx, val)) + return false; + proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Changing connection %s to %s\n", pgsql_tracked_variables[idx].set_variable_name, value1.c_str()); } } else { @@ -476,7 +477,7 @@ bool PgSQL_Variables::parse_variable_boolean(PgSQL_Session *sess, int idx, strin -bool PgSQL_Variables::parse_variable_number(PgSQL_Session *sess, int idx, string& value1, bool * lock_hostgroup) { +bool PgSQL_Variables::parse_variable_number(PgSQL_Session *sess, int idx, string& value1, bool* lock_hostgroup, bool* send_param_status) { int vl = strlen(value1.c_str()); const char *v = value1.c_str(); bool only_digit_chars = true;