diff --git a/include/PgSQL_Connection.h b/include/PgSQL_Connection.h index 653e22831..47390a32a 100644 --- a/include/PgSQL_Connection.h +++ b/include/PgSQL_Connection.h @@ -285,17 +285,17 @@ public: * on the PostgreSQL backend. * */ - void stmt_describe_prepared_start(); + void stmt_describe_start(); /** * @brief Continues the asynchronous description of a prepared SQL statement. * - * This method is called after stmt_describe_prepared_start() to handle the next step in the + * This method is called after stmt_describe_start() to handle the next step in the * asynchronous state machine for describing a prepared SQL statement on the PostgreSQL backend. * * @param event The event flag indicating the current I/O event. */ - void stmt_describe_prepared_cont(short event); + void stmt_describe_cont(short event); /** * @brief Initiates the asynchronous execution of a prepared SQL statement. @@ -321,8 +321,8 @@ public: void reset_session_cont(short event); int async_connect(short event); - int async_query(short event, const char* stmt, unsigned long length, const char* backend_stmt_name = nullptr, - bool is_prepare_stmt = false, PgSQL_Bind_Message* bind_message = nullptr); + int async_query(short event, const char* stmt, unsigned long length, const char* backend_stmt_name = nullptr, + PgSQL_Extended_Query_Type type = PGSQL_EXTENDED_QUERY_TYPE_NOT_SET, const PgSQL_Extended_Query_Info* extended_query_info = nullptr); int async_ping(short event); int async_reset_session(short event); int async_send_simple_command(short event, char* stmt, unsigned long length); // no result set expected @@ -498,8 +498,7 @@ public: unsigned int reorder_dynamic_variables_idx(); unsigned int number_of_matching_session_variables(const PgSQL_Connection* client_conn, unsigned int& not_matching); - void set_query(const char* stmt, unsigned long length, const char* backend_stmt_name = nullptr, - const PgSQL_Bind_Message* bind_msg = nullptr); + void set_query(const char* stmt, unsigned long length, const char* _backend_stmt_name = nullptr, const PgSQL_Extended_Query_Info* extended_query_info = nullptr); void reset(); bool IsKeepMultiplexEnabledVariables(char* query_digest_text); @@ -508,7 +507,7 @@ public: unsigned long length; const char* ptr; const char* backend_stmt_name; - const PgSQL_Bind_Message* bind_message; + const PgSQL_Extended_Query_Info* extended_query_info; } query; struct { diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index 351242b0a..eee51fa93 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -770,7 +770,7 @@ public: bool generate_parse_completion_packet(bool send, bool ready, char trx_state, PtrSize_t* _ptr = NULL); bool generate_ready_for_query_packet(bool send, char trx_state, PtrSize_t* _ptr = NULL); - bool generate_describe_completion_packet(bool send, bool ready, const PgSQL_Describe_Prepared_Info* desc, char trx_state, PtrSize_t* _ptr = NULL); + bool generate_describe_completion_packet(bool send, bool ready, const PgSQL_Describe_Prepared_Info* desc, uint8_t stmt_type, char trx_state, PtrSize_t* _ptr = NULL); bool generate_close_completion_packet(bool send, bool ready, char trx_state, PtrSize_t* _ptr = NULL); bool generate_bind_completion_packet(bool send, bool ready, char trx_state, PtrSize_t* _ptr = NULL); diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 57e3bbebc..0b19c693c 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -43,40 +43,12 @@ enum proxysql_session_type { */ enum PgSQL_Extended_Query_Type : uint8_t { - PGSQL_EXTENDED_QUERY_TYPE_NOT_SET = 0x0, - PGSQL_EXTENDED_QUERY_TYPE_PARSE = 0x1, - PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE = 0x2, - PGSQL_EXTENDED_QUERY_TYPE_EXECUTE = 0x4, + PGSQL_EXTENDED_QUERY_TYPE_NOT_SET = 0x00, + PGSQL_EXTENDED_QUERY_TYPE_PARSE = 0x01, + PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE = 0x02, + PGSQL_EXTENDED_QUERY_TYPE_EXECUTE = 0x04, }; -#if 0 //FIXME: remove after extended query support is fully implemented -class PgSQL_Formatted_Bind_Message { -public: - /*uint16_t num_param_formats = 0; - uint16_t num_param_values = 0; - uint16_t num_result_formats = 0; - - const unsigned int* param_formats = NULL; - const char* param_values = NULL; - const int* param_values_len = NULL; - const int* result_formats = NULL; - - const char* stmt_name = NULL; - const char* portal_name = NULL; - */ - - std::vector param_values; - std::vector param_lengths; - std::vector param_formats; - std::vector result_formats; - std::string stmt_name; - std::string portal_name; - - PgSQL_Formatted_Bind_Message(PgSQL_Bind_Message* bind_msg); - ~PgSQL_Formatted_Bind_Message(); -}; -#endif - /* Enumerated types for output format and date order */ typedef enum { DATESTYLE_FORMAT_NONE = 0, @@ -176,23 +148,28 @@ public: class PgSQL_STMT_Global_info; +struct PgSQL_Extended_Query_Info { + const char* stmt_client_name; + const char* stmt_client_portal_name; + PgSQL_STMT_Global_info* stmt_info; + PgSQL_Bind_Message* bind_msg; + uint64_t stmt_global_id; + uint32_t stmt_backend_id; + uint8_t stmt_type; +}; + class PgSQL_Query_Info { public: unsigned long long start_time; unsigned long long end_time; - uint64_t stmt_global_id; uint64_t affected_rows; uint64_t rows_sent; uint64_t waiting_since; + PgSQL_Extended_Query_Info extended_query_info; PgSQL_Session* sess; unsigned char* QueryPointer; - const char* stmt_client_name; - PgSQL_STMT_Global_info* stmt_info; - PgSQL_Bind_Message* bind_msg; SQP_par_t QueryParserArgs; - - uint32_t stmt_backend_id; int QueryLength; enum PGSQL_QUERY_command PgQueryCmd; @@ -213,6 +190,7 @@ public: bool is_select_NOT_for_update(); private: + void reset_extended_query_info(bool init = false); void init(unsigned char* _p, int len, bool header = false); }; diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index 39c7c16bc..9c7aaed3e 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -125,11 +125,11 @@ enum ASYNC_ST { // MariaDB Async State Machine ASYNC_RESET_SESSION_SUCCESSFUL, ASYNC_RESET_SESSION_FAILED, ASYNC_RESET_SESSION_TIMEOUT, - ASYNC_DESCRIBE_PREPARED_START, - ASYNC_DESCRIBE_PREPARED_CONT, - ASYNC_DESCRIBE_PREPARED_END, - ASYNC_DESCRIBE_PREPARED_SUCCESSFUL, - ASYNC_DESCRIBE_PREPARED_FAILED, + ASYNC_DESCRIBE_START, + ASYNC_DESCRIBE_CONT, + ASYNC_DESCRIBE_END, + ASYNC_DESCRIBE_SUCCESSFUL, + ASYNC_DESCRIBE_FAILED, ASYNC_IDLE }; diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index 120433aee..59cfa875d 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -736,25 +736,25 @@ handler_again: case ASYNC_STMT_PREPARE_SUCCESSFUL: break; - case ASYNC_DESCRIBE_PREPARED_START: - stmt_describe_prepared_start(); + case ASYNC_DESCRIBE_START: + stmt_describe_start(); if (async_exit_status) { - next_event(ASYNC_DESCRIBE_PREPARED_CONT); + next_event(ASYNC_DESCRIBE_CONT); } else { - NEXT_IMMEDIATE(ASYNC_DESCRIBE_PREPARED_END); + NEXT_IMMEDIATE(ASYNC_DESCRIBE_END); } break; - case ASYNC_DESCRIBE_PREPARED_CONT: + case ASYNC_DESCRIBE_CONT: { if (event) { - stmt_describe_prepared_cont(event); + stmt_describe_cont(event); } if (async_exit_status) { - next_event(ASYNC_DESCRIBE_PREPARED_CONT); + next_event(ASYNC_DESCRIBE_CONT); break; } if (is_error_present()) { - NEXT_IMMEDIATE(ASYNC_DESCRIBE_PREPARED_END); + NEXT_IMMEDIATE(ASYNC_DESCRIBE_END); } PGresult* result = get_result(); if (result) { @@ -767,22 +767,22 @@ handler_again: } stmt_metadata_result->populate(result); PQclear(result); - NEXT_IMMEDIATE(ASYNC_DESCRIBE_PREPARED_CONT); + NEXT_IMMEDIATE(ASYNC_DESCRIBE_CONT); } - NEXT_IMMEDIATE(ASYNC_DESCRIBE_PREPARED_END); + NEXT_IMMEDIATE(ASYNC_DESCRIBE_END); } break; - case ASYNC_DESCRIBE_PREPARED_END: + case ASYNC_DESCRIBE_END: PQsetNoticeReceiver(pgsql_conn, &PgSQL_Connection::unhandled_notice_cb, this); if (is_error_present()) { proxy_error("Failed to describe prepared statement: %s\n", get_error_code_with_message().c_str()); - NEXT_IMMEDIATE(ASYNC_DESCRIBE_PREPARED_FAILED); + NEXT_IMMEDIATE(ASYNC_DESCRIBE_FAILED); } else { - NEXT_IMMEDIATE(ASYNC_DESCRIBE_PREPARED_SUCCESSFUL); + NEXT_IMMEDIATE(ASYNC_DESCRIBE_SUCCESSFUL); } break; - case ASYNC_DESCRIBE_PREPARED_SUCCESSFUL: - case ASYNC_DESCRIBE_PREPARED_FAILED: + case ASYNC_DESCRIBE_SUCCESSFUL: + case ASYNC_DESCRIBE_FAILED: break; case ASYNC_STMT_EXECUTE_START: stmt_execute_start(); @@ -1273,8 +1273,8 @@ void PgSQL_Connection::async_free_result() { // 0 when the query is completed // 1 when the query is not completed // the calling function should check pgsql error in pgsql struct -int PgSQL_Connection::async_query(short event, const char* stmt, unsigned long length, - const char* backend_stmt_name, bool is_prepare_stmt, PgSQL_Bind_Message* bind_message) { +int PgSQL_Connection::async_query(short event, const char* stmt, unsigned long length, const char* backend_stmt_name, + PgSQL_Extended_Query_Type type, const PgSQL_Extended_Query_Info* extended_query_info) { PROXY_TRACE(); PROXY_TRACE2(); assert(pgsql_conn); @@ -1303,18 +1303,20 @@ int PgSQL_Connection::async_query(short event, const char* stmt, unsigned long l myds->sess->transaction_started_at = myds->sess->thread->curtime; } } - if (!backend_stmt_name) { + if (!extended_query_info) { async_state_machine = ASYNC_QUERY_START; } else { - if (is_prepare_stmt) { + if (type == PGSQL_EXTENDED_QUERY_TYPE_PARSE) { async_state_machine = ASYNC_STMT_PREPARE_START; - } else if (bind_message) { + } else if (type == PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE) { + async_state_machine = ASYNC_DESCRIBE_START; + } else if (type == PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) { async_state_machine = ASYNC_STMT_EXECUTE_START; } else { - async_state_machine = ASYNC_DESCRIBE_PREPARED_START; + assert(0); // should never reach here } } - set_query(stmt, length, backend_stmt_name, bind_message); + set_query(stmt, length, backend_stmt_name, extended_query_info); default: handler(event); break; @@ -1337,11 +1339,11 @@ int PgSQL_Connection::async_query(short event, const char* stmt, unsigned long l if (async_state_machine == ASYNC_STMT_PREPARE_SUCCESSFUL || async_state_machine == ASYNC_STMT_PREPARE_FAILED || - async_state_machine == ASYNC_DESCRIBE_PREPARED_SUCCESSFUL || - async_state_machine == ASYNC_DESCRIBE_PREPARED_FAILED) { + async_state_machine == ASYNC_DESCRIBE_SUCCESSFUL || + async_state_machine == ASYNC_DESCRIBE_FAILED) { compute_unknown_transaction_status(); if (async_state_machine == ASYNC_STMT_PREPARE_FAILED || - async_state_machine == ASYNC_DESCRIBE_PREPARED_FAILED) { + async_state_machine == ASYNC_DESCRIBE_FAILED) { return -1; } else { return 0; @@ -1624,22 +1626,38 @@ void PgSQL_Connection::stmt_prepare_cont(short event) { pgsql_result = PQgetResult(pgsql_conn); } -void PgSQL_Connection::stmt_describe_prepared_start() { +void PgSQL_Connection::stmt_describe_start() { PROXY_TRACE(); reset_error(); processing_multi_statement = false; async_exit_status = PG_EVENT_NONE; PQsetNoticeReceiver(pgsql_conn, &PgSQL_Connection::notice_handler_cb, this); - // We need to send a describe prepared statement to get the parameter types - if (PQsendDescribePrepared(pgsql_conn, query.backend_stmt_name) == 0) { - set_error_from_PQerrorMessage(); - proxy_error("Failed to send describe prepared. %s\n", get_error_code_with_message().c_str()); + const PgSQL_Extended_Query_Info* extended_query_info = query.extended_query_info; + + switch (extended_query_info->stmt_type) { + case 'P': // Portal + if (PQsendDescribePortal(pgsql_conn, extended_query_info->stmt_client_portal_name) == 0) { + set_error_from_PQerrorMessage(); + proxy_error("Failed to send describe portal message. %s\n", get_error_code_with_message().c_str()); + return; + } + break; + case 'S': // Prepared Statement + if (PQsendDescribePrepared(pgsql_conn, query.backend_stmt_name) == 0) { + set_error_from_PQerrorMessage(); + proxy_error("Failed to send describe prepared statement. %s\n", get_error_code_with_message().c_str()); + return; + } + break; + default: + set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Invalid statement type for describe", false); + proxy_error("Failed to send describe message. %s\n", get_error_code_with_message().c_str()); return; } flush(); } -void PgSQL_Connection::stmt_describe_prepared_cont(short event) { +void PgSQL_Connection::stmt_describe_cont(short event) { PROXY_TRACE(); proxy_debug(PROXY_DEBUG_MYSQL_PROTOCOL, 6, "event=%d\n", event); async_exit_status = PG_EVENT_NONE; @@ -1672,10 +1690,7 @@ void PgSQL_Connection::stmt_execute_start() { async_exit_status = PG_EVENT_NONE; PQsetNoticeReceiver(pgsql_conn, &PgSQL_Connection::notice_handler_cb, this); - //const PgSQL_Formatted_Bind_Message* formatted_bind_message = query.formatted_bind_message; - //assert(formatted_bind_message != NULL); - - const PgSQL_Bind_Message* bind_msg = query.bind_message; + const PgSQL_Bind_Message* bind_msg = query.extended_query_info->bind_msg; assert(bind_msg); // should never be null const PgSQL_Bind_Data* bind_data = bind_msg->data(); // will always have valid data @@ -1695,7 +1710,7 @@ void PgSQL_Connection::stmt_execute_start() { PgSQL_Bind_Message::ParamValue_t param; if (!bind_msg->next_param_value(&valCtx, ¶m)) { proxy_error("Failed to read param value at index %u\n", i); - set_error(PGSQL_GET_ERROR_CODE_STR(ERRCODE_INVALID_PARAMETER_VALUE), + set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Failed to read param value", false); return; } @@ -1716,7 +1731,7 @@ void PgSQL_Connection::stmt_execute_start() { if (!bind_msg->next_format(&fmtCtx, &format)) { proxy_error("Failed to read param format at index %u\n", i); - set_error(PGSQL_GET_ERROR_CODE_STR(ERRCODE_INVALID_PARAMETER_VALUE), + set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Failed to read param format", false); return; return; @@ -1734,7 +1749,7 @@ void PgSQL_Connection::stmt_execute_start() { uint16_t format; if (!bind_msg->next_format(&fmtCtx, &format)) { proxy_error("Failed to read result format at index %u\n", i); - set_error(PGSQL_GET_ERROR_CODE_STR(ERRCODE_INVALID_PARAMETER_VALUE), + set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Failed to read result format", false); return; } @@ -2278,15 +2293,14 @@ bool PgSQL_Connection::MultiplexDisabled(bool check_delay_token) { return ret; } -void PgSQL_Connection::set_query(const char* stmt, unsigned long length, const char* backend_stmt_name, - const PgSQL_Bind_Message* bind_msg) { +void PgSQL_Connection::set_query(const char* stmt, unsigned long length, const char* _backend_stmt_name, const PgSQL_Extended_Query_Info* extended_query_info) { query.length = length; query.ptr = stmt; if (length > largest_query_length) { largest_query_length = length; } - query.backend_stmt_name = backend_stmt_name; - query.bind_message = bind_msg; + query.backend_stmt_name = _backend_stmt_name; + query.extended_query_info = extended_query_info; } bool PgSQL_Connection::IsKeepMultiplexEnabledVariables(char* query_digest_text) { diff --git a/lib/PgSQL_Extended_Query_Message.cpp b/lib/PgSQL_Extended_Query_Message.cpp index fd80bb74f..191a8d695 100644 --- a/lib/PgSQL_Extended_Query_Message.cpp +++ b/lib/PgSQL_Extended_Query_Message.cpp @@ -126,6 +126,10 @@ bool PgSQL_Parse_Message::parse(PtrSize_t& pkt) { offset += _data.num_param_types * sizeof(uint32_t); } + if (offset != pkt_len) { + return false; + } + // take "ownership" _pkt = pkt; @@ -395,6 +399,11 @@ bool PgSQL_Bind_Message::parse(PtrSize_t& pkt) { // Move past the result formats offset += _data.num_result_formats * sizeof(uint16_t); } + + if (offset != pkt_len) { + return false; + } + // take "ownership" _pkt = pkt; // If we reach here, the packet is valid and fully parsed diff --git a/lib/PgSQL_Logger.cpp b/lib/PgSQL_Logger.cpp index 8d3d577e4..59a7eae51 100644 --- a/lib/PgSQL_Logger.cpp +++ b/lib/PgSQL_Logger.cpp @@ -706,7 +706,7 @@ void PgSQL_Logger::log_request(PgSQL_Session *sess, PgSQL_Data_Stream *myds) { if (sess->status != PROCESSING_STMT_EXECUTE) { query_digest = GloPgQPro->get_digest(&sess->CurrentQuery.QueryParserArgs); } else { - query_digest = sess->CurrentQuery.stmt_info->digest; + query_digest = sess->CurrentQuery.extended_query_info.stmt_info->digest; } PgSQL_Event me(let, @@ -720,9 +720,9 @@ void PgSQL_Logger::log_request(PgSQL_Session *sess, PgSQL_Data_Stream *myds) { int ql = 0; switch (sess->status) { case PROCESSING_STMT_EXECUTE: - c = (char *)sess->CurrentQuery.stmt_info->query; - ql = sess->CurrentQuery.stmt_info->query_length; - me.set_client_stmt_name((char*)sess->CurrentQuery.stmt_client_name); + c = (char *)sess->CurrentQuery.extended_query_info.stmt_info->query; + ql = sess->CurrentQuery.extended_query_info.stmt_info->query_length; + me.set_client_stmt_name((char*)sess->CurrentQuery.extended_query_info.stmt_client_name); break; case PROCESSING_STMT_PREPARE: default: @@ -733,7 +733,7 @@ void PgSQL_Logger::log_request(PgSQL_Session *sess, PgSQL_Data_Stream *myds) { // global cache and due to that we immediately reply to the client and session doesn't reach // 'PROCESSING_STMT_PREPARE' state. 'stmt_client_id' is expected to be '0' for anything that isn't // a prepared statement, still, logging should rely 'log_event_type' instead of this value. - me.set_client_stmt_name((char*)sess->CurrentQuery.stmt_client_name); + me.set_client_stmt_name((char*)sess->CurrentQuery.extended_query_info.stmt_client_name); break; } if (c) { diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 6ad7355dc..238e164b6 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -1529,46 +1529,55 @@ bool PgSQL_Protocol::generate_ready_for_query_packet(bool send, char trx_state, return true; } -bool PgSQL_Protocol::generate_describe_completion_packet(bool send, bool ready, const PgSQL_Describe_Prepared_Info* desc, char trx_state, PtrSize_t* _ptr) { +bool PgSQL_Protocol::generate_describe_completion_packet(bool send, bool ready, const PgSQL_Describe_Prepared_Info* desc, uint8_t stmt_type, char trx_state, PtrSize_t* _ptr) { // to avoid memory leak assert(send == true || _ptr); PG_pkt pgpkt{}; - uint32_t size = 0; - // Describe completion message - size = desc->parameter_types_count * sizeof(uint32_t) + sizeof(uint16_t) + 4; // size of the packet, including the type byte - pgpkt.put_char('t'); - pgpkt.put_uint32(size); // size of the packet, including the type byte - // If there are no parameters, we still need to write a zero - pgpkt.put_uint16(desc->parameter_types_count); // number of parameters - for (size_t i = 0; i < desc->parameter_types_count; i++) { - pgpkt.put_uint32(desc->parameter_types[i]); // parameter type OID + // ----------- Parameter Description ('t') ----------- + if (stmt_type == 'S') { + uint32_t size = desc->parameter_types_count * sizeof(uint32_t) + sizeof(uint16_t) + 4; // size of the packet, including the type byte + + pgpkt.put_char('t'); + pgpkt.put_uint32(size); // size of the packet, including the type byte + // If there are no parameters, we still need to write a zero + pgpkt.put_uint16(desc->parameter_types_count); // number of parameters + for (size_t i = 0; i < desc->parameter_types_count; i++) { + pgpkt.put_uint32(desc->parameter_types[i]); // parameter type OID + } } - size = desc->columns_count * (sizeof(uint32_t) + // table OID - sizeof(uint16_t) + // column index - sizeof(uint32_t) + // type OID - sizeof(uint16_t) + // column length - sizeof(uint32_t) + // type modifier - sizeof(uint16_t)) + // format code - sizeof(uint16_t) + 4; // Field count + size of the packet + // ----------- Row Description ('T') ----------- + if (desc->columns_count > 0) { + uint32_t size = desc->columns_count * (sizeof(uint32_t) + // table OID + sizeof(uint16_t) + // column index + sizeof(uint32_t) + // type OID + sizeof(uint16_t) + // column length + sizeof(uint32_t) + // type modifier + sizeof(uint16_t)) + // format code + sizeof(uint16_t) + 4; // Field count + size of the packet - for (size_t i = 0; i < desc->columns_count; i++) { - size += strlen(desc->columns[i].name) + 1; // field name + null terminator - } - pgpkt.put_char('T'); - // If there are no result fields, we still need to write a zero - pgpkt.put_uint32(size); // size of the packet, including the type byte - pgpkt.put_uint16(desc->columns_count); // number of result fields - - for (size_t i = 0; i < desc->columns_count; i++) { - pgpkt.put_string(desc->columns[i].name); // field name - pgpkt.put_uint32(desc->columns[i].table_oid); // table OID - pgpkt.put_uint16(desc->columns[i].column_index); // column index - pgpkt.put_uint32(desc->columns[i].type_oid); // type OID - pgpkt.put_uint16(desc->columns[i].length); // column length - pgpkt.put_uint32(desc->columns[i].type_modifier); // type modifier - pgpkt.put_uint16(desc->columns[i].format); // format code + for (size_t i = 0; i < desc->columns_count; i++) { + size += strlen(desc->columns[i].name) + 1; // field name + null terminator + } + pgpkt.put_char('T'); + // If there are no result fields, we still need to write a zero + pgpkt.put_uint32(size); // size of the packet, including the type byte + pgpkt.put_uint16(desc->columns_count); // number of result fields + + for (size_t i = 0; i < desc->columns_count; i++) { + pgpkt.put_string(desc->columns[i].name); // field name + pgpkt.put_uint32(desc->columns[i].table_oid); // table OID + pgpkt.put_uint16(desc->columns[i].column_index); // column index + pgpkt.put_uint32(desc->columns[i].type_oid); // type OID + pgpkt.put_uint16(desc->columns[i].length); // column length + pgpkt.put_uint32(desc->columns[i].type_modifier); // type modifier + pgpkt.put_uint16(desc->columns[i].format); // format code + } + } else { + // return NoData packet if there are no result fields + pgpkt.put_char('n'); + pgpkt.put_uint32(4); // size of the NoData packet (Fixed 4 bytes) } if (ready == true) { diff --git a/lib/PgSQL_Query_Processor.cpp b/lib/PgSQL_Query_Processor.cpp index b77d93a78..03136160d 100644 --- a/lib/PgSQL_Query_Processor.cpp +++ b/lib/PgSQL_Query_Processor.cpp @@ -277,9 +277,9 @@ PgSQL_Query_Processor_Output* PgSQL_Query_Processor::process_query(PgSQL_Session qp = (SQP_par_t*)&qi->QueryParserArgs; } else { qp = &stmt_exec_qp; - qp->digest = qi->stmt_info->digest; - qp->digest_text = qi->stmt_info->digest_text; - qp->first_comment = qi->stmt_info->first_comment; + qp->digest = qi->extended_query_info.stmt_info->digest; + qp->digest_text = qi->extended_query_info.stmt_info->digest_text; + qp->first_comment = qi->extended_query_info.stmt_info->first_comment; } } #define stackbuffer_size 128 diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index a8c1ebd40..49521be46 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -293,7 +293,6 @@ PgSQL_Query_Info::PgSQL_Query_Info() { QueryLength=0; QueryParserArgs.digest_text=NULL; QueryParserArgs.first_comment=NULL; - stmt_info=NULL; bool_is_select_NOT_for_update=false; bool_is_select_NOT_for_update_computed=false; have_affected_rows=false; // if affected rows is set, last_insert_id is set too @@ -302,20 +301,12 @@ PgSQL_Query_Info::PgSQL_Query_Info() { rows_sent=0; start_time=0; end_time=0; - stmt_backend_id = 0; - stmt_client_name = NULL; - bind_msg = NULL; - stmt_global_id = 0; + reset_extended_query_info(true); } PgSQL_Query_Info::~PgSQL_Query_Info() { GloPgQPro->query_parser_free(&QueryParserArgs); - stmt_info=NULL; - stmt_client_name = NULL; - if (bind_msg) { - delete bind_msg; - bind_msg = NULL; - } + reset_extended_query_info(); } void PgSQL_Query_Info::begin(unsigned char *_p, int len, bool header) { @@ -345,14 +336,19 @@ void PgSQL_Query_Info::end() { if ((end_time-start_time) > (unsigned int)pgsql_thread___long_query_time *1000) { __sync_add_and_fetch(&sess->thread->status_variables.stvar[st_var_queries_slow],1); } - stmt_info = NULL; - stmt_backend_id = 0; - stmt_global_id = 0; - stmt_client_name = NULL; - if (bind_msg) { - delete bind_msg; - bind_msg = NULL; - } + reset_extended_query_info(); +} + +void PgSQL_Query_Info::reset_extended_query_info(bool init) { + if (!init && extended_query_info.bind_msg) + delete extended_query_info.bind_msg; + extended_query_info.bind_msg = nullptr; + extended_query_info.stmt_client_name = nullptr; + extended_query_info.stmt_client_portal_name = nullptr; + extended_query_info.stmt_info = nullptr; + extended_query_info.stmt_global_id = 0; + extended_query_info.stmt_backend_id = 0; + extended_query_info.stmt_type = 'S'; } void PgSQL_Query_Info::init(unsigned char *_p, int len, bool header) { @@ -365,14 +361,7 @@ void PgSQL_Query_Info::init(unsigned char *_p, int len, bool header) { waiting_since = 0; affected_rows=0; rows_sent=0; - stmt_backend_id = 0; - stmt_global_id = 0; - stmt_info = NULL; - stmt_client_name = NULL; - if (bind_msg) { - delete bind_msg; - bind_msg = NULL; - } + reset_extended_query_info(); } void PgSQL_Query_Info::query_parser_init() { @@ -389,8 +378,8 @@ void PgSQL_Query_Info::query_parser_free() { } unsigned long long PgSQL_Query_Info::query_parser_update_counters() { - if (stmt_info) { - PgQueryCmd=stmt_info->PgQueryCmd; + if (extended_query_info.stmt_info) { + PgQueryCmd= extended_query_info.stmt_info->PgQueryCmd; } if (PgQueryCmd==PGSQL_QUERY___NONE) return 0; // this means that it was never initialized if (PgQueryCmd==PGSQL_QUERY__UNINITIALIZED) return 0; // this means that it was never initialized @@ -406,8 +395,8 @@ char * PgSQL_Query_Info::get_digest_text() { } bool PgSQL_Query_Info::is_select_NOT_for_update() { - if (stmt_info) { // we are processing a prepared statement. We already have the information - return stmt_info->is_select_NOT_for_update; + if (extended_query_info.stmt_info) { // we are processing a prepared statement. We already have the information + return extended_query_info.stmt_info->is_select_NOT_for_update; } if (QueryPointer==NULL) { return false; @@ -2802,32 +2791,43 @@ int PgSQL_Session::RunQuery(PgSQL_Data_Stream* myds, PgSQL_Connection* myconn) { break; case PROCESSING_STMT_PREPARE: { - if (CurrentQuery.stmt_backend_id == 0) { + if (CurrentQuery.extended_query_info.stmt_backend_id == 0) { uint32_t backend_stmt_id = myconn->local_stmts->generate_new_backend_stmt_id(); - CurrentQuery.stmt_backend_id = backend_stmt_id; + CurrentQuery.extended_query_info.stmt_backend_id = backend_stmt_id; proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Session %p myconn %p pgsql_conn %p Processing STMT_PREPARE with new backend_stmt_id=%u\n", this, myconn, myconn->pgsql_conn, backend_stmt_id); } // this is used to generate the name of the prepared statement in the backend - const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(CurrentQuery.stmt_backend_id); - rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, backend_stmt_name.c_str(), true); + const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(CurrentQuery.extended_query_info.stmt_backend_id); + rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, + backend_stmt_name.c_str(), PGSQL_EXTENDED_QUERY_TYPE_PARSE, &CurrentQuery.extended_query_info); } break; case PROCESSING_STMT_DESCRIBE: - assert(CurrentQuery.stmt_backend_id); + case PROCESSING_STMT_EXECUTE: + assert(CurrentQuery.extended_query_info.stmt_backend_id); { - const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(CurrentQuery.stmt_backend_id); - rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, backend_stmt_name.c_str()); + PgSQL_Extended_Query_Type type = (status == PROCESSING_STMT_DESCRIBE) ? PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE : PGSQL_EXTENDED_QUERY_TYPE_EXECUTE; + const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(CurrentQuery.extended_query_info.stmt_backend_id); + rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, + backend_stmt_name.c_str(), type, &CurrentQuery.extended_query_info); } break; - case PROCESSING_STMT_EXECUTE: +/* case PROCESSING_STMT_EXECUTE: assert(CurrentQuery.stmt_backend_id); { const std::string& backend_stmt_name = std::string(PROXYSQL_PS_PREFIX) + std::to_string(CurrentQuery.stmt_backend_id); - rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, backend_stmt_name.c_str(), false, - CurrentQuery.bind_msg); + const PgSQL_Extended_Query_Info extended_query_info = { + backend_stmt_name.c_str(), // Name of the prepared statement in the backend + CurrentQuery.stmt_portal_name, // Name of the portal on the backend + CurrentQuery.bind_msg, + PGSQL_EXTENDED_QUERY_TYPE_EXECUTE, // Type of extended query message + CurrentQuery.stmt_msg_type + }; + rc = myconn->async_query(myds->revents, (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, &extended_query_info); } break; +*/ default: // LCOV_EXCL_START assert(0); @@ -3017,11 +3017,11 @@ handler_again: if (mybe->server_myds->myconn && (mybe->server_myds->myconn->async_state_machine != ASYNC_IDLE) && mybe->server_myds->wait_until && (thread->curtime >= mybe->server_myds->wait_until)) { std::string query{}; - if (CurrentQuery.stmt_info == NULL) { // text protocol + if (CurrentQuery.extended_query_info.stmt_info == NULL) { // text protocol query = std::string{ mybe->server_myds->myconn->query.ptr, mybe->server_myds->myconn->query.length }; } else { // prepared statement - query = std::string{ CurrentQuery.stmt_info->query, CurrentQuery.stmt_info->query_length }; + query = std::string{ CurrentQuery.extended_query_info.stmt_info->query, CurrentQuery.extended_query_info.stmt_info->query_length }; } std::string client_addr{ "" }; @@ -3108,31 +3108,31 @@ handler_again: } } if (status == PROCESSING_STMT_DESCRIBE || status == PROCESSING_STMT_EXECUTE) { - uint32_t backend_stmt_id = myconn->local_stmts->find_backend_stmt_id_from_global_id(CurrentQuery.stmt_global_id); + uint32_t backend_stmt_id = myconn->local_stmts->find_backend_stmt_id_from_global_id(CurrentQuery.extended_query_info.stmt_global_id); if (backend_stmt_id == 0) { // the connection doesn't have the prepared statements prepared // we try to create it now - if (CurrentQuery.stmt_info == NULL) { + if (CurrentQuery.extended_query_info.stmt_info == NULL) { // this should never happen proxy_error("Session %p, status %d, CurrentQuery.stmt_info is NULL\n", this, status); assert(0); } - CurrentQuery.QueryLength = CurrentQuery.stmt_info->query_length; - CurrentQuery.QueryPointer = (unsigned char*)CurrentQuery.stmt_info->query; + CurrentQuery.QueryLength = CurrentQuery.extended_query_info.stmt_info->query_length; + CurrentQuery.QueryPointer = (unsigned char*)CurrentQuery.extended_query_info.stmt_info->query; // NOTE: Update 'first_comment' with the 'first_comment' from the retrieved // 'stmt_info' from the found prepared statement. 'CurrentQuery' requires its // own copy of 'first_comment' because it will later be free by 'QueryInfo::end'. - if (CurrentQuery.stmt_info->first_comment) { - CurrentQuery.QueryParserArgs.first_comment = strdup(CurrentQuery.stmt_info->first_comment); + if (CurrentQuery.extended_query_info.stmt_info->first_comment) { + CurrentQuery.QueryParserArgs.first_comment = strdup(CurrentQuery.extended_query_info.stmt_info->first_comment); } - if (CurrentQuery.stmt_global_id != CurrentQuery.stmt_info->statement_id) { + if (CurrentQuery.extended_query_info.stmt_global_id != CurrentQuery.extended_query_info.stmt_info->statement_id) { PROXY_TRACE(); assert(0); } previous_status.push(status); NEXT_IMMEDIATE(PROCESSING_STMT_PREPARE); } - CurrentQuery.stmt_backend_id = backend_stmt_id; + CurrentQuery.extended_query_info.stmt_backend_id = backend_stmt_id; } } } @@ -3985,8 +3985,8 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return true; } - if (stmt_type & PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE || - stmt_type & PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) { // for Describe and Execute we exit here + if (stmt_type == PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE || + stmt_type == PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) { // for Describe and Execute we exit here goto __exit_set_destination_hostgroup; } @@ -5339,7 +5339,7 @@ void PgSQL_Session::RequestEnd(PgSQL_Data_Stream* myds, const unsigned int myerr if (status != PROCESSING_STMT_EXECUTE) { qdt = CurrentQuery.get_digest_text(); } else { - qdt = CurrentQuery.stmt_info->digest_text; + qdt = CurrentQuery.extended_query_info.stmt_info->digest_text; } if (qdt) { @@ -5871,11 +5871,12 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg bool lock_hostgroup = false; const PgSQL_Parse_Data* parse_data = parse_msg->data(); + PgSQL_Extended_Query_Info& extended_query_info = CurrentQuery.extended_query_info; CurrentQuery.begin((unsigned char*)parse_data->query_string, strlen(parse_data->query_string) + 1, false); // parse_msg memory will be freed in pgsql_real_query.end(), if message is sent to backend server // CurrentQuery.stmt_client_name may briefly become a dangling pointer until CurrentQuery.end() is invoked - CurrentQuery.stmt_client_name = parse_data->stmt_name; + extended_query_info.stmt_client_name = parse_data->stmt_name; timespec begint; timespec endt; @@ -5896,6 +5897,7 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg // setting 'prepared' to prevent fetching results from the cache if the digest matches bool handled_in_handler = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(&parse_pkt, &lock_hostgroup, PGSQL_EXTENDED_QUERY_TYPE_PARSE); if (handled_in_handler == true) + // no need to release parse_pkt, it has been released in handler return 0; if (pgsql_thread___set_query_lock_on_hostgroup == 1) { @@ -5934,7 +5936,7 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg // if the same statement name is used, we drop it //FIXME: Revisit this logic PgSQL_STMTs_local_v14* local_stmts = client_myds->myconn->local_stmts; - std::string stmt_name(CurrentQuery.stmt_client_name); + std::string stmt_name(extended_query_info.stmt_client_name); if (auto it = local_stmts->stmt_name_to_global_ids.find(stmt_name); it != local_stmts->stmt_name_to_global_ids.end()) { @@ -5966,7 +5968,7 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg PgSQL_STMT_Global_info* stmt_info = GloPgStmt->find_prepared_statement_by_hash(hash, false); if (stmt_info) { local_stmts->client_insert(stmt_info->statement_id, stmt_name); - CurrentQuery.stmt_global_id = stmt_info->statement_id; + extended_query_info.stmt_global_id = stmt_info->statement_id; client_myds->setDSS_STATE_QUERY_SENT_NET(); @@ -6007,10 +6009,46 @@ int PgSQL_Session::handle_post_sync_describe_message(PgSQL_Describe_Message* des //thread->status_variables.stvar[st_var_frontend_stmt_describe]++; // FIXME thread->status_variables.stvar[st_var_queries]++; - bool lock_hostgroup = false; const PgSQL_Describe_Data* describe_data = describe_msg->data(); + const char* stmt_client_name = NULL; + const char* portal_name = NULL; + bool lock_hostgroup = false; + uint8_t stmt_type = describe_data->stmt_type; + + switch (stmt_type) { + case 'P': // Portal + if (describe_data->stmt_name[0] != '\0') { + // we don't support named portals yet + client_myds->setDSS_STATE_QUERY_SENT_NET(); + std::string err_msg = "only unnamed portals are supported"; + client_myds->myprot.generate_error_packet(true, true, err_msg.c_str(), PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return 2; + } + + // if we are describing a portal, Bind message must exists + if (!bind_waiting_for_execute) { + client_myds->setDSS_STATE_QUERY_SENT_NET(); + std::string err_msg = "portal \"" + std::string(describe_data->stmt_name) + "\" does not exist"; + client_myds->myprot.generate_error_packet(true, true, err_msg.c_str(), PGSQL_ERROR_CODES::ERRCODE_UNDEFINED_CURSOR, false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return 2; + } + + portal_name = describe_data->stmt_name; // currently only supporting unanmed portals + stmt_client_name = bind_waiting_for_execute->data()->stmt_name; // data() will always be a valid pointer + assert(strcmp(portal_name, bind_waiting_for_execute->data()->portal_name) == 0); // portal name should match the one in bind_waiting_for_execute + break; + case 'S': // Statement + stmt_client_name = describe_data->stmt_name; + break; + default: + assert(0); // Invalid statement type, should never happen + } + assert(stmt_client_name); - const char* stmt_client_name = describe_data->stmt_name; uint64_t stmt_global_id = client_myds->myconn->local_stmts->find_global_id_from_stmt_name(stmt_client_name); if (stmt_global_id == 0) { client_myds->setDSS_STATE_QUERY_SENT_NET(); @@ -6034,9 +6072,12 @@ int PgSQL_Session::handle_post_sync_describe_message(PgSQL_Describe_Message* des } // describe_msg memory will be freed in pgsql_real_query.end() // CurrentQuery.stmt_client_name may briefly become a dangling pointer until CurrentQuery.end() is invoked - CurrentQuery.stmt_client_name = stmt_client_name; - CurrentQuery.stmt_global_id = stmt_global_id; - CurrentQuery.stmt_info = stmt_info; + PgSQL_Extended_Query_Info& extended_query_info = CurrentQuery.extended_query_info; + extended_query_info.stmt_client_name = stmt_client_name; + extended_query_info.stmt_client_portal_name = portal_name; + extended_query_info.stmt_global_id = stmt_global_id; + extended_query_info.stmt_info = stmt_info; + extended_query_info.stmt_type = stmt_type; CurrentQuery.start_time = thread->curtime; timespec begint; @@ -6058,27 +6099,32 @@ int PgSQL_Session::handle_post_sync_describe_message(PgSQL_Describe_Message* des (begint.tv_sec * 1000000000 + begint.tv_nsec); } - pthread_rwlock_rdlock(&stmt_info->rwlock_); - if (stmt_info->stmt_metadata) { - // we have the metadata, so we can send it to the client - client_myds->setDSS_STATE_QUERY_SENT_NET(); - bool send_ready_packet = extended_query_frame.empty(); - unsigned int nTxn = NumActiveTransactions(); - const char txn_state = (nTxn ? 'T' : 'I'); - client_myds->myprot.generate_describe_completion_packet(true, send_ready_packet, stmt_info->stmt_metadata, txn_state); + // Use cached stmt_metadata only for statements; for portals, forward the describe request to backend. + if (extended_query_info.stmt_type == 'S') { + pthread_rwlock_rdlock(&stmt_info->rwlock_); + if (stmt_info->stmt_metadata) { + // we have the metadata, so we can send it to the client + client_myds->setDSS_STATE_QUERY_SENT_NET(); + bool send_ready_packet = extended_query_frame.empty(); + unsigned int nTxn = NumActiveTransactions(); + const char txn_state = (nTxn ? 'T' : 'I'); + client_myds->myprot.generate_describe_completion_packet(true, send_ready_packet, stmt_info->stmt_metadata, + extended_query_info.stmt_type, txn_state); + pthread_rwlock_unlock(&stmt_info->rwlock_); + LogQuery(NULL); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + CurrentQuery.end_time = thread->curtime; + CurrentQuery.end(); + return 0; + } pthread_rwlock_unlock(&stmt_info->rwlock_); - LogQuery(NULL); - client_myds->DSS = STATE_SLEEP; - status = WAITING_CLIENT_DATA; - CurrentQuery.end_time = thread->curtime; - CurrentQuery.end(); - return 0; } - pthread_rwlock_unlock(&stmt_info->rwlock_); auto describe_pkt = describe_msg->detach(); // detach the packet from the describe message // setting 'prepared' to prevent fetching results from the cache if the digest matches - bool handled_in_handler = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(&describe_pkt, &lock_hostgroup, PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE); + bool handled_in_handler = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(&describe_pkt, + &lock_hostgroup, PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE); if (handled_in_handler == true) { // no need to free describe_pkt, it is already freed in the handler return 0; @@ -6133,10 +6179,28 @@ int PgSQL_Session::handle_post_sync_close_message(PgSQL_Close_Message* close_msg thread->status_variables.stvar[st_var_frontend_stmt_close]++; thread->status_variables.stvar[st_var_queries]++; - const PgSQL_Close_Data* close_data = close_msg->data(); + const PgSQL_Close_Data* close_data = close_msg->data(); // this will always be a valid pointer + uint8_t stmt_type = close_data->stmt_type; + + switch (stmt_type) { + case 'P': // Portal + if (close_data->stmt_name[0] == '\0') { + // we don't support unnamed portals yet + client_myds->setDSS_STATE_QUERY_SENT_NET(); + std::string err_msg = "only named portals are supported"; + client_myds->myprot.generate_error_packet(true, true, err_msg.c_str(), PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return 2; + } + break; + case 'S': // Statement + client_myds->myconn->local_stmts->client_close(close_data->stmt_name); + break; + default: + assert(0); // this should never occur + } - const std::string& stmt_client_name = close_data->stmt_name; - client_myds->myconn->local_stmts->client_close(stmt_client_name); client_myds->setDSS_STATE_QUERY_SENT_NET(); unsigned int nTxn = NumActiveTransactions(); char txn_state = (nTxn ? 'T' : 'I'); @@ -6152,6 +6216,16 @@ int PgSQL_Session::handle_post_sync_bind_message(PgSQL_Bind_Message* bind_msg) { thread->status_variables.stvar[st_var_queries]++; const PgSQL_Bind_Data* bind_data = bind_msg->data(); + + if (bind_data->portal_name[0] != '\0') { + // we don't support portals yet + client_myds->setDSS_STATE_QUERY_SENT_NET(); + std::string err_msg = "only unnamed portals are supported"; + client_myds->myprot.generate_error_packet(true, true, err_msg.c_str(), PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return 2; + } const char* stmt_client_name = bind_data->stmt_name; @@ -6183,9 +6257,17 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu bool lock_hostgroup = false; const PgSQL_Execute_Data* execute_data = execute_msg->data(); - //CurrentQuery.begin(nullptr, 0, false); - //FIXME: replace strdup with s_strdup - const char* portal_name = execute_data->portal_name; // currently only supporting unanmed prepared statements + if (execute_data->portal_name[0] != '\0') { + // we don't support portals yet + client_myds->setDSS_STATE_QUERY_SENT_NET(); + std::string err_msg = "only unnamed portals are supported"; + client_myds->myprot.generate_error_packet(true, true, err_msg.c_str(), PGSQL_ERROR_CODES::ERRCODE_FEATURE_NOT_SUPPORTED, false, true); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return 2; + } + + const char* portal_name = execute_data->portal_name; if (!bind_waiting_for_execute) { client_myds->setDSS_STATE_QUERY_SENT_NET(); std::string err_msg = "portal \"" + std::string(portal_name) + "\" does not exist"; @@ -6194,6 +6276,8 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu status = WAITING_CLIENT_DATA; return 2; } + assert(strcmp(portal_name, bind_waiting_for_execute->data()->portal_name) == 0); // portal name should match the one in bind_waiting_for_execute + // bind_waiting_for_execute will be released on CurrentQuery.end() call or session destory const char* stmt_client_name = bind_waiting_for_execute->data()->stmt_name; uint64_t stmt_global_id = client_myds->myconn->local_stmts->find_global_id_from_stmt_name(stmt_client_name); @@ -6217,12 +6301,13 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu status = WAITING_CLIENT_DATA; return 2; } - - CurrentQuery.stmt_client_name = (char*)stmt_client_name; - CurrentQuery.stmt_global_id = stmt_global_id; - CurrentQuery.stmt_info = stmt_info; + PgSQL_Extended_Query_Info& extended_query_info = CurrentQuery.extended_query_info; + extended_query_info.stmt_client_portal_name = portal_name; + extended_query_info.stmt_client_name = stmt_client_name; + extended_query_info.stmt_global_id = stmt_global_id; + extended_query_info.stmt_info = stmt_info; + extended_query_info.bind_msg = bind_waiting_for_execute.release(); CurrentQuery.start_time = thread->curtime; - CurrentQuery.bind_msg = bind_waiting_for_execute.release(); timespec begint; timespec endt; @@ -6245,7 +6330,7 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu auto execute_pkt = execute_msg->detach(); // detach the packet from the describe message // setting 'prepared' to prevent fetching results from the cache if the digest matches - bool handled_in_handler = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(&execute_pkt, &lock_hostgroup, PGSQL_EXTENDED_QUERY_TYPE_DESCRIBE); + bool handled_in_handler = handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo(&execute_pkt, &lock_hostgroup, PGSQL_EXTENDED_QUERY_TYPE_EXECUTE); if (handled_in_handler == true) { // no need to free execute_pkt, it is already freed in the handler return 0; @@ -6350,10 +6435,7 @@ int PgSQL_Session::handler___status_PROCESSING_EXTENDED_QUERY_SYNC() { rc = handle_post_sync_execute_message(execute_msg->get()); } else { proxy_error("Unknown extended query message\n"); - client_myds->setDSS_STATE_QUERY_SENT_NET(); - client_myds->myprot.generate_error_packet(true, false, "Unknown extended query message", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, - true); - // will terminate current session + assert(0); // this should never happen } if (rc == 2) { @@ -6511,10 +6593,11 @@ bool PgSQL_Session::handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& s stmt_info->calculate_mem_usage(); } } - CurrentQuery.stmt_info = stmt_info; + PgSQL_Extended_Query_Info& extended_query_info = CurrentQuery.extended_query_info; + extended_query_info.stmt_info = stmt_info; global_stmtid = stmt_info->statement_id; - myds->myconn->local_stmts->backend_insert(global_stmtid, CurrentQuery.stmt_backend_id); + myds->myconn->local_stmts->backend_insert(global_stmtid, extended_query_info.stmt_backend_id); st = status; if (previous_status.empty() == false) { @@ -6528,8 +6611,8 @@ bool PgSQL_Session::handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& s // We only perform the client_insert when there is no previous status, this // is, when 'PROCESSING_STMT_PREPARE' is reached directly without transitioning from a previous status // like 'PROCESSING_STMT_EXECUTE'. - assert(CurrentQuery.stmt_client_name); - client_myds->myconn->local_stmts->client_insert(global_stmtid, CurrentQuery.stmt_client_name); + assert(extended_query_info.stmt_client_name); + client_myds->myconn->local_stmts->client_insert(global_stmtid, extended_query_info.stmt_client_name); bool send_ready_packet = extended_query_frame.empty(); char txn_state = myds->myconn->get_transaction_status_char(); @@ -6541,18 +6624,31 @@ bool PgSQL_Session::handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& s void PgSQL_Session::handler___rc0_PROCESSING_STMT_DESCRIBE_PREPARE(PgSQL_Data_Stream* myds) { //thread->status_variables.stvar[st_var_backend_stmt_describe]++; - assert(CurrentQuery.stmt_info); + const PgSQL_Extended_Query_Info& extended_query_info = CurrentQuery.extended_query_info; + assert(extended_query_info.stmt_info); bool send_ready_packet = extended_query_frame.empty(); char txn_state = myds->myconn->get_transaction_status_char(); - - GloPgStmt->wrlock(); - CurrentQuery.stmt_info->update_stmt_metadata(&myds->myconn->stmt_metadata_result); - client_myds->myprot.generate_describe_completion_packet(true, send_ready_packet, CurrentQuery.stmt_info->stmt_metadata, txn_state); - LogQuery(myds); - GloPgStmt->unlock(); - if (myds->myconn->stmt_metadata_result) { - delete myds->myconn->stmt_metadata_result; - myds->myconn->stmt_metadata_result = NULL; + + if (extended_query_info.stmt_type == 'S') { + GloPgStmt->wrlock(); + extended_query_info.stmt_info->update_stmt_metadata(&myds->myconn->stmt_metadata_result); + client_myds->myprot.generate_describe_completion_packet(true, send_ready_packet, extended_query_info.stmt_info->stmt_metadata, + extended_query_info.stmt_type, txn_state); + LogQuery(myds); + GloPgStmt->unlock(); + if (myds->myconn->stmt_metadata_result) { + delete myds->myconn->stmt_metadata_result; + myds->myconn->stmt_metadata_result = NULL; + } + } else { + // For portals, we don't cache metadata, so we just send an empty packet + client_myds->myprot.generate_describe_completion_packet(true, send_ready_packet, myds->myconn->stmt_metadata_result, + extended_query_info.stmt_type, txn_state); + LogQuery(myds); + if (myds->myconn->stmt_metadata_result) { + delete myds->myconn->stmt_metadata_result; + myds->myconn->stmt_metadata_result = NULL; + } } } @@ -6764,175 +6860,3 @@ std::string PgSQL_DateStyle_Util::datestyle_to_string(std::string_view input, co return datestyle_to_string(parse_datestyle(input), default_datestyle); } -#if 0 // FIXME: remove after extended query support is fully implemented -/* -// implement PgSQL_Formatted_Bind_Message, also use itatrator to get aligned values -PgSQL_Formatted_Bind_Message::PgSQL_Formatted_Bind_Message(PgSQL_Bind_Message* bind_msg) { - if (bind_msg == NULL) { - return; - } - num_param_formats = bind_msg->num_param_formats; - num_param_values = bind_msg->num_param_values; - num_result_formats = bind_msg->num_result_formats; - stmt_name = strdup(bind_msg->stmt_name ? bind_msg->stmt_name : ""); - portal_name = strdup(bind_msg->portal_name ? bind_msg->portal_name : ""); - - if (num_param_formats > 0) { - PgSQL_Bind_Message::FormatIterCtx param_format_iter; - bind_msg->init_param_format_iter(¶m_format_iter); - // Allocate memory for param_formats - param_formats = (const unsigned int*)malloc(num_param_formats * sizeof(unsigned int)); - if (param_formats == NULL) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to allocate memory for param_formats\n"); - return; - } - // Fill param_formats using the iterator - for (uint16_t i = 0; i < num_param_formats; ++i) { - if (!bind_msg->next_format(¶m_format_iter, (uint16_t*)¶m_formats[i])) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to read param format at index %u\n", i); - free((void*)param_formats); - param_formats = NULL; - return; - } - } - } - - if (num_param_values > 0) { - PgSQL_Bind_Message::ParamValueIterCtx param_value_iter; - bind_msg->init_param_value_iter(¶m_value_iter); - // Allocate memory for param_values - param_values = (const char*)malloc(num_param_values * sizeof(PgSQL_Bind_Message::ParamValue_t)); - param_values_len = (int*)malloc(num_param_values * sizeof(int)); - if (param_values == NULL) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to allocate memory for param_values\n"); - return; - } - // Fill param_values using the iterator - for (uint16_t i = 0; i < num_param_values; ++i) { - PgSQL_Bind_Message::ParamValue_t value; - if (!bind_msg->next_param_value(¶m_value_iter, &value)) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to read param value at index %u\n", i); - free((void*)param_values); - param_values = NULL; - return; - } - // Copy the value into the allocated memory - memcpy((void*)¶m_values[i], &value, sizeof(PgSQL_Bind_Message::ParamValue_t)); - // Store the length of the value - memcpy((void*)¶m_values_len[i], (uint32_t*)&value.len, sizeof(int)); - } - } - - if (num_result_formats > 0) { - PgSQL_Bind_Message::FormatIterCtx result_format_iter; - bind_msg->init_result_format_iter(&result_format_iter); - // Allocate memory for result_formats - result_formats = (const int*)malloc(num_result_formats * sizeof(int)); - if (result_formats == NULL) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to allocate memory for result_formats\n"); - return; - } - // Fill result_formats using the iterator - for (uint16_t i = 0; i < num_result_formats; ++i) { - if (!bind_msg->next_format(&result_format_iter, (uint16_t*)&result_formats[i])) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to read result format at index %u\n", i); - free((void*)result_formats); - result_formats = NULL; - return; - } - } - } -} - -PgSQL_Formatted_Bind_Message::~PgSQL_Formatted_Bind_Message() { - if (stmt_name) { - free((void*)stmt_name); - stmt_name = NULL; - } - if (portal_name) { - free((void*)portal_name); - portal_name = NULL; - } - if (param_formats) { - free((void*)param_formats); - param_formats = NULL; - } - if (param_values) { - free((void*)param_values); - param_values = NULL; - } - if (param_values_len) { - free((void*)param_values_len); - param_values_len = NULL; - } - if (result_formats) { - free((void*)result_formats); - result_formats = NULL; - } -} -*/ - -PgSQL_Formatted_Bind_Message::PgSQL_Formatted_Bind_Message(PgSQL_Bind_Message* bind_msg) { - if (bind_msg == NULL) { - return; - } - - stmt_name = bind_msg->stmt_name ? bind_msg->stmt_name : ""; - portal_name = bind_msg->portal_name ? bind_msg->portal_name : ""; - - if (bind_msg->num_param_values > 0) { - PgSQL_Bind_Message::ParamValueIterCtx valCtx; - bind_msg->init_param_value_iter(&valCtx); - - param_values.resize(bind_msg->num_param_values); - param_lengths.resize(bind_msg->num_param_values); - - for (int i = 0; i < bind_msg->num_param_values; ++i) { - PgSQL_Bind_Message::ParamValue_t param; - if (!bind_msg->next_param_value(&valCtx, ¶m)) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to read param value at index %u\n", i); - return; - } - - param_values[i] = (reinterpret_cast(param.value)); - param_lengths[i] = param.len; - } - } - - if (bind_msg->num_param_formats > 0) { - PgSQL_Bind_Message::FormatIterCtx fmtCtx; - bind_msg->init_param_format_iter(&fmtCtx); - - param_formats.resize(bind_msg->num_param_formats); - - for (int i = 0; i < bind_msg->num_param_formats; ++i) { - uint16_t format; - - if (!bind_msg->next_format(&fmtCtx, &format)) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to read param format at index %u\n", i); - return; - } - - param_formats[i] = format; - } - } - - if (bind_msg->num_result_formats > 0) { - PgSQL_Bind_Message::FormatIterCtx fmtCtx; - bind_msg->init_result_format_iter(&fmtCtx); - result_formats.resize(bind_msg->num_result_formats); - for (int i = 0; i < bind_msg->num_result_formats; ++i) { - uint16_t format; - if (!bind_msg->next_format(&fmtCtx, &format)) { - proxy_error("PgSQL_Formatted_Bind_Message: Failed to read result format at index %u\n", i); - return; - } - result_formats[i] = format; - } - } -} - -PgSQL_Formatted_Bind_Message::~PgSQL_Formatted_Bind_Message() { - -} -#endif diff --git a/lib/PgSQL_Thread.cpp b/lib/PgSQL_Thread.cpp index 75cb95891..d57931509 100644 --- a/lib/PgSQL_Thread.cpp +++ b/lib/PgSQL_Thread.cpp @@ -4725,7 +4725,7 @@ SQLite3_result* PgSQL_Threads_Handler::SQL3_Processlist() { pta[9] = strdup(buf); sprintf(buf, "%d", mc->parent->port); pta[10] = strdup(buf); - if (sess->CurrentQuery.stmt_info == NULL) { // text protocol + if (sess->CurrentQuery.extended_query_info.stmt_info == NULL) { // text protocol if (mc->query.length) { pta[13] = (char*)malloc(mc->query.length + 1); strncpy(pta[13], mc->query.ptr, mc->query.length); @@ -4736,7 +4736,7 @@ SQLite3_result* PgSQL_Threads_Handler::SQL3_Processlist() { } } else { // prepared statement - PgSQL_STMT_Global_info* si = sess->CurrentQuery.stmt_info; + PgSQL_STMT_Global_info* si = sess->CurrentQuery.extended_query_info.stmt_info; if (si->query_length) { pta[13] = (char*)malloc(si->query_length + 1); strncpy(pta[13], si->query, si->query_length);