From 4ebff4c0cc000b867e1ae4b74be83e9915b2e220 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Mon, 21 Jul 2025 01:22:07 +0500 Subject: [PATCH] Added support for protocol-supplied (out-of-band) parameter typing (argument-based) * This parameter types will be part of hash calculation * Few refactoring and optimizations --- include/PgSQL_Extended_Query_Message.h | 122 ++++++++++++++++++++----- include/PgSQL_PreparedStatement.h | 53 ++++++----- include/PgSQL_Session.h | 2 + include/gen_utils.h | 21 ++--- lib/PgSQL_Connection.cpp | 27 +++--- lib/PgSQL_Extended_Query_Message.cpp | 60 +++--------- lib/PgSQL_PreparedStatement.cpp | 37 +++++--- lib/PgSQL_Session.cpp | 23 ++++- 8 files changed, 207 insertions(+), 138 deletions(-) diff --git a/include/PgSQL_Extended_Query_Message.h b/include/PgSQL_Extended_Query_Message.h index dd14e8d43..a5a66f248 100644 --- a/include/PgSQL_Extended_Query_Message.h +++ b/include/PgSQL_Extended_Query_Message.h @@ -75,15 +75,107 @@ private: PtrSize_t _pkt = {}; ///< Packet data pointer. }; +struct PgSQL_Param_Value { + int32_t len; ///< Length of value (-1 for NULL) + const unsigned char* value; ///< Pointer to value data +}; + + +/** + * @brief Reads fields from a PostgreSQL extended query message. + * + * This template class provides an iterator-like interface for reading a sequence of fields + * from a buffer, converting each field from network byte order (big-endian) to host byte order. + * It supports reading different field types such as uint32_t, uint16_t, and PgSQL_Param_Value. + * + * Note: The buffer pointer passed to this reader may be nullptr if count is zero (valid case). + * If count is non-zero but the buffer is invalid (malformed packet), this is detected and handled + * during parsing before constructing the reader. + * + * @tparam T The type of field to read (e.g., uint32_t, uint16_t, PgSQL_Param_Value). + */ +template +class PgSQL_Field_Reader { +public: + /** + * @brief Constructs a field reader. + * @param start Pointer to the start of the field array. + * @param count Number of fields to read. + */ + PgSQL_Field_Reader(const unsigned char* start, uint16_t count) : current(start), remaining(count) {} + ~PgSQL_Field_Reader() = default; + PgSQL_Field_Reader() = delete; + PgSQL_Field_Reader(const PgSQL_Field_Reader&) = default; + PgSQL_Field_Reader& operator=(const PgSQL_Field_Reader&) = default; + PgSQL_Field_Reader(PgSQL_Field_Reader&&) = default; + PgSQL_Field_Reader& operator=(PgSQL_Field_Reader&&) = default; + + /** + * @brief Checks if there are more fields to read. + * @return True if more fields are available, false otherwise. + */ + bool has_next() const { return remaining > 0; } + + /** + * @brief Reads the next field from the buffer. + * @param out Pointer to the output variable to store the field value. + * @return True if the field was successfully read, false otherwise. + * + * For uint32_t and uint16_t, reads the value in big-endian order. + * For PgSQL_Param_Value, reads the length and value pointer, handling NULL values. + */ + bool next(T* out) { + if (remaining == 0) return false; + + if constexpr (std::is_same_v) { + if (!get_uint32be(current, out)) { + return false; + } + current += sizeof(uint32_t); + } else if constexpr (std::is_same_v) { + if (!get_uint16be(current, out)) { + return false; + } + current += sizeof(uint16_t); + } else if constexpr (std::is_same_v) { + // Read length (big-endian) + uint32_t len; + if (!get_uint32be(current, &len)) { + return false; + } + current += sizeof(uint32_t); + + out->len = (len == 0xFFFFFFFF) ? -1 : static_cast(len); + out->value = (len == 0xFFFFFFFF) ? nullptr : current; + + // Advance pointer if not NULL + if (out->len > 0) { + current += len; + } + } + remaining--; + return true; + } +private: + const unsigned char* current; ///< Current position in the buffer. + uint16_t remaining; ///< Number of fields remaining to read. +}; + + struct PgSQL_Parse_Data { const char* stmt_name; // The name of the prepared statement const char* query_string; // The query string to be prepared uint16_t num_param_types; // Number of parameter types specified - const uint32_t* param_types; // Array of parameter types (can be nullptr if none) + +private: + const unsigned char* param_types_start_ptr; // Array of parameter types (can be nullptr if none) + + friend class PgSQL_Parse_Message; }; class PgSQL_Parse_Message : public Base_Extended_Query_Message { public: + /** * @brief Parses the PgSQL_Parse_Message from the provided packet. * @@ -95,6 +187,9 @@ public: * @return True if parsing was successful, false otherwise. */ bool parse(PtrSize_t& pkt); + + // Initialize param type iterator + PgSQL_Field_Reader get_param_types_reader() const; }; struct PgSQL_Describe_Data { @@ -155,17 +250,6 @@ private: class PgSQL_Bind_Message : public Base_Extended_Query_Message { public: - typedef struct { - int32_t len; // Length of value (-1 for NULL) - const unsigned char* value; // Pointer to value data - } ParamValue_t; - - // Iterator context for parameter values - typedef struct { - const unsigned char* current; // Current position in values - uint16_t remaining; // Parameters remaining - } IteratorCtx; - /** * @brief Parses the PgSQL_Bind_Message from the provided packet. * @@ -179,16 +263,12 @@ public: */ bool parse(PtrSize_t& pkt); - // Initialize param format iterator - void init_param_format_iter(IteratorCtx* ctx) const; - // Initialize parameter value iterator - void init_param_value_iter(IteratorCtx* ctx) const; - // Get next parameter value - bool next_param_value(IteratorCtx* ctx, ParamValue_t* out) const; + // Initialize param type iterator + PgSQL_Field_Reader get_param_format_reader() const; // Initialize result format iterator - void init_result_format_iter(IteratorCtx* ctx) const; - // Get next format value - bool next_format(IteratorCtx* ctx, uint16_t* out) const; + PgSQL_Field_Reader get_result_format_reader() const; + // Initialize parameter value iterator + PgSQL_Field_Reader get_param_value_reader() const; }; struct PgSQL_Execute_Data { diff --git a/include/PgSQL_PreparedStatement.h b/include/PgSQL_PreparedStatement.h index 66a35b2c9..7e167933f 100644 --- a/include/PgSQL_PreparedStatement.h +++ b/include/PgSQL_PreparedStatement.h @@ -7,12 +7,11 @@ // class PgSQL_STMT_Global_info represents information about a PgSQL Prepared Statement // it is an internal representation of prepared statement // it include all metadata associated with it - class PgSQL_STMT_Global_info { public: uint64_t digest; PGSQL_QUERY_command PgQueryCmd; - char * digest_text; + char* digest_text; uint64_t hash; char *username; char *dbname; @@ -26,7 +25,9 @@ public: PgSQL_Describe_Prepared_Info* stmt_metadata; bool is_select_NOT_for_update; - PgSQL_STMT_Global_info(uint64_t id, char* u, char* d, char* q, unsigned int ql, char* fc, uint64_t _h); + Parse_Param_Types parse_param_types;// array of parameter types, used for prepared statements + + PgSQL_STMT_Global_info(uint64_t id, char* u, char* d, char* q, unsigned int ql, char* fc, Parse_Param_Types&& ppt, uint64_t _h); ~PgSQL_STMT_Global_info(); void update_stmt_metadata(PgSQL_Describe_Prepared_Info** new_stmt_metadata); @@ -41,11 +42,6 @@ private: }; class PgSQL_STMTs_local_v14 { -private: - bool is_client_; - std::stack free_backend_ids; - uint32_t local_max_stmt_id = 0; - public: // this map associate client_stmt_id to global_stmt_id : this is used only for client connections std::map stmt_name_to_global_ids; @@ -58,7 +54,7 @@ public: std::map global_stmt_to_backend_ids; PgSQL_Session *sess; - PgSQL_STMTs_local_v14(bool _ic) : is_client_(_ic), sess(NULL) { } + PgSQL_STMTs_local_v14(bool _ic) : sess(NULL), is_client_(_ic) { } ~PgSQL_STMTs_local_v14(); inline @@ -74,15 +70,38 @@ public: void backend_insert(uint64_t global_stmt_id, uint32_t backend_stmt_id); void client_insert(uint64_t global_stmt_id, const std::string& client_stmt_name); - uint64_t compute_hash(const char *user, const char *database, const char *query, unsigned int query_length); + uint64_t compute_hash(const char *user, const char *database, const char *query, unsigned int query_length, + const Parse_Param_Types& param_types); uint32_t generate_new_backend_stmt_id(); uint64_t find_global_id_from_stmt_name(const std::string& client_stmt_name); uint32_t find_backend_stmt_id_from_global_id(uint64_t global_id); bool client_close(const std::string& stmt_name); + +private: + bool is_client_; + std::stack free_backend_ids; + uint32_t local_max_stmt_id = 0; }; class PgSQL_STMT_Manager_v14 { +public: + PgSQL_STMT_Manager_v14(); + ~PgSQL_STMT_Manager_v14(); + PgSQL_STMT_Global_info* find_prepared_statement_by_hash(uint64_t hash, bool lock=true); + PgSQL_STMT_Global_info* find_prepared_statement_by_stmt_id(uint64_t id, bool lock=true); + inline void rdlock() { pthread_rwlock_rdlock(&rwlock_); } + inline void wrlock() { pthread_rwlock_wrlock(&rwlock_); } + inline void unlock() { pthread_rwlock_unlock(&rwlock_); } + void ref_count_client(uint64_t _stmt, int _v, bool lock=true); + void ref_count_server(uint64_t _stmt, int _v, bool lock=true); + PgSQL_STMT_Global_info* add_prepared_statement(char *user, char *database, char *query, unsigned int query_len, + char *fc, Parse_Param_Types&& ppt, bool lock=true); + void get_metrics(uint64_t *c_unique, uint64_t *c_total, uint64_t *stmt_max_stmt_id, uint64_t *cached, + uint64_t *s_unique, uint64_t *s_total); + SQLite3_result* get_prepared_statements_global_infos(); + void get_memory_usage(uint64_t& prep_stmt_metadata_mem_usage, uint64_t& prep_stmt_backend_mem_usage); + private: uint64_t next_statement_id; uint64_t num_stmt_with_ref_client_count_zero; @@ -100,20 +119,6 @@ private: uint64_t s_total; } statuses; time_t last_purge_time; -public: - PgSQL_STMT_Manager_v14(); - ~PgSQL_STMT_Manager_v14(); - PgSQL_STMT_Global_info* find_prepared_statement_by_hash(uint64_t hash, bool lock=true); - PgSQL_STMT_Global_info* find_prepared_statement_by_stmt_id(uint64_t id, bool lock=true); - inline void rdlock() { pthread_rwlock_rdlock(&rwlock_); } - inline void wrlock() { pthread_rwlock_wrlock(&rwlock_); } - inline void unlock() { pthread_rwlock_unlock(&rwlock_); } - void ref_count_client(uint64_t _stmt, int _v, bool lock=true); - void ref_count_server(uint64_t _stmt, int _v, bool lock=true); - PgSQL_STMT_Global_info * add_prepared_statement(char *user, char *database, char *query, unsigned int query_len, char *fc, bool lock=true); - void get_metrics(uint64_t *c_unique, uint64_t *c_total, uint64_t *stmt_max_stmt_id, uint64_t *cached, uint64_t *s_unique, uint64_t *s_total); - SQLite3_result* get_prepared_statements_global_infos(); - void get_memory_usage(uint64_t& prep_stmt_metadata_mem_usage, uint64_t& prep_stmt_backend_mem_usage); }; #endif /* CLASS_PGSQL_PREPARED_STATEMENT_H */ diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 6a852c983..e1eba0a2e 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -134,6 +134,7 @@ public: }; class PgSQL_STMT_Global_info; +using Parse_Param_Types = std::vector; // Vector of parameter types for prepared statements struct PgSQL_Extended_Query_Info { const char* stmt_client_name; @@ -143,6 +144,7 @@ struct PgSQL_Extended_Query_Info { uint64_t stmt_global_id; uint32_t stmt_backend_id; uint8_t stmt_type; + Parse_Param_Types parse_param_types; }; class PgSQL_Query_Info { diff --git a/include/gen_utils.h b/include/gen_utils.h index 10ba5fea8..68acb9ab0 100644 --- a/include/gen_utils.h +++ b/include/gen_utils.h @@ -337,14 +337,10 @@ inline T overflow_safe_multiply(T val) { * @param[out] dst_p A pointer where the extracted big endian 32-bit unsigned integer value will be stored. */ inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) { - int read_pos = 0; - unsigned a, b, c, d; - - a = pkt[read_pos++]; - b = pkt[read_pos++]; - c = pkt[read_pos++]; - d = pkt[read_pos++]; - *dst_p = (a << 24) | (b << 16) | (c << 8) | d; + *dst_p = ((uint32_t)pkt[0] << 24) | + ((uint32_t)pkt[1] << 16) | + ((uint32_t)pkt[2] << 8) | + ((uint32_t)pkt[3]); return true; } @@ -366,13 +362,8 @@ inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) { * The function uses post-increment to move the reading position after extracting each byte. */ inline bool get_uint16be(const unsigned char* pkt, uint16_t* dst_p) { - int read_pos = 0; ///< Current read position in the buffer. - unsigned a, b; - - // Read the two bytes from the buffer - a = pkt[read_pos++]; ///< First byte read from the buffer. - b = pkt[read_pos++]; ///< Second byte read from the buffer. - *dst_p = (a << 8) | b; + *dst_p = ((uint16_t)pkt[0] << 8) | + ((uint16_t)pkt[1]); return true; } diff --git a/lib/PgSQL_Connection.cpp b/lib/PgSQL_Connection.cpp index e492315c3..8dce0a9e9 100644 --- a/lib/PgSQL_Connection.cpp +++ b/lib/PgSQL_Connection.cpp @@ -1591,7 +1591,9 @@ void PgSQL_Connection::stmt_prepare_start() { PQsetNoticeReceiver(pgsql_conn, &PgSQL_Connection::notice_handler_cb, this); - if (PQsendPrepare(pgsql_conn, query.backend_stmt_name, query.ptr, 0, NULL) == 0) { + const Parse_Param_Types& parse_param_types = query.extended_query_info->parse_param_types; + + if (PQsendPrepare(pgsql_conn, query.backend_stmt_name, query.ptr, parse_param_types.size(), parse_param_types.data()) == 0) { set_error_from_PQerrorMessage(); proxy_error("Failed to send prepare. %s\n", get_error_code_with_message().c_str()); return; @@ -1702,54 +1704,49 @@ void PgSQL_Connection::stmt_execute_start() { std::vector result_formats; if (bind_data.num_param_values > 0) { - PgSQL_Bind_Message::IteratorCtx valCtx; - bind_msg->init_param_value_iter(&valCtx); + auto param_value_reader = bind_msg->get_param_value_reader(); param_values.resize(bind_data.num_param_values); param_lengths.resize(bind_data.num_param_values); for (int i = 0; i < bind_data.num_param_values; ++i) { - PgSQL_Bind_Message::ParamValue_t param; - if (!bind_msg->next_param_value(&valCtx, ¶m)) { + PgSQL_Param_Value param_val; + if (!param_value_reader.next(¶m_val)) { proxy_error("Failed to read param value at index %u\n", i); set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Failed to read param value", false); return; } - param_values[i] = (reinterpret_cast(param.value)); - param_lengths[i] = param.len; + param_values[i] = (reinterpret_cast(param_val.value)); + param_lengths[i] = param_val.len; } } if (bind_data.num_param_formats > 0) { - PgSQL_Bind_Message::IteratorCtx fmtCtx; - bind_msg->init_param_format_iter(&fmtCtx); + auto param_fmt_reader = bind_msg->get_param_format_reader(); param_formats.resize(bind_data.num_param_formats); for (int i = 0; i < bind_data.num_param_formats; ++i) { uint16_t format; - - if (!bind_msg->next_format(&fmtCtx, &format)) { + if (!param_fmt_reader.next(&format)) { proxy_error("Failed to read param format at index %u\n", i); set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Failed to read param format", false); return; return; } - param_formats[i] = format; } } if (bind_data.num_result_formats > 0) { - PgSQL_Bind_Message::IteratorCtx fmtCtx; - bind_msg->init_result_format_iter(&fmtCtx); + auto result_fmt_reader = bind_msg->get_result_format_reader(); result_formats.resize(bind_data.num_result_formats); for (int i = 0; i < bind_data.num_result_formats; ++i) { uint16_t format; - if (!bind_msg->next_format(&fmtCtx, &format)) { + if (!result_fmt_reader.next(&format)) { proxy_error("Failed to read result format at index %u\n", i); set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE, "Failed to read result format", false); diff --git a/lib/PgSQL_Extended_Query_Message.cpp b/lib/PgSQL_Extended_Query_Message.cpp index 3a30a6d8c..6f4a34d6e 100644 --- a/lib/PgSQL_Extended_Query_Message.cpp +++ b/lib/PgSQL_Extended_Query_Message.cpp @@ -117,7 +117,7 @@ bool PgSQL_Parse_Message::parse(PtrSize_t& pkt) { } // Read the parameter types array (each is 4 bytes, big-endian) - data.param_types = reinterpret_cast(packet + offset); + data.param_types_start_ptr = (packet + offset); // Move past the parameter types offset += data.num_param_types * sizeof(uint32_t); @@ -134,6 +134,11 @@ bool PgSQL_Parse_Message::parse(PtrSize_t& pkt) { return true; } +PgSQL_Field_Reader PgSQL_Parse_Message::get_param_types_reader() const { + const PgSQL_Parse_Data& parse_data = data(); + return PgSQL_Field_Reader(parse_data.param_types_start_ptr, parse_data.num_param_types); +} + bool PgSQL_Describe_Message::parse(PtrSize_t& pkt) { if (pkt.ptr == nullptr || pkt.size == 0) { @@ -413,60 +418,19 @@ bool PgSQL_Bind_Message::parse(PtrSize_t& pkt) { return true; } -// Initialize param format iterator -void PgSQL_Bind_Message::init_param_format_iter(IteratorCtx* ctx) const { +PgSQL_Field_Reader PgSQL_Bind_Message::get_param_format_reader() const { const PgSQL_Bind_Data& bind_data = data(); - ctx->current = bind_data.param_formats_start_ptr; - ctx->remaining = bind_data.num_param_formats; + return PgSQL_Field_Reader(bind_data.param_formats_start_ptr, bind_data.num_param_formats); } -void PgSQL_Bind_Message::init_param_value_iter(IteratorCtx* ctx) const { +PgSQL_Field_Reader PgSQL_Bind_Message::get_result_format_reader() const { const PgSQL_Bind_Data& bind_data = data(); - ctx->current = bind_data.param_values_start_ptr; - ctx->remaining = bind_data.num_param_values; + return PgSQL_Field_Reader(bind_data.result_formats_start_ptr, bind_data.num_result_formats); } -// Get next parameter value -bool PgSQL_Bind_Message::next_param_value(IteratorCtx* ctx, ParamValue_t* out) const { - if (ctx->remaining == 0) return false; - - // Read length (big-endian) - uint32_t len; - if (!get_uint32be(ctx->current, &len)) { - return false; - } - ctx->current += sizeof(uint32_t); - - out->len = (len == 0xFFFFFFFF) ? -1 : static_cast(len); - out->value = (len == 0xFFFFFFFF) ? nullptr : ctx->current; - - // Advance pointer if not NULL - if (out->len > 0) { - ctx->current += len; - } - - ctx->remaining--; - return true; -} - -// Initialize format iterator -void PgSQL_Bind_Message::init_result_format_iter(IteratorCtx* ctx) const { +PgSQL_Field_Reader PgSQL_Bind_Message::get_param_value_reader() const { const PgSQL_Bind_Data& bind_data = data(); - ctx->current = bind_data.result_formats_start_ptr; - ctx->remaining = bind_data.num_result_formats; -} - -// Get next format value -bool PgSQL_Bind_Message::next_format(IteratorCtx* ctx, uint16_t* out) const { - if (ctx->remaining == 0) return false; - - if (!get_uint16be(ctx->current, out)) { - return false; - } - - ctx->current += sizeof(uint16_t); - ctx->remaining--; - return true; + return PgSQL_Field_Reader(bind_data.param_values_start_ptr, bind_data.num_param_values); } // implement PgSQL_Execute_Message diff --git a/lib/PgSQL_PreparedStatement.cpp b/lib/PgSQL_PreparedStatement.cpp index 29495c3f8..cd688a7a6 100644 --- a/lib/PgSQL_PreparedStatement.cpp +++ b/lib/PgSQL_PreparedStatement.cpp @@ -14,10 +14,11 @@ extern PgSQL_STMT_Manager_v14 *GloPgStmt; const int PS_GLOBAL_STATUS_FIELD_NUM = 9; static uint64_t stmt_compute_hash(const char *user, - const char *database, const char *query, unsigned int query_length) { + const char *database, const char *query, unsigned int query_length, const Parse_Param_Types& param_types) { // two random seperators static const char DELIM1[] = "-ZiODNjvcNHTFaARXoqqSPDqQe-"; static const char DELIM2[] = "-aSfpWDoswfuRsJXqZKfcelzCL-"; + static const char DELIM3[] = "-rQkhRVXdvgVYsmiqZCMikjKmP-"; // NOSONAR: strlen is safe here size_t user_length = strlen(user); // NOSONAR @@ -25,6 +26,7 @@ static uint64_t stmt_compute_hash(const char *user, size_t database_length = strlen(database); // NOSONAR size_t delim1_length = sizeof(DELIM1) - 1; size_t delim2_length = sizeof(DELIM2) - 1; + size_t delim3_length = sizeof(DELIM3) - 1; size_t l = 0; l += user_length; @@ -32,6 +34,11 @@ static uint64_t stmt_compute_hash(const char *user, l += delim1_length; l += delim2_length; l += query_length; + if (!param_types.empty()) { + l += delim3_length; // add length for the third delimiter + l += sizeof(uint16_t); // add length for number of parameter types + l += (param_types.size() * sizeof(uint32_t)); // add length for parameter types + } auto buf = (char *)malloc(l); l = 0; @@ -40,7 +47,12 @@ static uint64_t stmt_compute_hash(const char *user, memcpy(buf + l, database, database_length); l += database_length; // write database memcpy(buf + l, DELIM2, delim2_length); l += delim2_length; // write delimiter2 memcpy(buf + l, query, query_length); l += query_length; // write query - + if (!param_types.empty()) { + uint16_t size = param_types.size(); + memcpy(buf + l, DELIM3, delim3_length); l += delim3_length; // write delimiter3 + memcpy(buf + l, &size, sizeof(uint16_t)); l += sizeof(uint16_t); // write number of parameter types + memcpy(buf + l, param_types.data(), size * sizeof(uint32_t)); l += (size * sizeof(uint32_t)); // write each parameter type + } uint64_t hash = SpookyHash::Hash64(buf, l, 0); free(buf); return hash; @@ -48,13 +60,14 @@ static uint64_t stmt_compute_hash(const char *user, void PgSQL_STMT_Global_info::compute_hash() { hash = stmt_compute_hash(username, dbname, query, - query_length); + query_length, parse_param_types); } PgSQL_STMT_Global_info::PgSQL_STMT_Global_info(uint64_t id, char *u, char *d, char *q, unsigned int ql, char *fc, + Parse_Param_Types&& ppt, uint64_t _h) { pthread_rwlock_init(&rwlock_, NULL); total_mem_usage = 0; @@ -69,11 +82,8 @@ PgSQL_STMT_Global_info::PgSQL_STMT_Global_info(uint64_t id, memcpy(query, q, ql); query[ql] = '\0'; // add NULL byte query_length = ql; - if (fc) { - first_comment = strdup(fc); - } else { - first_comment = nullptr; - } + first_comment = fc ? strdup(fc) : nullptr; + parse_param_types = std::move(ppt); PgQueryCmd = PGSQL_QUERY__UNINITIALIZED; if (_h) { @@ -235,6 +245,7 @@ PgSQL_STMT_Global_info::~PgSQL_STMT_Global_info() { free(first_comment); if (digest_text) free(digest_text); + parse_param_types.clear(); // clear the parameter types vector if (stmt_metadata) delete stmt_metadata; pthread_rwlock_destroy(&rwlock_); @@ -260,8 +271,8 @@ void PgSQL_STMTs_local_v14::client_insert(uint64_t global_stmt_id, const std::st } uint64_t PgSQL_STMTs_local_v14::compute_hash(const char *user, - const char *database, const char *query, unsigned int query_length) { - uint64_t hash = stmt_compute_hash(user, database, query, query_length); + const char *database, const char *query, unsigned int query_length, const Parse_Param_Types& param_types) { + uint64_t hash = stmt_compute_hash(user, database, query, query_length, param_types); return hash; } @@ -474,10 +485,10 @@ bool PgSQL_STMTs_local_v14::client_close(const std::string& stmt_name) { PgSQL_STMT_Global_info* PgSQL_STMT_Manager_v14::add_prepared_statement( char *u, char *d, char *q, unsigned int ql, - char *fc, bool lock) { + char *fc, Parse_Param_Types&& ppt, bool lock) { PgSQL_STMT_Global_info *ret = nullptr; uint64_t hash = stmt_compute_hash( - u, d, q, ql); // this identifies the prepared statement + u, d, q, ql, ppt); // this identifies the prepared statement if (lock) { wrlock(); } @@ -495,7 +506,7 @@ PgSQL_STMT_Global_info* PgSQL_STMT_Manager_v14::add_prepared_statement( next_statement_id++; } - auto stmt_info = std::make_unique(next_id, u, d, q, ql, fc, hash); + auto stmt_info = std::make_unique(next_id, u, d, q, ql, fc, std::move(ppt), hash); // insert it in both maps map_stmt_id_to_info.insert(std::make_pair(stmt_info->statement_id, stmt_info.get())); map_stmt_hash_to_info.insert(std::make_pair(stmt_info->hash, stmt_info.get())); diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index d744c6eea..92663ca39 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -347,6 +347,7 @@ void PgSQL_Query_Info::reset_extended_query_info() { extended_query_info.stmt_global_id = 0; extended_query_info.stmt_backend_id = 0; extended_query_info.stmt_type = 'S'; + extended_query_info.parse_param_types.clear(); } void PgSQL_Query_Info::init(unsigned char *_p, int len, bool header) { @@ -2811,7 +2812,7 @@ int PgSQL_Session::RunQuery(PgSQL_Data_Stream* myds, PgSQL_Connection* myconn) { // bind_waiting_for_execute in case the client sends a sequence like // Bind/Describe/Execute/Describe/Sync, so that a subsequent Describe Portal // does not incorrectly assume a pending Bind. - if (rc != 1 && type == PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) { + if (rc == 0 && type == PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) { bind_waiting_for_execute.reset(nullptr); } } @@ -3122,6 +3123,9 @@ handler_again: if (CurrentQuery.extended_query_info.stmt_info->first_comment) { CurrentQuery.QueryParserArgs.first_comment = strdup(CurrentQuery.extended_query_info.stmt_info->first_comment); } + if (CurrentQuery.extended_query_info.stmt_info->parse_param_types.empty() == false) { + CurrentQuery.extended_query_info.parse_param_types = CurrentQuery.extended_query_info.stmt_info->parse_param_types; + } if (CurrentQuery.extended_query_info.stmt_global_id != CurrentQuery.extended_query_info.stmt_info->statement_id) { PROXY_TRACE(); assert(0); @@ -5902,6 +5906,19 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg (begint.tv_sec * 1000000000 + begint.tv_nsec); } assert(qpo); // GloPgQPro->process_mysql_query() should always return a qpo + + if (parse_data.num_param_types > 0) { + Parse_Param_Types parse_param_type; + parse_param_type.resize(parse_data.num_param_types); + auto param_type_reader = parse_msg->get_param_types_reader(); // get the reader for the param types + for (uint16_t i = 0; i < parse_data.num_param_types; ++i) { + if (!param_type_reader.next(&parse_param_type[i])) { + proxy_error("Failed to read result format at index %u\n", i); + return 2; + } + } + CurrentQuery.extended_query_info.parse_param_types = std::move(parse_param_type); + } auto parse_pkt = parse_msg->detach(); // detach the packet from the parse message @@ -5971,7 +5988,8 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg client_myds->myconn->userinfo->username, client_myds->myconn->userinfo->dbname, (char*)CurrentQuery.QueryPointer, - CurrentQuery.QueryLength + CurrentQuery.QueryLength, + CurrentQuery.extended_query_info.parse_param_types ); // Check global statement cache @@ -6575,6 +6593,7 @@ bool PgSQL_Session::handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& s (char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, CurrentQuery.QueryParserArgs.first_comment, + std::move(CurrentQuery.extended_query_info.parse_param_types), false); assert(stmt_info); // GloPgStmt->add_prepared_statement() should always return a valid pointer if (CurrentQuery.QueryParserArgs.digest_text) {