From 24fecc1f6e09bec173ff2a801d315f6112c1e374 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Sun, 22 Jun 2025 19:30:31 +0500 Subject: [PATCH] Add PostgreSQL extended query (prepared statement) support in ProxySQL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit lays the groundwork for handling PostgreSQL’s extended query protocol (prepared statements) by introducing parsing, it's caching and execution framework stubs. - Implement Parse handling - Implement caching of parse (reuse) - Introduce Execute framework stub for future binding and execution logic. --- include/PgSQL_Connection.h | 15 +- include/PgSQL_Data_Stream.h | 5 +- include/PgSQL_Logger.hpp | 7 +- include/PgSQL_PreparedStatement.h | 234 ++++++++ include/PgSQL_Protocol.h | 24 +- include/PgSQL_Session.h | 26 +- lib/Base_Thread.cpp | 1 + lib/Makefile | 3 +- lib/PgSQL_Connection.cpp | 138 ++++- lib/PgSQL_Data_Stream.cpp | 18 +- lib/PgSQL_HostGroups_Manager.cpp | 2 +- lib/PgSQL_Logger.cpp | 24 +- lib/PgSQL_PreparedStatement.cpp | 908 ++++++++++++++++++++++++++++++ lib/PgSQL_Protocol.cpp | 167 ++++++ lib/PgSQL_Query_Processor.cpp | 12 +- lib/PgSQL_Session.cpp | 446 +++++++++++---- lib/PgSQL_Thread.cpp | 4 +- lib/PgSQL_Variables.cpp | 2 - src/main.cpp | 10 +- 19 files changed, 1889 insertions(+), 157 deletions(-) create mode 100644 include/PgSQL_PreparedStatement.h create mode 100644 lib/PgSQL_PreparedStatement.cpp diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index 2d47c4075..39f610582 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -12,6 +12,7 @@ class PgSQL_SrvC; class PgSQL_Query_Result; +class PgSQL_STMTs_local_v14; //#define STATUS_MYSQL_CONNECTION_TRANSACTION 0x00000001 // DEPRECATED #define STATUS_MYSQL_CONNECTION_COMPRESSION 0x00000002 #define STATUS_MYSQL_CONNECTION_USER_VARIABLE 0x00000004 @@ -254,12 +255,16 @@ public: void query_cont(short event); void fetch_result_start(); void fetch_result_cont(short event); + void stmt_prepare_start(); + void stmt_prepare_cont(short event); + //void stmt_execute_start(); + //void stmt_execute_cont(short event); void reset_session_start(); void reset_session_cont(short event); int async_connect(short event); - int async_query(short event, char* stmt, unsigned long length); + int async_query(short event, const char* stmt, unsigned long length, const char* backend_stmt_name = nullptr, void* execute_data = nullptr); int async_ping(short event); int async_reset_session(short event); int async_send_simple_command(short event, char* stmt, unsigned long length); // no result set expected @@ -410,6 +415,7 @@ public: const char* get_pg_connection_status_str(); const char* get_pg_transaction_status_str(); unsigned int get_memory_usage() const; + char get_transaction_status_char(); inline int get_backend_pid() { return (pgsql_conn) ? get_pg_backend_pid() : -1; } @@ -437,14 +443,15 @@ public: unsigned int reorder_dynamic_variables_idx(); unsigned int number_of_matching_session_variables(const PgSQL_Connection* client_conn, unsigned int& not_matching); - void set_query(char* stmt, unsigned long length); + void set_query(const char* stmt, unsigned long length, const char* backend_stmt_name = nullptr); void reset(); bool IsKeepMultiplexEnabledVariables(char* query_digest_text); struct { unsigned long length; - char* ptr; + const char* ptr; + const char* backend_stmt_name; } query; struct { @@ -500,7 +507,7 @@ public: bool processing_multi_statement; bool multiplex_delayed; - + PgSQL_STMTs_local_v14* local_stmts; PgSQL_SrvC *parent; PgSQL_Connection_userinfo* userinfo; PgSQL_Data_Stream* myds; diff --git a/include/PgSQL_Data_Stream.h b/include/PgSQL_Data_Stream.h index efecaa4b8..ef27812bc 100644 --- a/include/PgSQL_Data_Stream.h +++ b/include/PgSQL_Data_Stream.h @@ -92,7 +92,7 @@ public: } CompPktOUT; PgSQL_Protocol myprot; - PgSQL_MyDS_real_query mysql_real_query; + PgSQL_MyDS_real_query pgsql_real_query; bytes_stats_t bytes_info; // bytes statistics PtrSize_t multi_pkt; @@ -189,7 +189,6 @@ public: bool available_data_out(); void remove_pollout(); void set_pollout(); - void mysql_free(); void set_net_failure(); void setDSS_STATE_QUERY_SENT_NET(); @@ -265,7 +264,7 @@ public: void return_MySQL_Connection_To_Pool(); void destroy_MySQL_Connection_From_Pool(bool sq); - void free_mysql_real_query(); + void free_pgsql_real_query(); void reinit_queues(); void destroy_queues(); diff --git a/include/PgSQL_Logger.hpp b/include/PgSQL_Logger.hpp index 3888d671a..28492d237 100644 --- a/include/PgSQL_Logger.hpp +++ b/include/PgSQL_Logger.hpp @@ -12,6 +12,7 @@ class PgSQL_Event { char *schemaname; size_t username_len; size_t schemaname_len; + size_t client_stmt_name_len; uint64_t start_time; uint64_t end_time; uint64_t query_digest; @@ -26,20 +27,20 @@ class PgSQL_Event { enum log_event_type et; uint64_t hid; char *extra_info; + char *client_stmt_name; bool have_affected_rows; bool have_rows_sent; uint64_t affected_rows; uint64_t rows_sent; - uint32_t client_stmt_id; - + public: PgSQL_Event(log_event_type _et, uint32_t _thread_id, char * _username, char * _schemaname , uint64_t _start_time , uint64_t _end_time , uint64_t _query_digest, char *_client, size_t _client_len); uint64_t write(std::fstream *f, PgSQL_Session *sess); uint64_t write_query_format_1(std::fstream *f); uint64_t write_query_format_2_json(std::fstream *f); void write_auth(std::fstream *f, PgSQL_Session *sess); - void set_client_stmt_id(uint32_t client_stmt_id); + void set_client_stmt_name(char* client_stmt_name); void set_query(const char *ptr, int len); void set_server(int _hid, const char *ptr, int len); void set_extra_info(char *); diff --git a/include/PgSQL_PreparedStatement.h b/include/PgSQL_PreparedStatement.h new file mode 100644 index 000000000..e35dee074 --- /dev/null +++ b/include/PgSQL_PreparedStatement.h @@ -0,0 +1,234 @@ +#ifndef CLASS_PGSQL_PREPARED_STATEMENT_H +#define CLASS_PGSQL_PREPARED_STATEMENT_H + +#include "proxysql.h" +#include "cpp.h" + +/* +One of the main challenge in handling prepared statement (PS) is that a single +PS could be executed on multiple backends, and on each backend it could have a +different stmt_id. +For this reason ProxySQL returns to the client a stmt_id generated by the proxy +itself, and internally maps client's stmt_id with the backend stmt_id. + +The implementation in ProxySQL is, simplified, the follow: +* when a client sends a MYSQL_COM_STMT_PREPARE, ProxySQL executes it to one of + the backend +* the backend returns a stmt_id. This stmt_id is NOT returned to the client. The + stmt_id returned from the backend is stored in MySQL_STMTs_local(), and + MySQL_STMTs_local() is responsible for mapping the connection's MYSQL_STMT + and a global_stmt_id +* the global_stmt_id is the stmt_id returned to the client +* the global_stmt_id is used to locate the relevant MySQL_STMT_Global_info() in + MySQL_STMT_Manager() +* MySQL_STMT_Global_info() stores all metadata associated with a PS +* MySQL_STMT_Manager() is responsible for storing all MySQL_STMT_Global_info() + in global structures accessible and shareble by all threads. + +To summarie the most important classes: +* MySQL_STMT_Global_info() stores all metadata associated with a PS +* MySQL_STMT_Manager() stores all the MySQL_STMT_Global_info(), indexes using + a global_stmt_id that iis the stmt_id generated by ProxySQL and returned to + the client +* MySQL_STMTs_local() associate PS located in a backend connection to a + global_stmt_id +*/ + +// class MySQL_STMT_Global_info represents information about a MySQL Prepared Statement +// it is an internal representation of prepared statement +// it include all metadata associated with it + +class PgSQL_STMT_Global_info { + private: + void compute_hash(); + public: + pthread_rwlock_t rwlock_; + uint64_t digest; + PGSQL_QUERY_command PgQueryCmd; + char * digest_text; + uint64_t hash; + char *username; + char *schemaname; + char *query; + unsigned int query_length; + int ref_count_client; + int ref_count_server; + uint64_t statement_id; + char* first_comment; + uint64_t total_mem_usage; + struct describe { + uint16_t num_fields; + uint16_t num_params; + uint8_t flags; + } describe; + + bool is_select_NOT_for_update; + PgSQL_STMT_Global_info(uint64_t id, char *u, char *s, char *q, unsigned int ql, char *fc, uint64_t _h); + void update_metadata(MYSQL_STMT *stmt); + ~PgSQL_STMT_Global_info(); + void calculate_mem_usage(); +}; + +#if 0 +// stmt_execute_metadata_t represent metadata required to run STMT_EXECUTE +class pgsql_stmt_execute_metadata_t { + public: + uint32_t size; + uint32_t stmt_id; + uint8_t flags; + uint16_t num_params; + MYSQL_BIND *binds; + my_bool *is_nulls; + unsigned long *lengths; + void *pkt; + pgsql_stmt_execute_metadata_t() { + size = 0; + stmt_id = 0; + binds=NULL; + is_nulls=NULL; + lengths=NULL; + pkt=NULL; + } + ~pgsql_stmt_execute_metadata_t() { + if (binds) + free(binds); + binds = NULL; + if (is_nulls) + free(is_nulls); + is_nulls = NULL; + if (lengths) + free(lengths); + lengths = NULL; + size = 0; + stmt_id = 0; + if (pkt) { + free(pkt); + pkt = NULL; + } + } +}; + +// server side, metadata related to STMT_EXECUTE are stored in MYSQL_STMT itself +// client side, they are stored in stmt_execute_metadata_t +// MySQL_STMTs_meta maps stmt_execute_metadata_t with stmt_id +class PgSQL_STMTs_meta { + private: + unsigned int num_entries; + std::map m; + public: + PgSQL_STMTs_meta() { + num_entries=0; + } + ~PgSQL_STMTs_meta() { + for (std::map::iterator it=m.begin(); it!=m.end(); ++it) { + pgsql_stmt_execute_metadata_t*sem=it->second; + delete sem; + } + } + // we declare it here to be inline + void insert(uint32_t global_statement_id, pgsql_stmt_execute_metadata_t*stmt_meta) { + std::pair::iterator,bool> ret; + ret=m.insert(std::make_pair(global_statement_id, stmt_meta)); + if (ret.second==true) { + num_entries++; + } + } + // we declare it here to be inline + pgsql_stmt_execute_metadata_t* find(uint32_t global_statement_id) { + auto s=m.find(global_statement_id); + if (s!=m.end()) { // found + return s->second; + } + return NULL; // not found + } + + void erase(uint32_t global_statement_id) { + auto s=m.find(global_statement_id); + if (s!=m.end()) { // found + pgsql_stmt_execute_metadata_t*sem=s->second; + delete sem; + num_entries--; + m.erase(s); + } + } +}; +#endif + +// class MySQL_STMTs_local associates a global statement ID with a local statement ID for a specific connection + +class PgSQL_STMTs_local_v14 { +private: + bool is_client_; + std::stack free_backend_ids; + uint32_t local_max_stmt_id = 0; + +public: + // this map associate client_stmt_id to global_stmt_id : this is used only for client connections + std::map stmt_name_to_global_ids; + // this multimap associate global_stmt_id to client_stmt_id : this is used only for client connections + std::multimap global_id_to_stmt_names; + + // this map associate backend_stmt_id to global_stmt_id : this is used only for backend connections + std::map backend_stmt_to_global_ids; + // this map associate global_stmt_id to backend_stmt_id : this is used only for backend connections + std::map global_stmt_to_backend_ids; + + PgSQL_Session *sess; + PgSQL_STMTs_local_v14(bool _ic) : is_client_(_ic), sess(NULL) { } + ~PgSQL_STMTs_local_v14(); + + inline + void set_is_client(PgSQL_Session *_s) { + sess=_s; + is_client_ = true; + } + + inline + bool is_client() { return is_client_; } + inline + unsigned int get_num_backend_stmts() { return backend_stmt_to_global_ids.size(); } + + void backend_insert(uint64_t global_stmt_id, uint32_t backend_stmt_id); + void client_insert(uint64_t global_stmt_id, const std::string& client_stmt_name); + uint64_t compute_hash(char *user, char *schema, char *query, unsigned int query_length); + uint32_t generate_new_backend_stmt_id(); + uint64_t find_global_id_from_stmt_name(const std::string& client_stmt_name); + bool client_close(const std::string& stmt_name); +}; + + +class PgSQL_STMT_Manager_v14 { +private: + uint64_t next_statement_id; + uint64_t num_stmt_with_ref_client_count_zero; + uint64_t num_stmt_with_ref_server_count_zero; + pthread_rwlock_t rwlock_; + std::map map_stmt_id_to_info; // map using statement id + std::map map_stmt_hash_to_info; // map using hashes + std::stack free_stmt_ids; + struct { + uint64_t c_unique; + uint64_t c_total; + uint64_t stmt_max_stmt_id; + uint64_t cached; + uint64_t s_unique; + uint64_t s_total; + } statuses; + time_t last_purge_time; +public: + PgSQL_STMT_Manager_v14(); + ~PgSQL_STMT_Manager_v14(); + PgSQL_STMT_Global_info * find_prepared_statement_by_hash(uint64_t hash); + PgSQL_STMT_Global_info * find_prepared_statement_by_stmt_id(uint64_t id, bool lock=true); + void rdlock() { pthread_rwlock_rdlock(&rwlock_); } + void wrlock() { pthread_rwlock_wrlock(&rwlock_); } + void unlock() { pthread_rwlock_unlock(&rwlock_); } + void ref_count_client(uint64_t _stmt, int _v, bool lock=true); + void ref_count_server(uint64_t _stmt, int _v, bool lock=true); + PgSQL_STMT_Global_info * add_prepared_statement(char *u, char *s, char *q, unsigned int ql, char *fc, bool lock=true); + void get_metrics(uint64_t *c_unique, uint64_t *c_total, uint64_t *stmt_max_stmt_id, uint64_t *cached, uint64_t *s_unique, uint64_t *s_total); + SQLite3_result* get_prepared_statements_global_infos(); + void get_memory_usage(uint64_t& prep_stmt_metadata_mem_usage, uint64_t& prep_stmt_backend_mem_usage); +}; + +#endif /* CLASS_PGSQL_PREPARED_STATEMENT_H */ diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index 55a4a9e5b..90f4cd7b4 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -245,7 +245,10 @@ public: void write_PasswordMessage(const char* psw) { write_generic('p', "s", psw); } - + void write_ParseCompletion() { + put_char('1'); + put_uint32(4); + } void write_RowDescription(const char* tupdesc, ...); void write_DataRow(const char* tupdesc, ...); @@ -282,6 +285,22 @@ private: friend void SQLite3_to_Postgres(PtrSizeArray* psa, SQLite3_result* result, char* error, int affected_rows, const char* query_type, char txn_state); }; +class PgSQL_Parse_Message { +public: + PgSQL_Parse_Message(); + ~PgSQL_Parse_Message(); + const char* stmt_name = NULL; // The name of the prepared statement + const char* query_string = NULL; // The query string to be prepared + uint16_t num_param_types = 0; // Number of parameter types specified + const uint32_t* param_types = NULL; // Array of parameter types (can be nullptr if none) + + bool parse(PtrSize_t& pkt); + PtrSize_t detach(); + +private: + PtrSize_t _pkt = {}; +}; + class PgSQL_Protocol; #define PGSQL_QUERY_RESULT_NO_DATA 0x00 @@ -737,6 +756,9 @@ public: bool generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state = 'I', PtrSize_t* _ptr = NULL, const std::vector>& param_status = std::vector>()); + bool generate_parse_completion_packet(bool send, bool ready, char trx_state, PtrSize_t* _ptr = NULL); + bool generate_ready_for_query_packet(bool send, char trx_state, PtrSize_t* _ptr = NULL); + // temporary overriding generate_pkt_OK to avoid crash. FIXME remove this bool generate_pkt_OK(bool send, void** ptr, unsigned int* len, uint8_t sequence_id, unsigned int affected_rows, uint64_t last_insert_id, uint16_t status, uint16_t warnings, char* msg, bool eof_identifier = false) { diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 00010071f..c680e20a2 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -5,7 +5,7 @@ #include #include - +#include #include "proxysql.h" #include "Base_Session.h" #include "cpp.h" @@ -16,6 +16,7 @@ class PgSQL_Query_Result; class PgSQL_ExplicitTxnStateMgr; +class PgSQL_Parse_Message; //#include "../deps/json/json.hpp" //using json = nlohmann::json; @@ -142,6 +143,7 @@ public: bool match(char* m); }; +class PgSQL_STMT_Global_info; class PgSQL_Query_Info { public: @@ -151,11 +153,11 @@ public: unsigned long long start_time; unsigned long long end_time; - MYSQL_STMT* mysql_stmt; - stmt_execute_metadata_t* stmt_meta; + char* stmt_client_name; uint64_t stmt_global_id; - uint64_t stmt_client_id; - MySQL_STMT_Global_info* stmt_info; + uint64_t stmt_backend_id; + + PgSQL_STMT_Global_info* stmt_info; int QueryLength; enum PGSQL_QUERY_command PgQueryCmd; @@ -181,6 +183,10 @@ public: class PgSQL_Session : public Base_Session { private: + using PktType = std::variant>; + + std::queue pending_packets; + //int handler_ret; void handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE(PtrSize_t*, bool*); @@ -228,6 +234,11 @@ private: bool handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(PtrSize_t*, bool* lock_hostgroup, PgSQL_ps_type prepare_stmt_type = PgSQL_ps_type_not_set); void handler___client_DSS_QUERY_SENT___server_DSS_NOT_INITIALIZED__get_connection(); + bool handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_PARSE(PtrSize_t& pkt); + int handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_SYNC(PtrSize_t& pkt); + bool handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& st, PgSQL_Data_Stream* myds, bool& prepared_stmt_with_no_params); + + int handle_post_sync_parse_message(PgSQL_Parse_Message* parsse_msg); //void return_proxysql_internal(PtrSize_t*); bool handler_special_queries(PtrSize_t*, bool* lock_hostgroup); @@ -302,10 +313,6 @@ private: void housekeeping_before_pkts(); #endif // 0 int get_pkts_from_client(bool&, PtrSize_t&); - - //void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_STMT_PREPARE(PtrSize_t& pkt); - //void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_STMT_EXECUTE(PtrSize_t& pkt); - // these functions have code that used to be inline, and split into functions for readibility int handler_ProcessingQueryError_CheckBackendConnectionStatus(PgSQL_Data_Stream* myds); void SetQueryTimeout(); @@ -316,7 +323,6 @@ private: void handler_minus1_HandleBackendConnection(PgSQL_Data_Stream* myds); int RunQuery(PgSQL_Data_Stream* myds, PgSQL_Connection* myconn); void handler___status_WAITING_CLIENT_DATA(); - void handler_rc0_Process_GTID(PgSQL_Connection* myconn); void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_INIT_DB_replace_CLICKHOUSE(PtrSize_t& pkt); void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___not_mysql(PtrSize_t& pkt); bool handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_detect_SQLi(); diff --git a/lib/Base_Thread.cpp b/lib/Base_Thread.cpp index d6ee97fe7..d6d2b0830 100644 --- a/lib/Base_Thread.cpp +++ b/lib/Base_Thread.cpp @@ -134,6 +134,7 @@ S Base_Thread::create_new_session_and_client_data_stream(int _fd) { if constexpr (std::is_same_v) { PgSQL_Connection* myconn = new PgSQL_Connection(); sess->client_myds->attach_connection(myconn); + sess->client_myds->myconn->set_is_client(); // this is used for prepared statements } else if constexpr (std::is_same_v) { MySQL_Connection* myconn = new MySQL_Connection(); sess->client_myds->attach_connection(myconn); diff --git a/lib/Makefile b/lib/Makefile index 33fc0fa5c..83dd70406 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -164,7 +164,8 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo PgSQL_Protocol.oo PgSQL_Thread.oo PgSQL_Data_Stream.oo PgSQL_Session.oo PgSQL_Variables.oo PgSQL_HostGroups_Manager.oo PgSQL_Connection.oo PgSQL_Backend.oo PgSQL_Logger.oo PgSQL_Authentication.oo PgSQL_Error_Helper.oo \ MySQL_Query_Cache.oo PgSQL_Query_Cache.oo PgSQL_Monitor.oo \ MySQL_Set_Stmt_Parser.oo PgSQL_Set_Stmt_Parser.oo \ - PgSQL_Variables_Validator.oo PgSQL_ExplicitTxnStateMgr.oo + PgSQL_Variables_Validator.oo PgSQL_ExplicitTxnStateMgr.oo \ + PgSQL_PreparedStatement.oo OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX)) HEADERS := ../include/*.h ../include/*.hpp diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index b5e668948..3a9923530 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -9,7 +9,7 @@ using json = nlohmann::json; #include "PgSQL_HostGroups_Manager.h" #include "proxysql.h" #include "cpp.h" -#include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "PgSQL_Data_Stream.h" #include "PgSQL_Query_Processor.h" #include "MySQL_Variables.h" @@ -178,6 +178,7 @@ PgSQL_Connection::PgSQL_Connection() { options.init_connect = NULL; options.init_connect_sent = false; userinfo = new PgSQL_Connection_userinfo(); + local_stmts = new PgSQL_STMTs_local_v14(false); // false by default, it is a backend for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { variables[i].value = NULL; @@ -200,6 +201,10 @@ PgSQL_Connection::~PgSQL_Connection() { PQclear(pgsql_result); pgsql_result = NULL; } + if (local_stmts) { + delete local_stmts; + local_stmts = NULL; + } if (pgsql_conn) { if (is_connected()) __sync_fetch_and_sub(&PgHGM->status.server_connections_connected, 1); @@ -670,6 +675,55 @@ handler_again: case ASYNC_RESET_SESSION_SUCCESSFUL: case ASYNC_RESET_SESSION_TIMEOUT: break; + case ASYNC_STMT_PREPARE_START: + stmt_prepare_start(); + if (async_exit_status) { + next_event(ASYNC_STMT_PREPARE_CONT); + } + else { + NEXT_IMMEDIATE(ASYNC_STMT_PREPARE_END); + } + break; + case ASYNC_STMT_PREPARE_CONT: + { + if (event) { + stmt_prepare_cont(event); + } + if (async_exit_status) { + next_event(ASYNC_STMT_PREPARE_END); + break; + } + if (is_error_present()) { + NEXT_IMMEDIATE(ASYNC_STMT_PREPARE_END); + } + PGresult* result = get_result(); + if (result) { + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + set_error_from_result(result, PGSQL_ERROR_FIELD_ALL); + assert(is_error_present()); + } + PQclear(result); + NEXT_IMMEDIATE(ASYNC_STMT_PREPARE_CONT); + } + NEXT_IMMEDIATE(ASYNC_STMT_PREPARE_END); + } + break; + case ASYNC_STMT_PREPARE_END: + if (is_error_present()) { + proxy_error("Failed to prepare statement: %s\n", get_error_code_with_message().c_str()); + NEXT_IMMEDIATE(ASYNC_STMT_PREPARE_FAILED); + } + else { + NEXT_IMMEDIATE(ASYNC_STMT_PREPARE_SUCCESSFUL); + } + break; + case ASYNC_STMT_PREPARE_FAILED: + + break; + case ASYNC_STMT_PREPARE_SUCCESSFUL: + + break; + default: // not implemented yet assert(0); @@ -1149,7 +1203,7 @@ bool PgSQL_Connection::IsAutoCommit() { // 0 when the query is completed // 1 when the query is not completed // the calling function should check pgsql error in pgsql struct -int PgSQL_Connection::async_query(short event, char* stmt, unsigned long length) { +int PgSQL_Connection::async_query(short event, const char* stmt, unsigned long length, const char* backend_stmt_name, void* execute_data) { PROXY_TRACE(); PROXY_TRACE2(); assert(pgsql_conn); @@ -1178,13 +1232,26 @@ int PgSQL_Connection::async_query(short event, char* stmt, unsigned long length) } } - set_query(stmt, length); - async_state_machine = ASYNC_QUERY_START; + if (!backend_stmt_name) { + async_state_machine = ASYNC_QUERY_START; + } else { + if (execute_data) { + async_state_machine = ASYNC_STMT_EXECUTE_START; + } else { + async_state_machine = ASYNC_STMT_PREPARE_START; + } + } + set_query(stmt, length, backend_stmt_name); default: handler(event); break; } + if (async_state_machine == ASYNC_STMT_EXECUTE_END) { + PROXY_TRACE2(); + async_state_machine = ASYNC_QUERY_END; + } + if (async_state_machine == ASYNC_QUERY_END) { PROXY_TRACE2(); compute_unknown_transaction_status(); @@ -1195,9 +1262,20 @@ int PgSQL_Connection::async_query(short event, char* stmt, unsigned long length) return 0; } } + + if (async_state_machine == ASYNC_STMT_PREPARE_SUCCESSFUL || + async_state_machine == ASYNC_STMT_PREPARE_FAILED) { + compute_unknown_transaction_status(); + if (async_state_machine == ASYNC_STMT_PREPARE_FAILED) { + return -1; + } else { + return 0; + } + } + if (async_state_machine == ASYNC_USE_RESULT_START) { // if we reached this point it measn we are processing a multi-statement - // and we need to exit to give control to MySQL_Session + // and we need to exit to give control to PgSQL_Session processing_multi_statement = true; return 2; } @@ -1429,6 +1507,51 @@ void PgSQL_Connection::next_multi_statement_result(PGresult* result) { query_result->buffer_to_PSarrayOut(); } +void PgSQL_Connection::stmt_prepare_start() { + PROXY_TRACE(); + reset_error(); + processing_multi_statement = false; + async_exit_status = PG_EVENT_NONE; + + PQsetNoticeReceiver(pgsql_conn, &PgSQL_Connection::notice_handler_cb, this); + + if (PQsendPrepare(pgsql_conn, query.backend_stmt_name, query.ptr, 0, NULL) == 0) { + set_error_from_PQerrorMessage(); + proxy_error("Failed to send prepare. %s\n", get_error_code_with_message().c_str()); + return; + } + flush(); +} + +void PgSQL_Connection::stmt_prepare_cont(short event) { + PROXY_TRACE(); + proxy_debug(PROXY_DEBUG_MYSQL_PROTOCOL, 6, "event=%d\n", event); + async_exit_status = PG_EVENT_NONE; + if (event & POLLOUT) { + flush(); + return; + } + + if (PQconsumeInput(pgsql_conn) == 0) { + /* We will only set the error if we didn't capture error in last call. If is_error_present is true, + * it indicates that an error was already captured during a previous PQconsumeInput call, + * and we do not want to overwrite that information. + */ + if (is_error_present() == false) { + set_error_from_PQerrorMessage(); + proxy_error("Failed to consume input. %s\n", get_error_code_with_message().c_str()); + } + return; + } + + if (PQisBusy(pgsql_conn)) { + async_exit_status = PG_EVENT_READ; + return; + } + + pgsql_result = PQgetResult(pgsql_conn); +} + void PgSQL_Connection::reset_session_start() { PROXY_TRACE(); assert(pgsql_conn); @@ -1880,6 +2003,8 @@ void PgSQL_Connection::reset() { set_status(old_compress, STATUS_MYSQL_CONNECTION_COMPRESSION); reusable = true; creation_time = monotonic_time(); + delete local_stmts; + local_stmts = new PgSQL_STMTs_local_v14(false); for (int i = 0; i < PGSQL_NAME_LAST_HIGH_WM; i++) { var_hash[i] = 0; @@ -1970,12 +2095,13 @@ bool PgSQL_Connection::AutocommitFalse_AndSavepoint() { return ret; } -void PgSQL_Connection::set_query(char* stmt, unsigned long length) { +void PgSQL_Connection::set_query(const char* stmt, unsigned long length, const char* backend_stmt_name) { query.length = length; query.ptr = stmt; if (length > largest_query_length) { largest_query_length = length; } + query.backend_stmt_name = backend_stmt_name; } bool PgSQL_Connection::IsKeepMultiplexEnabledVariables(char* query_digest_text) { diff --git a/lib/PgSQL_Data_Stream.cpp b/lib/PgSQL_Data_Stream.cpp index ef952d1c9..d220dead4 100644 --- a/lib/PgSQL_Data_Stream.cpp +++ b/lib/PgSQL_Data_Stream.cpp @@ -6,7 +6,7 @@ #define UNIX_PATH_MAX 108 #endif -#include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "PgSQL_Data_Stream.h" #include "openssl/x509v3.h" @@ -224,10 +224,10 @@ PgSQL_Data_Stream::PgSQL_Data_Stream() { proxy_addr.port = 0; sess = NULL; - mysql_real_query.pkt.ptr = NULL; - mysql_real_query.pkt.size = 0; - mysql_real_query.QueryPtr = NULL; - mysql_real_query.QuerySize = 0; + pgsql_real_query.pkt.ptr = NULL; + pgsql_real_query.pkt.size = 0; + pgsql_real_query.QueryPtr = NULL; + pgsql_real_query.QuerySize = 0; query_retries_on_failure = 0; connect_retries_on_failure = 0; @@ -296,7 +296,7 @@ PgSQL_Data_Stream::~PgSQL_Data_Stream() { proxy_addr.addr = NULL; } - free_mysql_real_query(); + free_pgsql_real_query(); if (com_field_wild) { free(com_field_wild); @@ -1169,9 +1169,9 @@ void PgSQL_Data_Stream::return_MySQL_Connection_To_Pool() { } } -void PgSQL_Data_Stream::free_mysql_real_query() { - if (mysql_real_query.QueryPtr) { - mysql_real_query.end(); +void PgSQL_Data_Stream::free_pgsql_real_query() { + if (pgsql_real_query.QueryPtr) { + pgsql_real_query.end(); } } diff --git a/lib/PgSQL_HostGroups_Manager.cpp b/lib/PgSQL_HostGroups_Manager.cpp index 5646170a4..7304db9b8 100644 --- a/lib/PgSQL_HostGroups_Manager.cpp +++ b/lib/PgSQL_HostGroups_Manager.cpp @@ -6,7 +6,7 @@ using json = nlohmann::json; #include "proxysql.h" #include "cpp.h" -#include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "PgSQL_Data_Stream.h" #include diff --git a/lib/PgSQL_Logger.cpp b/lib/PgSQL_Logger.cpp index 16b5d1e69..2aecca79f 100644 --- a/lib/PgSQL_Logger.cpp +++ b/lib/PgSQL_Logger.cpp @@ -8,7 +8,7 @@ using json = nlohmann::json; #include "PgSQL_Data_Stream.h" #include "PgSQL_Query_Processor.h" -#include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "PgSQL_Logger.hpp" #include @@ -60,11 +60,11 @@ PgSQL_Event::PgSQL_Event (log_event_type _et, uint32_t _thread_id, char * _usern affected_rows=0; have_rows_sent=false; rows_sent=0; - client_stmt_id=0; + client_stmt_name=NULL; } -void PgSQL_Event::set_client_stmt_id(uint32_t client_stmt_id) { - this->client_stmt_id = client_stmt_id; +void PgSQL_Event::set_client_stmt_name(char* client_stmt_name) { + this->client_stmt_name = client_stmt_name; } // if affected rows is set, last_insert_id is set too. @@ -266,7 +266,8 @@ uint64_t PgSQL_Event::write_query_format_1(std::fstream *f) { total_bytes+=mysql_encode_length(start_time,NULL); total_bytes+=mysql_encode_length(end_time,NULL); - total_bytes+=mysql_encode_length(client_stmt_id,NULL); + client_stmt_name_len=strlen(client_stmt_name); + total_bytes+=mysql_encode_length(client_stmt_name_len,NULL)+client_stmt_name_len; total_bytes+=mysql_encode_length(affected_rows,NULL); total_bytes+=mysql_encode_length(rows_sent,NULL); @@ -325,9 +326,10 @@ uint64_t PgSQL_Event::write_query_format_1(std::fstream *f) { f->write((char *)buf,len); if (et == PROXYSQL_COM_STMT_PREPARE || et == PROXYSQL_COM_STMT_EXECUTE) { - len=mysql_encode_length(client_stmt_id,buf); - write_encoded_length(buf,client_stmt_id,len,buf[0]); - f->write((char *)buf,len); + len = mysql_encode_length(client_stmt_name_len, buf); + write_encoded_length(buf, client_stmt_name_len, len, buf[0]); + f->write((char*)buf, len); + f->write(client_stmt_name, client_stmt_name_len); } len=mysql_encode_length(affected_rows,buf); @@ -431,7 +433,7 @@ uint64_t PgSQL_Event::write_query_format_2_json(std::fstream *f) { j["digest"] = digest_hex; if (et == PROXYSQL_COM_STMT_PREPARE || et == PROXYSQL_COM_STMT_EXECUTE) { - j["client_stmt_id"] = client_stmt_id; + j["client_stmt_name"] = client_stmt_name; } // for performance reason, we are moving the write lock @@ -720,7 +722,7 @@ void PgSQL_Logger::log_request(PgSQL_Session *sess, PgSQL_Data_Stream *myds) { case PROCESSING_STMT_EXECUTE: c = (char *)sess->CurrentQuery.stmt_info->query; ql = sess->CurrentQuery.stmt_info->query_length; - me.set_client_stmt_id(sess->CurrentQuery.stmt_client_id); + me.set_client_stmt_name(sess->CurrentQuery.stmt_client_name); break; case PROCESSING_STMT_PREPARE: default: @@ -731,7 +733,7 @@ void PgSQL_Logger::log_request(PgSQL_Session *sess, PgSQL_Data_Stream *myds) { // global cache and due to that we immediately reply to the client and session doesn't reach // 'PROCESSING_STMT_PREPARE' state. 'stmt_client_id' is expected to be '0' for anything that isn't // a prepared statement, still, logging should rely 'log_event_type' instead of this value. - me.set_client_stmt_id(sess->CurrentQuery.stmt_client_id); + me.set_client_stmt_name(sess->CurrentQuery.stmt_client_name); break; } if (c) { diff --git a/lib/PgSQL_PreparedStatement.cpp b/lib/PgSQL_PreparedStatement.cpp new file mode 100644 index 000000000..0b1f2aa10 --- /dev/null +++ b/lib/PgSQL_PreparedStatement.cpp @@ -0,0 +1,908 @@ +#include "proxysql.h" +#include "cpp.h" + +#ifndef SPOOKYV2 +#include "SpookyV2.h" +#define SPOOKYV2 +#endif + +#include "PgSQL_PreparedStatement.h" +#include "MySQL_Protocol.h" + +//extern MySQL_STMT_Manager *GloMyStmt; +//static uint32_t add_prepared_statement_calls = 0; +//static uint32_t find_prepared_statement_by_hash_calls = 0; +//#else +extern PgSQL_STMT_Manager_v14 *GloPgStmt; +//#endif + +const int PS_GLOBAL_STATUS_FIELD_NUM = 9; + +static uint64_t stmt_compute_hash(char *user, + char *schema, char *query, + unsigned int query_length) { + int l = 0; + l += strlen(user); + l += strlen(schema); +// two random seperators +#define _COMPUTE_HASH_DEL1_ "-ujhtgf76y576574fhYTRDFwdt-" +#define _COMPUTE_HASH_DEL2_ "-8k7jrhtrgJHRgrefgreRFewg6-" + l += strlen(_COMPUTE_HASH_DEL1_); + l += strlen(_COMPUTE_HASH_DEL2_); + l += query_length; + char *buf = (char *)malloc(l); + l = 0; + + // write user + strcpy(buf + l, user); + l += strlen(user); + + // write delimiter1 + strcpy(buf + l, _COMPUTE_HASH_DEL1_); + l += strlen(_COMPUTE_HASH_DEL1_); + + // write schema + strcpy(buf + l, schema); + l += strlen(schema); + + // write delimiter2 + strcpy(buf + l, _COMPUTE_HASH_DEL2_); + l += strlen(_COMPUTE_HASH_DEL2_); + + // write query + memcpy(buf + l, query, query_length); + l += query_length; + + uint64_t hash = SpookyHash::Hash64(buf, l, 0); + free(buf); + return hash; +} + +void PgSQL_STMT_Global_info::compute_hash() { + hash = stmt_compute_hash(username, schemaname, query, + query_length); +} + +PgSQL_STMT_Global_info::PgSQL_STMT_Global_info(uint64_t id, + char *u, char *s, char *q, + unsigned int ql, + char *fc, + uint64_t _h) { + pthread_rwlock_init(&rwlock_, NULL); + total_mem_usage = 0; + statement_id = id; + ref_count_client = 0; + ref_count_server = 0; + digest_text = NULL; + username = strdup(u); + schemaname = strdup(s); + query = (char *)malloc(ql + 1); + memcpy(query, q, ql); + query[ql] = '\0'; // add NULL byte + query_length = ql; + if (fc) { + first_comment = strdup(fc); + } else { + first_comment = NULL; + } + PgQueryCmd = PGSQL_QUERY__UNINITIALIZED; + + if (_h) { + hash = _h; + } else { + compute_hash(); + } + + is_select_NOT_for_update = false; + { // see bug #899 . Most of the code is borrowed from + // Query_Info::is_select_NOT_for_update() + if (ql >= 7) { + if (strncasecmp(q, (char *)"SELECT ", 7) == 0) { // is a SELECT + if (ql >= 17) { + char *p = q; + p += ql - 11; + if (strncasecmp(p, " FOR UPDATE", 11) == 0) { // is a SELECT FOR UPDATE + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + p = q; + p += ql-10; + if (strncasecmp(p, " FOR SHARE", 10) == 0) { // is a SELECT FOR SHARE + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + if (ql >= 25) { + p = q; + p += ql-19; + if (strncasecmp(p, " LOCK IN SHARE MODE", 19) == 0) { // is a SELECT LOCK IN SHARE MODE + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + p = q; + p += ql-7; + if (strncasecmp(p," NOWAIT",7)==0) { + // let simplify. If NOWAIT is used, we assume FOR UPDATE|SHARE is used + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; +/* + if (strcasestr(q," FOR UPDATE ")) { + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + if (strcasestr(q," FOR SHARE ")) { + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } +*/ + } + p = q; + p += ql-12; + if (strncasecmp(p," SKIP LOCKED",12)==0) { + // let simplify. If SKIP LOCKED is used, we assume FOR UPDATE|SHARE is used + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; +/* + if (strcasestr(q," FOR UPDATE ")==NULL) { + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + if (strcasestr(q," FOR SHARE ")==NULL) { + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } +*/ + } + p=q; + char buf[129]; + if (ql>=128) { // for long query, just check the last 128 bytes + p+=ql-128; + memcpy(buf,p,128); + buf[128]=0; + } else { + memcpy(buf,p,ql); + buf[ql]=0; + } + if (strcasestr(buf," FOR ")) { + if (strcasestr(buf," FOR UPDATE ")) { + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + if (strcasestr(buf," FOR SHARE ")) { + __sync_fetch_and_add(&MyHGM->status.select_for_update_or_equivalent, 1); + goto __exit_PgSQL_STMT_Global_info___search_select; + } + } + } + } + is_select_NOT_for_update = true; + } + } + } +__exit_PgSQL_STMT_Global_info___search_select: + calculate_mem_usage(); +} + +void PgSQL_STMT_Global_info::calculate_mem_usage() { + total_mem_usage = sizeof(PgSQL_STMT_Global_info) + + query_length + 1;// + + //(ref_count_client * 24) + + //(ref_count_server * 24); + + if (username) total_mem_usage += strlen(username) + 1; + if (schemaname) total_mem_usage += strlen(schemaname) + 1; + if (first_comment) total_mem_usage += strlen(first_comment) + 1; + if (digest_text) total_mem_usage += strlen(digest_text) + 1; +} + +void PgSQL_STMT_Global_info::update_metadata(MYSQL_STMT *stmt) { + + bool need_refresh = false; + pthread_rwlock_wrlock(&rwlock_); + /* if ( + (num_params != stmt->param_count) + || + (num_columns != stmt->field_count) + ) { + need_refresh = true; + } + for (i = 0; i < num_columns; i++) { + if (need_refresh == false) { // don't bother to check if need_refresh == true + bool ok = true; + MYSQL_FIELD *fs = &(stmt->fields[i]); + MYSQL_FIELD *fd = fields[i]; + if (ok) { + ok = false; + if (fd->name == NULL && fs->name == NULL) { + ok = true; + } else { + if (fd->name && fs->name && strcmp(fd->name,fs->name)==0) { + ok = true; + } + } + } + if (ok) { + ok = false; + if (fd->org_name == NULL && fs->org_name == NULL) { + ok = true; + } else { + if (fd->org_name && fs->org_name && strcmp(fd->org_name,fs->org_name)==0) { + ok = true; + } + } + } + if (ok) { + ok = false; + if (fd->table == NULL && fs->table == NULL) { + ok = true; + } else { + if (fd->table && fs->table && strcmp(fd->table,fs->table)==0) { + ok = true; + } + } + } + if (ok) { + ok = false; + if (fd->org_table == NULL && fs->org_table == NULL) { + ok = true; + } else { + if (fd->org_table && fs->org_table && strcmp(fd->org_table,fs->org_table)==0) { + ok = true; + } + } + } + if (ok) { + ok = false; + if (fd->db == NULL && fs->db == NULL) { + ok = true; + } else { + if (fd->db && fs->db && strcmp(fd->db,fs->db)==0) { + ok = true; + } + } + } + if (ok) { + ok = false; + if (fd->catalog == NULL && fs->catalog == NULL) { + ok = true; + } else { + if (fd->catalog && fs->catalog && strcmp(fd->catalog,fs->catalog)==0) { + ok = true; + } + } + } + if (ok) { + ok = false; + if (fd->def == NULL && fs->def == NULL) { + ok = true; + } else { + if (fd->def && fs->def && strcmp(fd->def,fs->def)==0) { + ok = true; + } + } + } + if (ok == false) { + need_refresh = true; + } + } + } + if (need_refresh) { + if (digest_text && strncasecmp(digest_text, "EXPLAIN", strlen("EXPLAIN"))==0) { + // do not print any message in case of EXPLAIN + } else { + proxy_warning("Updating metadata for stmt %lu , user %s, query %s\n", statement_id, username, query); + } +// from here is copied from destructor + if (num_columns) { + uint16_t i; + for (i = 0; i < num_columns; i++) { + MYSQL_FIELD *f = fields[i]; + if (f->name) { + free(f->name); + f->name = NULL; + } + if (f->org_name) { + free(f->org_name); + f->org_name = NULL; + } + if (f->table) { + free(f->table); + f->table = NULL; + } + if (f->org_table) { + free(f->org_table); + f->org_table = NULL; + } + if (f->db) { + free(f->db); + f->db = NULL; + } + if (f->catalog) { + free(f->catalog); + f->catalog = NULL; + } + if (f->def) { + free(f->def); + f->def = NULL; + } + free(fields[i]); + } + free(fields); + fields = NULL; + } + if (num_params) { + uint16_t i; + for (i = 0; i < num_params; i++) { + free(params[i]); + } + free(params); + params = NULL; + } +// till here is copied from destructor + +// from here is copied from constructor + num_params = stmt->param_count; + num_columns = stmt->field_count; + fields = NULL; + if (num_columns) { + fields = (MYSQL_FIELD **)malloc(num_columns * sizeof(MYSQL_FIELD *)); + uint16_t i; + for (i = 0; i < num_columns; i++) { + fields[i] = (MYSQL_FIELD *)malloc(sizeof(MYSQL_FIELD)); + MYSQL_FIELD *fs = &(stmt->fields[i]); + MYSQL_FIELD *fd = fields[i]; + // first copy all fields + memcpy(fd, fs, sizeof(MYSQL_FIELD)); + // then duplicate strings + fd->name = (fs->name ? strdup(fs->name) : NULL); + fd->org_name = (fs->org_name ? strdup(fs->org_name) : NULL); + fd->table = (fs->table ? strdup(fs->table) : NULL); + fd->org_table = (fs->org_table ? strdup(fs->org_table) : NULL); + fd->db = (fs->db ? strdup(fs->db) : NULL); + fd->catalog = (fs->catalog ? strdup(fs->catalog) : NULL); + fd->def = (fs->def ? strdup(fs->def) : NULL); + } + } + + params = NULL; + if (num_params == 2) { + PROXY_TRACE(); + } + if (num_params) { + params = (MYSQL_BIND **)malloc(num_params * sizeof(MYSQL_BIND *)); + uint16_t i; + for (i = 0; i < num_params; i++) { + params[i] = (MYSQL_BIND *)malloc(sizeof(MYSQL_BIND)); + // MYSQL_BIND *ps=&(stmt->params[i]); + // MYSQL_BIND *pd=params[i]; + // copy all params + // memcpy(pd,ps,sizeof(MYSQL_BIND)); + memset(params[i], 0, sizeof(MYSQL_BIND)); + } + } + +// till here is copied from constructor + calculate_mem_usage(); + }*/ + pthread_rwlock_unlock(&rwlock_); +} + +PgSQL_STMT_Global_info::~PgSQL_STMT_Global_info() { + free(username); + free(schemaname); + free(query); + if (first_comment) { + free(first_comment); + } + /*if (num_columns) { + uint16_t i; + for (i = 0; i < num_columns; i++) { + MYSQL_FIELD *f = fields[i]; + if (f->name) { + free(f->name); + f->name = NULL; + } + if (f->org_name) { + free(f->org_name); + f->org_name = NULL; + } + if (f->table) { + free(f->table); + f->table = NULL; + } + if (f->org_table) { + free(f->org_table); + f->org_table = NULL; + } + if (f->db) { + free(f->db); + f->db = NULL; + } + if (f->catalog) { + free(f->catalog); + f->catalog = NULL; + } + if (f->def) { + free(f->def); + f->def = NULL; + } + free(fields[i]); + } + free(fields); + fields = NULL; + } + + if (num_params) { + uint16_t i; + for (i = 0; i < num_params; i++) { + free(params[i]); + } + free(params); + params = NULL; + } + */ + if (digest_text) { + free(digest_text); + digest_text = NULL; + } +} + +void PgSQL_STMTs_local_v14::backend_insert(uint64_t global_stmt_id, uint32_t backend_stmt_id) { + //std::pair::iterator, bool> ret; + //ret = global_stmt_to_backend_stmt.insert(std::make_pair(global_statement_id, stmt)); + global_stmt_to_backend_ids.insert(std::make_pair(global_stmt_id, backend_stmt_id)); + backend_stmt_to_global_ids.insert(std::make_pair(backend_stmt_id,global_stmt_id)); + // note: backend_insert() is always called after add_prepared_statement() + // for this reason, we will the ref count increase in add_prepared_statement() + // GloPgStmt->ref_count_client(global_statement_id, 1); +} + +void PgSQL_STMTs_local_v14::client_insert(uint64_t global_stmt_id, const std::string& client_stmt_name) { + stmt_name_to_global_ids.insert(std::make_pair(client_stmt_name, global_stmt_id)); + global_id_to_stmt_names.insert(std::make_pair(global_stmt_id, client_stmt_name)); +} + +uint64_t PgSQL_STMTs_local_v14::compute_hash(char *user, + char *schema, char *query, + unsigned int query_length) { + uint64_t hash; + hash = stmt_compute_hash(user, schema, query, query_length); + return hash; +} + +PgSQL_STMT_Manager_v14::PgSQL_STMT_Manager_v14() { + last_purge_time = time(NULL); + pthread_rwlock_init(&rwlock_, NULL); + map_stmt_id_to_info= std::map(); // map using statement id + map_stmt_hash_to_info = std::map(); // map using hashes + free_stmt_ids = std::stack (); + + next_statement_id = + 1; // we initialize this as 1 because we 0 is not allowed + num_stmt_with_ref_client_count_zero = 0; + num_stmt_with_ref_server_count_zero = 0; + statuses.c_unique = 0; + statuses.c_total = 0; + statuses.stmt_max_stmt_id = 0; + statuses.cached = 0; + statuses.s_unique = 0; + statuses.s_total = 0; +} + +PgSQL_STMT_Manager_v14::~PgSQL_STMT_Manager_v14() { + for (auto it = map_stmt_id_to_info.begin(); it != map_stmt_id_to_info.end(); ++it) { + PgSQL_STMT_Global_info * a = it->second; + delete a; + } +} + +void PgSQL_STMT_Manager_v14::ref_count_client(uint64_t _stmt_id ,int _v, bool lock) { + if (lock) + pthread_rwlock_wrlock(&rwlock_); + auto s = map_stmt_id_to_info.find(_stmt_id); + if (s != map_stmt_id_to_info.end()) { + statuses.c_total += _v; + PgSQL_STMT_Global_info *stmt_info = s->second; + if (stmt_info->ref_count_client == 0 && _v == 1) { + __sync_sub_and_fetch(&num_stmt_with_ref_client_count_zero,1); + } else { + if (stmt_info->ref_count_client == 1 && _v == -1) { + __sync_add_and_fetch(&num_stmt_with_ref_client_count_zero,1); + } + } + stmt_info->ref_count_client += _v; + time_t ct = time(NULL); + uint64_t num_client_count_zero = __sync_add_and_fetch(&num_stmt_with_ref_client_count_zero, 0); + uint64_t num_server_count_zero = __sync_add_and_fetch(&num_stmt_with_ref_server_count_zero, 0); + + size_t map_size = map_stmt_id_to_info.size(); + if ( + (ct > last_purge_time+1) && + (map_size > (unsigned)mysql_thread___max_stmts_cache ) && + (num_client_count_zero > map_size/10) && + (num_server_count_zero > map_size/10) + ) { // purge only if there is at least 10% gain + last_purge_time = ct; + int max_purge = map_size ; + int i = -1; + uint64_t *torem = + (uint64_t *)malloc(max_purge * sizeof(uint64_t)); + for (std::map::iterator it = + map_stmt_id_to_info.begin(); + it != map_stmt_id_to_info.end(); ++it) { + if ( (i == (max_purge - 1)) || (i == ((int)num_client_count_zero - 1)) ) { + break; // nothing left to clean up + } + PgSQL_STMT_Global_info *a = it->second; + if ((__sync_add_and_fetch(&a->ref_count_client, 0) == 0) && + (a->ref_count_server == 0) ) // this to avoid that IDs are incorrectly reused + { + uint64_t hash = a->hash; + auto s2 = map_stmt_hash_to_info.find(hash); + if (s2 != map_stmt_hash_to_info.end()) { + map_stmt_hash_to_info.erase(s2); + } + __sync_sub_and_fetch(&num_stmt_with_ref_client_count_zero,1); + //if (a->ref_count_server == 0) { + //__sync_sub_and_fetch(&num_stmt_with_ref_server_count_zero,1); + //} + // m.erase(it); + // delete a; + i++; + torem[i] = it->first; + } + } + while (i >= 0) { + uint64_t id = torem[i]; + auto s3 = map_stmt_id_to_info.find(id); + PgSQL_STMT_Global_info *a = s3->second; + if (a->ref_count_server == 0) { + __sync_sub_and_fetch(&num_stmt_with_ref_server_count_zero,1); + free_stmt_ids.push(id); + } + map_stmt_id_to_info.erase(s3); + statuses.s_total -= a->ref_count_server; + delete a; + i--; + } + free(torem); + } + } + if (lock) + pthread_rwlock_unlock(&rwlock_); +} + +void PgSQL_STMT_Manager_v14::ref_count_server(uint64_t _stmt_id ,int _v, bool lock) { + if (lock) + pthread_rwlock_wrlock(&rwlock_); + std::map::iterator s; + s = map_stmt_id_to_info.find(_stmt_id); + if (s != map_stmt_id_to_info.end()) { + statuses.s_total += _v; + PgSQL_STMT_Global_info *stmt_info = s->second; + if (stmt_info->ref_count_server == 0 && _v == 1) { + __sync_sub_and_fetch(&num_stmt_with_ref_server_count_zero,1); + } else { + if (stmt_info->ref_count_server == 1 && _v == -1) { + __sync_add_and_fetch(&num_stmt_with_ref_server_count_zero,1); + } + } + stmt_info->ref_count_server += _v; + } + if (lock) + pthread_rwlock_unlock(&rwlock_); +} + +PgSQL_STMTs_local_v14::~PgSQL_STMTs_local_v14() { + // Note: we do not free the prepared statements because we assume that + // if we call this destructor the connection is being destroyed anyway + + if (is_client_) { + for (auto it = stmt_name_to_global_ids.begin(); + it != stmt_name_to_global_ids.end(); ++it) { + uint64_t global_stmt_id = it->second; + GloPgStmt->ref_count_client(global_stmt_id, -1); + } + } else { + /*for (std::map::iterator it = global_stmt_to_backend_stmt.begin(); + it != global_stmt_to_backend_stmt.end(); ++it) { + uint64_t global_stmt_id = it->first; + MYSQL_STMT *stmt = it->second; + proxy_mysql_stmt_close(stmt); + GloPgStmt->ref_count_server(global_stmt_id, -1); + }*/ + for (auto it = backend_stmt_to_global_ids.begin(); + it != backend_stmt_to_global_ids.end(); ++it) { + uint64_t global_stmt_id = it->second; + GloPgStmt->ref_count_server(global_stmt_id, -1); + } + } +} + + +PgSQL_STMT_Global_info *PgSQL_STMT_Manager_v14::find_prepared_statement_by_hash(uint64_t hash) { + PgSQL_STMT_Global_info *ret = NULL; // assume we do not find it + auto s = map_stmt_hash_to_info.find(hash); + if (s != map_stmt_hash_to_info.end()) { + ret = s->second; + } + return ret; +} + +PgSQL_STMT_Global_info* PgSQL_STMT_Manager_v14::find_prepared_statement_by_stmt_id( + uint64_t id, bool lock) { + PgSQL_STMT_Global_info*ret = NULL; // assume we do not find it + if (lock) { + pthread_rwlock_wrlock(&rwlock_); + } + + auto s = map_stmt_id_to_info.find(id); + if (s != map_stmt_id_to_info.end()) { + ret = s->second; + } + + if (lock) { + pthread_rwlock_unlock(&rwlock_); + } + return ret; +} + +uint32_t PgSQL_STMTs_local_v14::generate_new_backend_stmt_id() { + assert(is_client_ == false); + if (free_backend_ids.empty() == false) { + uint32_t backend_stmt_id = free_backend_ids.top(); + free_backend_ids.pop(); + return backend_stmt_id; + } + local_max_stmt_id++; + return local_max_stmt_id; +} + +uint64_t PgSQL_STMTs_local_v14::find_global_id_from_stmt_name(const std::string& client_stmt_name) { + uint64_t ret=0; + auto s = stmt_name_to_global_ids.find(client_stmt_name); + if (s != stmt_name_to_global_ids.end()) { + ret = s->second; + } + return ret; +} + +bool PgSQL_STMTs_local_v14::client_close(const std::string& stmt_name) { + auto s = stmt_name_to_global_ids.find(stmt_name); + if (s != stmt_name_to_global_ids.end()) { // found + uint64_t global_stmt_id = s->second; + stmt_name_to_global_ids.erase(s); + GloPgStmt->ref_count_client(global_stmt_id, -1); + std::pair::iterator, std::multimap::iterator> ret; + ret = global_id_to_stmt_names.equal_range(global_stmt_id); + for (std::multimap::iterator it=ret.first; it!=ret.second; ++it) { + if (it->second == stmt_name) { + global_id_to_stmt_names.erase(it); + break; + } + } + return true; + } + return false; // we don't really remove the prepared statement +} + +PgSQL_STMT_Global_info* PgSQL_STMT_Manager_v14::add_prepared_statement( + char *u, char *s, char *q, unsigned int ql, + char *fc, bool lock) { + PgSQL_STMT_Global_info *ret = NULL; + uint64_t hash = stmt_compute_hash( + u, s, q, ql); // this identifies the prepared statement + if (lock) { + pthread_rwlock_wrlock(&rwlock_); + } + // try to find the statement + auto f = map_stmt_hash_to_info.find(hash); + if (f != map_stmt_hash_to_info.end()) { + // found it! + ret = f->second; + ret->update_metadata(nullptr); + } else { + uint64_t next_id = 0; + if (!free_stmt_ids.empty()) { + next_id = free_stmt_ids.top(); + free_stmt_ids.pop(); + } else { + next_id = next_statement_id; + next_statement_id++; + } + + std::unique_ptr stmt_info (new PgSQL_STMT_Global_info(next_id, u, s, q, ql, fc, hash)); + // insert it in both maps + map_stmt_id_to_info.insert(std::make_pair(stmt_info->statement_id, stmt_info.get())); + map_stmt_hash_to_info.insert(std::make_pair(stmt_info->hash, stmt_info.get())); + ret = stmt_info.release(); + __sync_add_and_fetch(&num_stmt_with_ref_client_count_zero,1); + __sync_add_and_fetch(&num_stmt_with_ref_server_count_zero,1); + } + if (ret->ref_count_server == 0) { + __sync_sub_and_fetch(&num_stmt_with_ref_server_count_zero,1); + } + ret->ref_count_server++; + statuses.s_total++; + if (lock) { + pthread_rwlock_unlock(&rwlock_); + } + return ret; +} + + +void PgSQL_STMT_Manager_v14::get_memory_usage(uint64_t& prep_stmt_metadata_mem_usage, uint64_t& prep_stmt_backend_mem_usage) { + prep_stmt_backend_mem_usage = 0; + prep_stmt_metadata_mem_usage = sizeof(PgSQL_STMT_Manager_v14); + rdlock(); + prep_stmt_metadata_mem_usage += map_stmt_id_to_info.size() * (sizeof(uint64_t) + sizeof(PgSQL_STMT_Global_info*)); + prep_stmt_metadata_mem_usage += map_stmt_hash_to_info.size() * (sizeof(uint64_t) + sizeof(PgSQL_STMT_Global_info*)); + prep_stmt_metadata_mem_usage += free_stmt_ids.size() * (sizeof(uint64_t)); + for (const auto& keyval : map_stmt_id_to_info) { + const PgSQL_STMT_Global_info* stmt_global_info = keyval.second; + prep_stmt_metadata_mem_usage += stmt_global_info->total_mem_usage; + prep_stmt_metadata_mem_usage += stmt_global_info->ref_count_server;// * + //((stmt_global_info->num_params * sizeof(MYSQL_BIND)) + + //(stmt_global_info->num_columns * sizeof(MYSQL_FIELD))) + 16; // ~16 bytes of memory utilized by global_stmt_id and stmt_id mappings + prep_stmt_metadata_mem_usage += stmt_global_info->ref_count_client;// * + //((stmt_global_info->num_params * sizeof(MYSQL_BIND)) + + //(stmt_global_info->num_columns * sizeof(MYSQL_FIELD))) + 16; // ~16 bytes of memory utilized by global_stmt_id and stmt_id mappings + + // backend + prep_stmt_backend_mem_usage += stmt_global_info->ref_count_server;// *(sizeof(MYSQL_STMT) + + //56// + //sizeof(MADB_STMT_EXTENSION) + //(stmt_global_info->num_params * sizeof(MYSQL_BIND)) + + //(stmt_global_info->num_columns * sizeof(MYSQL_FIELD))); + } + unlock(); +} + +void PgSQL_STMT_Manager_v14::get_metrics(uint64_t *c_unique, uint64_t *c_total, + uint64_t *stmt_max_stmt_id, uint64_t *cached, + uint64_t *s_unique, uint64_t *s_total) { +#ifdef DEBUG + uint64_t c_u = 0; + uint64_t c_t = 0; + uint64_t m = 0; + uint64_t c = 0; + uint64_t s_u = 0; + uint64_t s_t = 0; +#endif + pthread_rwlock_wrlock(&rwlock_); + statuses.cached = map_stmt_id_to_info.size(); + statuses.c_unique = statuses.cached - num_stmt_with_ref_client_count_zero; + statuses.s_unique = statuses.cached - num_stmt_with_ref_server_count_zero; +#ifdef DEBUG + for (std::map::iterator it = map_stmt_id_to_info.begin(); + it != map_stmt_id_to_info.end(); ++it) { + PgSQL_STMT_Global_info *a = it->second; + c++; + if (a->ref_count_client) { + c_u++; + c_t += a->ref_count_client; + } + if (a->ref_count_server) { + s_u++; + s_t += a->ref_count_server; + } + if (it->first > m) { + m = it->first; + } + } + assert (c_u == statuses.c_unique); + assert (c_t == statuses.c_total); + assert (c == statuses.cached); + assert (s_t == statuses.s_total); + assert (s_u == statuses.s_unique); + *stmt_max_stmt_id = m; +#endif + *stmt_max_stmt_id = next_statement_id; // this is max stmt_id, no matter if in used or not + *c_unique = statuses.c_unique; + *c_total = statuses.c_total; + *cached = statuses.cached; + *s_total = statuses.s_total; + *s_unique = statuses.s_unique; + pthread_rwlock_unlock(&rwlock_); +} + + +class PS_global_stats { + public: + uint64_t statement_id; + char *username; + char *schemaname; + uint64_t digest; + unsigned long long ref_count_client; + unsigned long long ref_count_server; + char *query; + uint64_t num_columns; + uint64_t num_params; + PS_global_stats(uint64_t stmt_id, char *s, char *u, uint64_t d, char *q, unsigned long long ref_c, unsigned long long ref_s, uint64_t columns, uint64_t params) { + statement_id = stmt_id; + digest=d; + query=strndup(q, mysql_thread___query_digests_max_digest_length); + username=strdup(u); + schemaname=strdup(s); + ref_count_client = ref_c; + ref_count_server = ref_s; + num_columns = columns; + num_params = params; + } + ~PS_global_stats() { + if (query) { + free(query); + query=NULL; + } + if (username) { + free(username); + username=NULL; + } + if (schemaname) { + free(schemaname); + schemaname=NULL; + } + } + char **get_row() { + char buf[128]; + char **pta=(char **)malloc(sizeof(char *)*PS_GLOBAL_STATUS_FIELD_NUM); + sprintf(buf,"%lu",statement_id); + pta[0]=strdup(buf); + assert(schemaname); + pta[1]=strdup(schemaname); + assert(username); + pta[2]=strdup(username); + + sprintf(buf,"0x%016llX", (long long unsigned int)digest); + pta[3]=strdup(buf); + + assert(query); + pta[4]=strdup(query); + sprintf(buf,"%llu",ref_count_client); + pta[5]=strdup(buf); + sprintf(buf,"%llu",ref_count_server); + pta[6]=strdup(buf); + sprintf(buf,"%lu",num_columns); + pta[7]=strdup(buf); + sprintf(buf,"%lu",num_params); + pta[8]=strdup(buf); + + return pta; + } + void free_row(char **pta) { + int i; + for (i=0;iadd_column_definition(SQLITE_TEXT,"stmt_id"); + result->add_column_definition(SQLITE_TEXT,"schemaname"); + result->add_column_definition(SQLITE_TEXT,"username"); + result->add_column_definition(SQLITE_TEXT,"digest"); + result->add_column_definition(SQLITE_TEXT,"query"); + result->add_column_definition(SQLITE_TEXT,"ref_count_client"); + result->add_column_definition(SQLITE_TEXT,"ref_count_server"); + result->add_column_definition(SQLITE_TEXT,"num_columns"); + result->add_column_definition(SQLITE_TEXT,"num_params"); + for (std::map::iterator it = map_stmt_id_to_info.begin(); + it != map_stmt_id_to_info.end(); ++it) { + PgSQL_STMT_Global_info *a = it->second; + PS_global_stats * pgs = new PS_global_stats(a->statement_id, + a->schemaname, a->username, + a->hash, a->query, + a->ref_count_client, a->ref_count_server, 0, 0); + char **pta = pgs->get_row(); + result->add_row(pta); + pgs->free_row(pta); + delete pgs; + } + unlock(); + return result; +} diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 4b606f267..9fa451195 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -487,6 +487,11 @@ static inline bool get_uint16be(unsigned char* pkt, uint16_t* dst_p) return true; } +static inline bool get_int16be(unsigned char* pkt, int16_t* dst_p) +{ + return get_uint16be(pkt, (uint16_t*)dst_p); +} + bool PgSQL_Protocol::get_header(unsigned char* pkt, unsigned int pkt_len, pgsql_hdr* hdr) { unsigned int type; uint32_t len; @@ -1567,6 +1572,50 @@ bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, return true; } +bool PgSQL_Protocol::generate_ready_for_query_packet(bool send, char trx_state, PtrSize_t* _ptr) { + // to avoid memory leak + assert(send == true || _ptr); + + PG_pkt pgpkt{}; + pgpkt.write_ReadyForQuery(trx_state); + auto buff = pgpkt.detach(); + if (send == true) { + (*myds)->PSarrayOUT->add((void*)buff.first, buff.second); + } else { + _ptr->ptr = buff.first; + _ptr->size = buff.second; + } + return true; +} + +bool PgSQL_Protocol::generate_parse_completion_packet(bool send, bool ready, char trx_state, PtrSize_t* _ptr) { + // to avoid memory leak + assert(send == true || _ptr); + + PG_pkt pgpkt{}; + + if (ready == true) { + pgpkt.set_multi_pkt_mode(true); + } + + // Parse completion message + pgpkt.write_ParseCompletion(); + + if (ready == true) { + pgpkt.write_ReadyForQuery(trx_state); + pgpkt.set_multi_pkt_mode(false); + } + + auto buff = pgpkt.detach(); + if (send == true) { + (*myds)->PSarrayOUT->add((void*)buff.first, buff.second); + } else { + _ptr->ptr = buff.first; + _ptr->size = buff.second; + } + return true; +} + //bool PgSQL_Protocol::generate_row_description(bool send, PgSQL_Query_Result* rs, const PG_Fields& fields, unsigned int size) { // if ((*myds)->sess->mirror == true) { // return true; @@ -2422,3 +2471,121 @@ void PgSQL_Query_Result::clear() { buffer_init(); reset(); } + + +PgSQL_Parse_Message::PgSQL_Parse_Message() { + +} + +PgSQL_Parse_Message::~PgSQL_Parse_Message() { + if (_pkt.ptr) { + free(_pkt.ptr); + _pkt.ptr = nullptr; + _pkt.size = 0; + } +} + +bool PgSQL_Parse_Message::parse(PtrSize_t& pkt) { + + if (pkt.ptr == nullptr || pkt.size == 0) { + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 1, "No packet to parse\n"); + return false; + } + + if (pkt.size < 5) { + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 1, "Packet too short for parsing: %u bytes\n", pkt.size); + return false; + } + + unsigned char* packet = (unsigned char*)pkt.ptr; + uint32_t pkt_len = pkt.size; + uint32_t payload_len = 0; + uint32_t offset = 0; + + if (packet[offset++] != 'P') { // 'P' is the packet type for Parse + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 1, "Invalid packet type: expected 'P'\n"); + return false; + } + + // Read the length of the packet (4 bytes, big-endian) + if (!get_uint32be(packet + offset, &payload_len)) { + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 1, "Failed to read packet size\n"); + return false; + } + offset += sizeof(uint32_t); + + // Check if the reported packet length matches the provided length + if (payload_len != pkt_len - 1) { + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 1, "Packet size too small: %u bytes\n", pkt.size); + return false; + } + + // Validate remaining length for statement name (at least 1 byte for null-terminated string) + if (offset >= pkt_len) { + return false; // Not enough data for statement name + } + + // Read the statement name (null-terminated string) + stmt_name = reinterpret_cast(packet + offset); + size_t stmt_name_len = strnlen(stmt_name, pkt_len - offset); + + // Ensure there is a null-terminator within the packet length + if (offset + stmt_name_len >= pkt_len) { + return false; // No null-terminator found within the packet bounds + } + + offset += stmt_name_len + 1; // Move past the null-terminated statement name + + // Validate remaining length for query string (at least 1 byte for null-terminated string) + if (offset >= pkt_len) { + return false; // Not enough data for query string + } + + // Read the query string (null-terminated string) + query_string = reinterpret_cast(packet + offset); + size_t query_string_len = strnlen(query_string, pkt_len - offset); + + // Ensure there is a null-terminator within the packet length + if (offset + query_string_len >= pkt_len) { + return false; // No null-terminator found within the packet bounds + } + + offset += query_string_len + 1; // Move past the null-terminated query string + + // Validate remaining length for number of parameter types (2 bytes) + if (offset + sizeof(int16_t) > pkt_len) { + return false; // Not enough data for numParameterTypes + } + + // Read the length of the parameter types (2 bytes, big-endian) + if (!get_uint16be(packet + offset, &num_param_types)) { + proxy_debug(PROXY_DEBUG_MYSQL_CONNECTION, 1, "Failed to read packet size\n"); + return false; + } + offset += sizeof(int16_t); + + // If there are parameter types, ensure there's enough data for all of them + if (num_param_types > 0) { + if (offset + num_param_types * sizeof(uint32_t) > pkt_len) { + return false; // Not enough data for all parameter types + } + + // Read the parameter types array (each is 4 bytes, big-endian) + param_types = reinterpret_cast(packet + offset); + + // Move past the parameter types + offset += num_param_types * sizeof(uint32_t); + } + + // take "ownership" + _pkt = pkt; + + // If we reach here, the packet is valid and fully parsed + return true; +} + +PtrSize_t PgSQL_Parse_Message::detach() { + PtrSize_t result = _pkt; + memset(this, 0, sizeof(PgSQL_Parse_Message)); + return result; +} diff --git a/lib/PgSQL_Query_Processor.cpp b/lib/PgSQL_Query_Processor.cpp index 68e373987..246bca18f 100644 --- a/lib/PgSQL_Query_Processor.cpp +++ b/lib/PgSQL_Query_Processor.cpp @@ -283,18 +283,18 @@ PgSQL_Query_Processor_Output* PgSQL_Query_Processor::process_query(PgSQL_Session } #define stackbuffer_size 128 char stackbuffer[stackbuffer_size]; - unsigned int len = 0; + unsigned int len = size; char* query = NULL; // NOTE: if ptr == NULL , we are calling process_mysql_query() on an STMT_EXECUTE if (ptr) { - len = size - sizeof(mysql_hdr) - 1; + //len = size - sizeof(mysql_hdr) - 1; if (len < stackbuffer_size) { query = stackbuffer; } else { - query = (char*)l_alloc(len + 1); + query = (char*)l_alloc(len); } - memcpy(query, (char*)ptr + sizeof(mysql_hdr) + 1, len); - query[len] = 0; + memcpy(query, ptr, len); + query[len-1] = 0; } else { //query = qi->stmt_info->query; @@ -308,7 +308,7 @@ PgSQL_Query_Processor_Output* PgSQL_Query_Processor::process_query(PgSQL_Session // query is in the stack } else { if (ptr) { - l_free(len + 1, query); + l_free(len, query); } } diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index bc6ed8def..0eed13406 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -1,7 +1,7 @@ #include "../deps/json/json.hpp" using json = nlohmann::json; #define PROXYJSON - +#include #include "PgSQL_HostGroups_Manager.h" #include "PgSQL_Thread.h" #include "proxysql.h" @@ -14,7 +14,7 @@ using json = nlohmann::json; #include "PgSQL_Data_Stream.h" #include "MySQL_Data_Stream.h" #include "PgSQL_Query_Processor.h" -#include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "PgSQL_Logger.hpp" #include "StatCounters.h" #include "PgSQL_Authentication.h" @@ -31,20 +31,6 @@ using json = nlohmann::json; #define SELECT_VERSION_COMMENT "select @@version_comment limit 1" #define SELECT_VERSION_COMMENT_LEN 32 -//#define SELECT_DB_USER "select DATABASE(), USER() limit 1" -#define SELECT_DB_USER_LEN 33 -//#define SELECT_CHARSET_STATUS "select @@character_set_client, @@character_set_connection, @@character_set_server, @@character_set_database limit 1" -#define SELECT_CHARSET_STATUS_LEN 115 -#define PROXYSQL_VERSION_COMMENT "\x01\x00\x00\x01\x01\x27\x00\x00\x02\x03\x64\x65\x66\x00\x00\x00\x11\x40\x40\x76\x65\x72\x73\x69\x6f\x6e\x5f\x63\x6f\x6d\x6d\x65\x6e\x74\x00\x0c\x21\x00\x18\x00\x00\x00\xfd\x00\x00\x1f\x00\x00\x05\x00\x00\x03\xfe\x00\x00\x02\x00\x0b\x00\x00\x04\x0a(ProxySQL)\x05\x00\x00\x05\xfe\x00\x00\x02\x00" -#define PROXYSQL_VERSION_COMMENT_LEN 81 - -// PROXYSQL_VERSION_COMMENT_WITH_OK is sent instead of PROXYSQL_VERSION_COMMENT -// if Client supports CLIENT_DEPRECATE_EOF -#define PROXYSQL_VERSION_COMMENT_WITH_OK "\x01\x00\x00\x01\x01" \ -"\x27\x00\x00\x02\x03\x64\x65\x66\x00\x00\x00\x11\x40\x40\x76\x65\x72\x73\x69\x6f\x6e\x5f\x63\x6f\x6d\x6d\x65\x6e\x74\x00\x0c\x21\x00\x18\x00\x00\x00\xfd\x00\x00\x1f\x00\x00" \ -"\x0b\x00\x00\x03\x0a(ProxySQL)" \ -"\x07\x00\x00\x04\xfe\x00\x00\x02\x00\x00\x00" -#define PROXYSQL_VERSION_COMMENT_WITH_OK_LEN 74 #define SELECT_CONNECTION_ID "SELECT CONNECTION_ID()" #define SELECT_CONNECTION_ID_LEN 22 @@ -59,6 +45,8 @@ using json = nlohmann::json; #define EXPMARIA +const char* PROXYSQL_PS_PREFIX = "proxysql_ps_"; + using std::function; using std::vector; @@ -113,7 +101,7 @@ extern PgSQL_Authentication* GloPgAuth; extern MySQL_LDAP_Authentication* GloMyLdapAuth; extern ProxySQL_Admin* GloAdmin; extern PgSQL_Logger* GloPgSQL_Logger; -extern MySQL_STMT_Manager_v14* GloMyStmt; +extern PgSQL_STMT_Manager_v14* GloPgStmt; extern SQLite3_Server* GloSQLite3Server; @@ -313,7 +301,6 @@ PgSQL_Query_Info::PgSQL_Query_Info() { rows_sent=0; start_time=0; end_time=0; - stmt_client_id=0; } PgSQL_Query_Info::~PgSQL_Query_Info() { @@ -327,8 +314,6 @@ void PgSQL_Query_Info::begin(unsigned char *_p, int len, bool header) { PgQueryCmd=PGSQL_QUERY___NONE; QueryPointer=NULL; QueryLength=0; - mysql_stmt=NULL; - stmt_meta=NULL; QueryParserArgs.digest_text=NULL; QueryParserArgs.first_comment=NULL; start_time=sess->thread->curtime; @@ -344,7 +329,6 @@ void PgSQL_Query_Info::begin(unsigned char *_p, int len, bool header) { waiting_since = 0; affected_rows=0; rows_sent=0; - stmt_client_id=0; } void PgSQL_Query_Info::end() { @@ -353,20 +337,9 @@ void PgSQL_Query_Info::end() { if ((end_time-start_time) > (unsigned int)pgsql_thread___long_query_time *1000) { __sync_add_and_fetch(&sess->thread->status_variables.stvar[st_var_queries_slow],1); } - assert(mysql_stmt==NULL); if (stmt_info) { stmt_info=NULL; } - if (stmt_meta) { // fix bug #796: memory is not freed in case of error during STMT_EXECUTE - if (stmt_meta->pkt) { - uint32_t stmt_global_id=0; - memcpy(&stmt_global_id,(char *)(stmt_meta->pkt)+5,sizeof(uint32_t)); - sess->SLDH->reset(stmt_global_id); - free(stmt_meta->pkt); - stmt_meta->pkt=NULL; - } - stmt_meta = NULL; - } } void PgSQL_Query_Info::init(unsigned char *_p, int len, bool header) { @@ -396,7 +369,7 @@ void PgSQL_Query_Info::query_parser_free() { unsigned long long PgSQL_Query_Info::query_parser_update_counters() { if (stmt_info) { - //PgQueryCmd=stmt_info->MyComQueryCmd; + PgQueryCmd=stmt_info->PgQueryCmd; } if (PgQueryCmd==PGSQL_QUERY___NONE) return 0; // this means that it was never initialized if (PgQueryCmd==PGSQL_QUERY__UNINITIALIZED) return 0; // this means that it was never initialized @@ -550,10 +523,7 @@ PgSQL_Session::PgSQL_Session() { transaction_started_at = 0; CurrentQuery.sess = this; - CurrentQuery.mysql_stmt = NULL; - CurrentQuery.stmt_meta = NULL; CurrentQuery.stmt_global_id = 0; - CurrentQuery.stmt_client_id = 0; CurrentQuery.stmt_info = NULL; current_hostgroup = -1; @@ -590,14 +560,6 @@ void PgSQL_Session::reset() { default_hostgroup = -1; locked_on_hostgroup = -1; locked_on_hostgroup_and_all_variables_set = false; - if (sess_STMTs_meta) { - delete sess_STMTs_meta; - sess_STMTs_meta = NULL; - } - if (SLDH) { - delete SLDH; - SLDH = NULL; - } if (mybes) { reset_all_backends(); delete mybes; @@ -2368,7 +2330,10 @@ __get_pkts_from_client: if (thread->variables.stats_time_query_processor) { clock_gettime(CLOCK_THREAD_CPUTIME_ID, &begint); } - qpo = GloPgQPro->process_query(this, pkt.ptr, pkt.size, &CurrentQuery); + unsigned int query_len = pkt.size - 5; // excluding header + char* query_ptr = (char*)pkt.ptr + 5; + + qpo = GloPgQPro->process_query(this, query_ptr, query_len, &CurrentQuery); if (thread->variables.stats_time_query_processor) { clock_gettime(CLOCK_THREAD_CPUTIME_ID, &endt); thread->status_variables.stvar[st_var_query_processor_time] = thread->status_variables.stvar[st_var_query_processor_time] + @@ -2510,7 +2475,7 @@ __get_pkts_from_client: proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Received query to be processed...\n"); mybe->server_myds->killed_at = 0; mybe->server_myds->kill_type = 0; - mybe->server_myds->mysql_real_query.init(&pkt); + mybe->server_myds->pgsql_real_query.init(&pkt); mybe->server_myds->statuses.questions++; client_myds->setDSS_STATE_QUERY_SENT_NET(); } @@ -2524,6 +2489,12 @@ __get_pkts_from_client: return handler_ret; break; case 'P': + if (handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_PARSE(pkt) == false) { + handler_ret = -1; + return handler_ret; + } + break; + case 'B': case 'D': case 'E': @@ -2531,6 +2502,29 @@ __get_pkts_from_client: l_free(pkt.size, pkt.ptr); continue; case 'S': + { + __run_sync_again: + int rc = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_SYNC(pkt); + + if (rc == -1) { + handler_ret = -1; + return handler_ret; + } + + if (rc == 0) { + if (pending_packets.empty() == false) { + writeout(); + goto __run_sync_again; + } else { + // we do not need this packet anymore + l_free(pkt.size, pkt.ptr); + pkt.ptr = NULL; + pkt.size = 0; + } + } + + } + break; default: proxy_error("Not implemented yet. Message type:'%c'\n", c); client_myds->setDSS_STATE_QUERY_SENT_NET(); @@ -2594,7 +2588,11 @@ __get_pkts_from_client: if (thread->variables.stats_time_query_processor) { clock_gettime(CLOCK_THREAD_CPUTIME_ID, &begint); } - qpo = GloPgQPro->process_query(this, pkt.ptr, pkt.size, &CurrentQuery); + + unsigned int query_len = pkt.size - 4 - 1; // excluding header + char* query_ptr = (char*)pkt.ptr + 4 + 1; + + qpo = GloPgQPro->process_query(this, query_ptr, query_len, &CurrentQuery); if (thread->variables.stats_time_query_processor) { clock_gettime(CLOCK_THREAD_CPUTIME_ID, &endt); thread->status_variables.stvar[st_var_query_processor_time] = thread->status_variables.stvar[st_var_query_processor_time] + @@ -2736,7 +2734,7 @@ __get_pkts_from_client: proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Received query to be processed with MariaDB Client library\n"); mybe->server_myds->killed_at = 0; mybe->server_myds->kill_type = 0; - mybe->server_myds->mysql_real_query.init(&pkt); + mybe->server_myds->pgsql_real_query.init(&pkt); mybe->server_myds->statuses.questions++; client_myds->setDSS_STATE_QUERY_SENT_NET(); } @@ -2744,13 +2742,6 @@ __get_pkts_from_client: handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___not_mysql(pkt); } break; - /*case _MYSQL_COM_STMT_PREPARE: - handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_STMT_PREPARE(pkt); - break; - case _MYSQL_COM_STMT_EXECUTE: - handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_STMT_EXECUTE(pkt); - break; - */ default: // in this switch we only handle the most common commands. // The not common commands are handled by "default" , that @@ -2980,15 +2971,28 @@ bool PgSQL_Session::handler_minus1_HandleErrorCodes(PgSQL_Data_Stream* myds, int // this function used to be inline. void PgSQL_Session::handler_minus1_GenerateErrorMessage(PgSQL_Data_Stream* myds, bool& wrong_pass) { PgSQL_Connection* myconn = myds->myconn; + + if (myconn == NULL) { + client_myds->myprot.generate_error_packet(true, true, "Lost connection to PostgreSQL server during query", + PGSQL_ERROR_CODES::ERRCODE_CONNECTION_FAILURE, false); + return; + } + switch (status) { case PROCESSING_QUERY: - if (myconn) { - PgSQL_Result_to_PgSQL_wire(myconn, myds); - } - else { - PgSQL_Result_to_PgSQL_wire(NULL, myds); + PgSQL_Result_to_PgSQL_wire(myconn, myds); + break; + case PROCESSING_STMT_PREPARE: + client_myds->myprot.generate_error_packet(true, true, myconn->get_error_message().c_str(), myconn->get_error_code(), false); + if (previous_status.size()) { + // an STMT_PREPARE failed + // we have a previous status, probably STMT_EXECUTE, + // but returning to that status is not safe after STMT_PREPARE failed + // for this reason we exit immediately + wrong_pass = true; } break; + case PROCESSING_STMT_EXECUTE: default: // LCOV_EXCL_START assert(0); @@ -3022,16 +3026,23 @@ int PgSQL_Session::RunQuery(PgSQL_Data_Stream* myds, PgSQL_Connection* myconn) { int rc = 0; switch (status) { case PROCESSING_QUERY: - rc = myconn->async_query(myds->revents, myds->mysql_real_query.QueryPtr, myds->mysql_real_query.QuerySize); + rc = myconn->async_query(myds->revents, myds->pgsql_real_query.QueryPtr, myds->pgsql_real_query.QuerySize); break; - /*case PROCESSING_STMT_PREPARE: - rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, &CurrentQuery.mysql_stmt); + case PROCESSING_STMT_PREPARE: + { + uint32_t backend_stmt_id = myconn->local_stmts->generate_new_backend_stmt_id(); + CurrentQuery.stmt_backend_id = backend_stmt_id; // this is used to generate the name of the prepared statement in the backend + const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(backend_stmt_id); + rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, backend_stmt_name.c_str()); + } break; case PROCESSING_STMT_EXECUTE: - PROXY_TRACE2(); - rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, &CurrentQuery.mysql_stmt, CurrentQuery.stmt_meta); + assert(CurrentQuery.stmt_backend_id); + { + const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(CurrentQuery.stmt_backend_id); + rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, backend_stmt_name.c_str()); + } break; - */ default: // LCOV_EXCL_START assert(0); @@ -3154,8 +3165,8 @@ handler_again: } break; - //case PROCESSING_STMT_PREPARE: - //case PROCESSING_STMT_EXECUTE: + case PROCESSING_STMT_PREPARE: + case PROCESSING_STMT_EXECUTE: case PROCESSING_QUERY: //fprintf(stderr,"PROCESSING_QUERY\n"); if (pause_until > thread->curtime) { @@ -3332,26 +3343,21 @@ handler_again: //autocommit = myconn->pgsql->server_status & SERVER_STATUS_AUTOCOMMIT; } - /*if (mirror == false && myconn->pgsql) { - // Support for LAST_INSERT_ID() - if (myconn->pgsql->insert_id) { - last_insert_id = myconn->pgsql->insert_id; - } - if (myconn->pgsql->affected_rows) { - if (myconn->pgsql->affected_rows != ULLONG_MAX) { - last_HG_affected_rows = current_hostgroup; - if (pgsql_thread___auto_increment_delay_multiplex && myconn->pgsql->insert_id) { - myconn->auto_increment_delay_token = pgsql_thread___auto_increment_delay_multiplex + 1; - __sync_fetch_and_add(&PgHGM->status.auto_increment_delay_multiplex, 1); - } - } - } - }*/ - switch (status) { case PROCESSING_QUERY: PgSQL_Result_to_PgSQL_wire(myconn, myconn->myds); break; + case PROCESSING_STMT_PREPARE: + { + enum session_status st; + if (handler___rc0_PROCESSING_STMT_PREPARE(st, myds, prepared_stmt_with_no_params)) { + NEXT_IMMEDIATE(st); + } + } + break; + case PROCESSING_STMT_EXECUTE: + //handler_rc0_PROCESSING_STMT_EXECUTE(myds); + break; default: // LCOV_EXCL_START assert(0); @@ -3434,6 +3440,23 @@ handler_again: } } } + + /* + // FIXME: Temporary workaround. Update the logic below when pipeline mode is implemented + if (rc != 1 && pkt.size && pkt.ptr && ((char*)pkt.ptr)[0] == 'S') { // it's a sync packet + // sent sync packet again to client queue, to execute sync in next iteration to handle remaining pending packets + if (pending_packets.empty() == false) { + writeout(); + handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_SYNC(pkt); + goto handler_again; + } else { + // we do not need this packet anymore + l_free(pkt.size, pkt.ptr); + pkt.ptr = NULL; + pkt.size = 0; + } + } + */ goto __exit_DSS__STATE_NOT_INITIALIZED; } break; @@ -4057,8 +4080,8 @@ void PgSQL_Session::handler_WCD_SS_MCQ_qpo_OK_msg(PtrSize_t* pkt) { client_myds->DSS = STATE_QUERY_SENT_NET; unsigned int nTrx = NumActiveTransactions(); - const char trx_state = (nTrx ? 'T' : 'I'); - client_myds->myprot.generate_ok_packet(true, true, qpo->OK_msg, 0, (const char*)pkt->ptr + 5, trx_state); + const char txn_state = (nTrx ? 'T' : 'I'); + client_myds->myprot.generate_ok_packet(true, true, qpo->OK_msg, 0, (const char*)pkt->ptr + 5, txn_state); RequestEnd(NULL); l_free(pkt->size, pkt->ptr); } @@ -4255,8 +4278,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C if (value1.empty()) { client_myds->DSS = STATE_QUERY_SENT_NET; unsigned int nTrx = NumActiveTransactions(); - const char trx_state = (nTrx ? 'T' : 'I'); - client_myds->myprot.generate_ok_packet(true, true, NULL, 0, dig, trx_state, NULL, param_status); + const char txn_state = (nTrx ? 'T' : 'I'); + client_myds->myprot.generate_ok_packet(true, true, NULL, 0, dig, txn_state, NULL, param_status); RequestEnd(NULL); l_free(pkt->size, pkt->ptr); return true; @@ -4473,8 +4496,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C client_myds->DSS = STATE_QUERY_SENT_NET; unsigned int nTrx = NumActiveTransactions(); - const char trx_state = (nTrx ? 'T' : 'I'); - client_myds->myprot.generate_ok_packet(true, true, NULL, 0, dig, trx_state, NULL, param_status); + const char txn_state = (nTrx ? 'T' : 'I'); + client_myds->myprot.generate_ok_packet(true, true, NULL, 0, dig, txn_state, NULL, param_status); RequestEnd(NULL); l_free(pkt->size, pkt->ptr); return true; @@ -4712,8 +4735,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C } client_myds->DSS = STATE_QUERY_SENT_NET; unsigned int nTrx = NumActiveTransactions(); - const char trx_state = (nTrx ? 'T' : 'I'); - client_myds->myprot.generate_ok_packet(true, true, NULL, 0, dig, trx_state, NULL, param_status); + const char txn_state = (nTrx ? 'T' : 'I'); + client_myds->myprot.generate_ok_packet(true, true, NULL, 0, dig, txn_state, NULL, param_status); if (mirror == false) { RequestEnd(NULL); @@ -5497,7 +5520,7 @@ void PgSQL_Session::RequestEnd(PgSQL_Data_Stream* myds, const unsigned int myerr myds->myconn->async_free_result(); myds->myconn->compute_unknown_transaction_status(); } - myds->free_mysql_real_query(); + myds->free_pgsql_real_query(); } if (session_fast_forward == SESSION_FORWARD_TYPE_NONE) { // reset status of the session @@ -5862,7 +5885,7 @@ void PgSQL_Session::set_previous_status_mode3(bool allow_execute) { case PROCESSING_QUERY: previous_status.push(PROCESSING_QUERY); break; - /*case PROCESSING_STMT_PREPARE: + case PROCESSING_STMT_PREPARE: previous_status.push(PROCESSING_STMT_PREPARE); break; case PROCESSING_STMT_EXECUTE: @@ -5870,7 +5893,7 @@ void PgSQL_Session::set_previous_status_mode3(bool allow_execute) { previous_status.push(PROCESSING_STMT_EXECUTE); break; } - */ + default: // LCOV_EXCL_START assert(0); // Assert to indicate an unexpected status value @@ -5935,7 +5958,7 @@ void PgSQL_Session::switch_normal_to_fast_forward_mode(PtrSize_t& pkt, std::stri // as we are in FAST_FORWARD mode, we directly send the packet to the backend. // need to reset mysql_real_query - mybe->server_myds->mysql_real_query.reset(); + mybe->server_myds->pgsql_real_query.reset(); } void PgSQL_Session::switch_fast_forward_to_normal_mode() { @@ -5999,6 +6022,237 @@ void PgSQL_Session::reset_default_session_variable(enum pgsql_variable_name idx) } } +int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg) { + + thread->status_variables.stvar[st_var_frontend_stmt_prepare]++; + thread->status_variables.stvar[st_var_queries]++; + + bool lock_hostgroup = false; + bool rc_break = false; + + CurrentQuery.begin((unsigned char*)parse_msg->query_string, strlen(parse_msg->query_string) + 1, false); + CurrentQuery.stmt_client_name = (char*)parse_msg->stmt_name; + + timespec begint; + timespec endt; + if (thread->variables.stats_time_query_processor) { + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &begint); + } + qpo = GloPgQPro->process_query(this, (unsigned char*)parse_msg->query_string, strlen(parse_msg->query_string) + 1, &CurrentQuery); + if (thread->variables.stats_time_query_processor) { + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &endt); + thread->status_variables.stvar[st_var_query_processor_time] = thread->status_variables.stvar[st_var_query_processor_time] + + (endt.tv_sec * 1000000000 + endt.tv_nsec) - + (begint.tv_sec * 1000000000 + begint.tv_nsec); + } + assert(qpo); // GloPgQPro->process_mysql_query() should always return a qpo + // setting 'prepared' to prevent fetching results from the cache if the digest matches + rc_break = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(&pkt, &lock_hostgroup, PgSQL_ps_type_prepare_stmt); + if (rc_break == true) + return 0; + + if (pgsql_thread___set_query_lock_on_hostgroup == 1) { + if (locked_on_hostgroup < 0) { + if (lock_hostgroup) { + // we are locking on hostgroup now + locked_on_hostgroup = current_hostgroup; + } + } + if (locked_on_hostgroup >= 0) { + if (current_hostgroup != locked_on_hostgroup) { + client_myds->DSS = STATE_QUERY_SENT_NET; + int l = CurrentQuery.QueryLength; + char* end = (char*)""; + if (l > 256) { + l = 253; + end = (char*)"..."; + } + string nqn = string((char*)CurrentQuery.QueryPointer, l); + char* err_msg = (char*)"Session trying to reach HG %d while locked on HG %d . Rejecting query: %s"; + char* buf = (char*)malloc(strlen(err_msg) + strlen(nqn.c_str()) + strlen(end) + 64); + sprintf(buf, err_msg, current_hostgroup, locked_on_hostgroup, nqn.c_str(), end); + client_myds->myprot.generate_error_packet(true, true, buf, PGSQL_ERROR_CODES::ERRCODE_RAISE_EXCEPTION, + false, true); + thread->status_variables.stvar[st_var_hostgroup_locked_queries]++; + RequestEnd(NULL); + free(buf); + return 0; + } + } + } + mybe = find_or_create_backend(current_hostgroup); + + PgSQL_STMTs_local_v14* local_stmts = client_myds->myconn->local_stmts; + std::string stmt_name = (char*)CurrentQuery.stmt_client_name; // create a string + + // if the same statement name is used, we drop it + //FIXME: Revisit this logic + if (auto search = local_stmts->stmt_name_to_global_ids.find(stmt_name); + search != local_stmts->stmt_name_to_global_ids.end()) { + uint64_t client_global_id = search->second; + auto range = local_stmts->global_id_to_stmt_names.equal_range(client_global_id); + assert(range.first != range.second); + for (auto it = range.first; it != range.second; ++it) { + if (it->second == stmt_name) { + local_stmts->global_id_to_stmt_names.erase(it); + break; + } + } + local_stmts->stmt_name_to_global_ids.erase(search); + client_myds->myconn->local_stmts->client_close(stmt_name); + } + uint64_t hash = client_myds->myconn->local_stmts->compute_hash( + (char*)client_myds->myconn->userinfo->username, + (char*)client_myds->myconn->userinfo->dbname, + (char*)CurrentQuery.QueryPointer, + CurrentQuery.QueryLength + ); + PgSQL_STMT_Global_info* stmt_info = NULL; + // we first lock GloStmt + GloPgStmt->wrlock(); + stmt_info = GloPgStmt->find_prepared_statement_by_hash(hash); + if (stmt_info) { + local_stmts->client_insert(stmt_info->statement_id, stmt_name); + CurrentQuery.stmt_global_id = stmt_info->statement_id; + client_myds->setDSS_STATE_QUERY_SENT_NET(); + bool send_ready_packet = pending_packets.empty(); + unsigned int nTxn = NumActiveTransactions(); + const char txn_state = (nTxn ? 'T' : 'I'); + client_myds->myprot.generate_parse_completion_packet(true, send_ready_packet, txn_state); + LogQuery(NULL); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + CurrentQuery.end_time = thread->curtime; + CurrentQuery.end(); + GloPgStmt->unlock(); + return 0; + } else { + mybe = find_or_create_backend(current_hostgroup); + status = PROCESSING_STMT_PREPARE; + mybe->server_myds->connect_retries_on_failure = pgsql_thread___connect_retries_on_failure; + mybe->server_myds->wait_until = 0; + pause_until = 0; + mybe->server_myds->killed_at = 0; + mybe->server_myds->kill_type = 0; + auto parse_pkt = parse_msg->detach(); // detach the packet from the parse message + mybe->server_myds->pgsql_real_query.init(&parse_pkt); // mem leak fix + mybe->server_myds->statuses.questions++; + client_myds->setDSS_STATE_QUERY_SENT_NET(); + } + GloPgStmt->unlock(); + return 1; +} + +int PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_SYNC(PtrSize_t& pkt) { + if (session_type != PROXYSQL_SESSION_PGSQL) { // only PgSQL module supports prepared statement!! + client_myds->setDSS_STATE_QUERY_SENT_NET(); + client_myds->myprot.generate_error_packet(true, false, "Prepared statements not supported", PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, + false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return 0; + } + + if (pending_packets.empty()) { + unsigned int nTxn = NumActiveTransactions(); + const char txn_state = (nTxn ? 'T' : 'I'); + client_myds->myprot.generate_ready_for_query_packet(true, txn_state); + return 0; + } + + // we have pending packets, so we will process them now + auto packet = std::move(pending_packets.front()); // get the packet from the queue + pending_packets.pop(); // remove the packet from the queue + + const std::unique_ptr* parse_msg = std::get_if>(&packet); + + int rc = -1; + if (parse_msg && parse_msg->get()) { + rc = handle_post_sync_parse_message(parse_msg->get()); + } + + return rc; // make sure to not return before unlocking GloMyStmt +} + +bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_PARSE(PtrSize_t& pkt) { + if (session_type != PROXYSQL_SESSION_PGSQL) { // only PgSQL module supports prepared statement!! + l_free(pkt.size, pkt.ptr); + client_myds->setDSS_STATE_QUERY_SENT_NET(); + client_myds->myprot.generate_error_packet(true, false, "Prepared statements not supported", PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, + false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return true; + } + + std::unique_ptr parse_msg(new PgSQL_Parse_Message()); + bool rc = parse_msg->parse(pkt); + if (rc == false) { + l_free(pkt.size, pkt.ptr); + client_myds->setDSS_STATE_QUERY_SENT_NET(); + client_myds->myprot.generate_error_packet(true, false, "invalid string in message", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, + true, true); + writeout(); + return false; + } + pending_packets.push(std::move(parse_msg)); // we will process it later, after sync packet + return true; +} + +bool PgSQL_Session::handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& st, PgSQL_Data_Stream* myds, bool& prepared_stmt_with_no_params) { + thread->status_variables.stvar[st_var_backend_stmt_prepare]++; + GloPgStmt->wrlock(); + uint32_t client_stmtid = 0; + uint64_t global_stmtid; + + PgSQL_STMT_Global_info* stmt_info = NULL; + stmt_info = GloPgStmt->add_prepared_statement( + (char*)client_myds->myconn->userinfo->username, + (char*)client_myds->myconn->userinfo->schemaname, + (char*)CurrentQuery.QueryPointer, + CurrentQuery.QueryLength, + CurrentQuery.QueryParserArgs.first_comment, + //nullptr, //FIXME: add correct parse packet here + false); + if (CurrentQuery.QueryParserArgs.digest_text) { + if (stmt_info->digest_text == NULL) { + stmt_info->digest_text = strdup(CurrentQuery.QueryParserArgs.digest_text); + stmt_info->digest = CurrentQuery.QueryParserArgs.digest; // copy digest + stmt_info->PgQueryCmd = CurrentQuery.PgQueryCmd; // copy MyComQueryCmd + stmt_info->calculate_mem_usage(); + } + } + global_stmtid = stmt_info->statement_id; + myds->myconn->local_stmts->backend_insert(global_stmtid, CurrentQuery.stmt_backend_id); + // We only perform the client_insert when there is no previous status, this + // is, when 'PROCESSING_STMT_PREPARE' is reached directly without transitioning from a previous status + // like 'PROCESSING_STMT_EXECUTE'. + if (previous_status.size() == 0) { + assert(CurrentQuery.stmt_client_name); + client_myds->myconn->local_stmts->client_insert(global_stmtid, CurrentQuery.stmt_client_name); + } + st = status; + size_t sts = previous_status.size(); + if (sts) { + myds->myconn->async_state_machine = ASYNC_IDLE; + myds->DSS = STATE_MARIADB_GENERIC; + st = previous_status.top(); + previous_status.pop(); + GloPgStmt->unlock(); + return true; + } else { + bool send_ready_packet = pending_packets.empty(); + char txn_state = myds->myconn->get_transaction_status_char(); + client_myds->myprot.generate_parse_completion_packet(true, send_ready_packet, txn_state); + //if (stmt_info->num_params == 0) { + // prepared_stmt_with_no_params = true; + //} + LogQuery(myds); + GloPgStmt->unlock(); + } + return false; +} + // Optimized single‐pass parser for PostgreSQL DateStyle strings. // It supports input in one of these forms: // - "ISO, MDY" (two tokens separated by a comma) diff --git a/lib/PgSQL_Thread.cpp b/lib/PgSQL_Thread.cpp index f6f3c6a53..d78c2bd18 100644 --- a/lib/PgSQL_Thread.cpp +++ b/lib/PgSQL_Thread.cpp @@ -21,7 +21,7 @@ using json = nlohmann::json; #include "PgSQL_Data_Stream.h" #include "PgSQL_Query_Processor.h" #include "StatCounters.h" -#include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "PgSQL_Logger.hpp" #include "PgSQL_Variables_Validator.h" #include @@ -4736,7 +4736,7 @@ SQLite3_result* PgSQL_Threads_Handler::SQL3_Processlist() { } } else { // prepared statement - MySQL_STMT_Global_info* si = sess->CurrentQuery.stmt_info; + PgSQL_STMT_Global_info* si = sess->CurrentQuery.stmt_info; if (si->query_length) { pta[13] = (char*)malloc(si->query_length + 1); strncpy(pta[13], si->query, si->query_length); diff --git a/lib/PgSQL_Variables.cpp b/lib/PgSQL_Variables.cpp index 801add653..25e2d6163 100644 --- a/lib/PgSQL_Variables.cpp +++ b/lib/PgSQL_Variables.cpp @@ -265,14 +265,12 @@ inline bool verify_server_variable(PgSQL_Session* session, int idx, uint32_t cli case PROCESSING_QUERY: session->previous_status.push(PROCESSING_QUERY); break; - /* case PROCESSING_STMT_PREPARE: session->previous_status.push(PROCESSING_STMT_PREPARE); break; case PROCESSING_STMT_EXECUTE: session->previous_status.push(PROCESSING_STMT_EXECUTE); break; - */ default: // LCOV_EXCL_START proxy_error("Wrong status %d\n", session->status); diff --git a/src/main.cpp b/src/main.cpp index 4fcc2ff57..1fc66f81f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -22,6 +22,7 @@ using json = nlohmann::json; #include "ProxySQL_Statistics.hpp" #include "MySQL_PreparedStatement.h" +#include "PgSQL_PreparedStatement.h" #include "ProxySQL_Cluster.hpp" #include "MySQL_Logger.hpp" #include "PgSQL_Logger.hpp" @@ -478,6 +479,7 @@ MySQL_Threads_Handler *GloMTH = NULL; PgSQL_Threads_Handler* GloPTH = NULL; Web_Interface *GloWebInterface; MySQL_STMT_Manager_v14 *GloMyStmt; +PgSQL_STMT_Manager_v14 *GloPgStmt; MySQL_Monitor *GloMyMon; PgSQL_Monitor *GloPgMon; @@ -903,7 +905,7 @@ void ProxySQL_Main_init_main_modules() { GloMyLogger=NULL; GloPgSQL_Logger = NULL; GloMyStmt=NULL; - + GloPgStmt = NULL; // initialize libev if (!ev_default_loop (EVBACKEND_POLL | EVFLAG_NOENV)) { fprintf(stderr,"could not initialise libev"); @@ -921,7 +923,7 @@ void ProxySQL_Main_init_main_modules() { GloPgSQL_Logger = new PgSQL_Logger(); GloPgSQL_Logger->print_version(); GloMyStmt=new MySQL_STMT_Manager_v14(); - + GloPgStmt=new PgSQL_STMT_Manager_v14(); PgHGM = new PgSQL_HostGroups_Manager(); PgHGM->init(); PgSQL_Threads_Handler* _tmp_GloPTH = NULL; @@ -1287,6 +1289,10 @@ void ProxySQL_Main_shutdown_all_modules() { delete GloMyStmt; GloMyStmt=NULL; } + if (GloPgStmt) { + delete GloPgStmt; + GloPgStmt = NULL; + } } void ProxySQL_Main_init() {