Added support for protocol-supplied (out-of-band) parameter typing (argument-based)

* This parameter types will be part of hash calculation
* Few refactoring and optimizations
pull/5044/head
Rahim Kanji 10 months ago
parent a8a2799187
commit 4ebff4c0cc

@ -75,15 +75,107 @@ private:
PtrSize_t _pkt = {}; ///< Packet data pointer.
};
struct PgSQL_Param_Value {
int32_t len; ///< Length of value (-1 for NULL)
const unsigned char* value; ///< Pointer to value data
};
/**
* @brief Reads fields from a PostgreSQL extended query message.
*
* This template class provides an iterator-like interface for reading a sequence of fields
* from a buffer, converting each field from network byte order (big-endian) to host byte order.
* It supports reading different field types such as uint32_t, uint16_t, and PgSQL_Param_Value.
*
* Note: The buffer pointer passed to this reader may be nullptr if count is zero (valid case).
* If count is non-zero but the buffer is invalid (malformed packet), this is detected and handled
* during parsing before constructing the reader.
*
* @tparam T The type of field to read (e.g., uint32_t, uint16_t, PgSQL_Param_Value).
*/
template<class T>
class PgSQL_Field_Reader {
public:
/**
* @brief Constructs a field reader.
* @param start Pointer to the start of the field array.
* @param count Number of fields to read.
*/
PgSQL_Field_Reader(const unsigned char* start, uint16_t count) : current(start), remaining(count) {}
~PgSQL_Field_Reader() = default;
PgSQL_Field_Reader() = delete;
PgSQL_Field_Reader(const PgSQL_Field_Reader&) = default;
PgSQL_Field_Reader& operator=(const PgSQL_Field_Reader&) = default;
PgSQL_Field_Reader(PgSQL_Field_Reader&&) = default;
PgSQL_Field_Reader& operator=(PgSQL_Field_Reader&&) = default;
/**
* @brief Checks if there are more fields to read.
* @return True if more fields are available, false otherwise.
*/
bool has_next() const { return remaining > 0; }
/**
* @brief Reads the next field from the buffer.
* @param out Pointer to the output variable to store the field value.
* @return True if the field was successfully read, false otherwise.
*
* For uint32_t and uint16_t, reads the value in big-endian order.
* For PgSQL_Param_Value, reads the length and value pointer, handling NULL values.
*/
bool next(T* out) {
if (remaining == 0) return false;
if constexpr (std::is_same_v<T, uint32_t>) {
if (!get_uint32be(current, out)) {
return false;
}
current += sizeof(uint32_t);
} else if constexpr (std::is_same_v<T, uint16_t>) {
if (!get_uint16be(current, out)) {
return false;
}
current += sizeof(uint16_t);
} else if constexpr (std::is_same_v<T, PgSQL_Param_Value>) {
// Read length (big-endian)
uint32_t len;
if (!get_uint32be(current, &len)) {
return false;
}
current += sizeof(uint32_t);
out->len = (len == 0xFFFFFFFF) ? -1 : static_cast<int32_t>(len);
out->value = (len == 0xFFFFFFFF) ? nullptr : current;
// Advance pointer if not NULL
if (out->len > 0) {
current += len;
}
}
remaining--;
return true;
}
private:
const unsigned char* current; ///< Current position in the buffer.
uint16_t remaining; ///< Number of fields remaining to read.
};
struct PgSQL_Parse_Data {
const char* stmt_name; // The name of the prepared statement
const char* query_string; // The query string to be prepared
uint16_t num_param_types; // Number of parameter types specified
const uint32_t* param_types; // Array of parameter types (can be nullptr if none)
private:
const unsigned char* param_types_start_ptr; // Array of parameter types (can be nullptr if none)
friend class PgSQL_Parse_Message;
};
class PgSQL_Parse_Message : public Base_Extended_Query_Message<PgSQL_Parse_Data,PgSQL_Parse_Message> {
public:
/**
* @brief Parses the PgSQL_Parse_Message from the provided packet.
*
@ -95,6 +187,9 @@ public:
* @return True if parsing was successful, false otherwise.
*/
bool parse(PtrSize_t& pkt);
// Initialize param type iterator
PgSQL_Field_Reader<uint32_t> get_param_types_reader() const;
};
struct PgSQL_Describe_Data {
@ -155,17 +250,6 @@ private:
class PgSQL_Bind_Message : public Base_Extended_Query_Message<PgSQL_Bind_Data,PgSQL_Bind_Message> {
public:
typedef struct {
int32_t len; // Length of value (-1 for NULL)
const unsigned char* value; // Pointer to value data
} ParamValue_t;
// Iterator context for parameter values
typedef struct {
const unsigned char* current; // Current position in values
uint16_t remaining; // Parameters remaining
} IteratorCtx;
/**
* @brief Parses the PgSQL_Bind_Message from the provided packet.
*
@ -179,16 +263,12 @@ public:
*/
bool parse(PtrSize_t& pkt);
// Initialize param format iterator
void init_param_format_iter(IteratorCtx* ctx) const;
// Initialize parameter value iterator
void init_param_value_iter(IteratorCtx* ctx) const;
// Get next parameter value
bool next_param_value(IteratorCtx* ctx, ParamValue_t* out) const;
// Initialize param type iterator
PgSQL_Field_Reader<uint16_t> get_param_format_reader() const;
// Initialize result format iterator
void init_result_format_iter(IteratorCtx* ctx) const;
// Get next format value
bool next_format(IteratorCtx* ctx, uint16_t* out) const;
PgSQL_Field_Reader<uint16_t> get_result_format_reader() const;
// Initialize parameter value iterator
PgSQL_Field_Reader<PgSQL_Param_Value> get_param_value_reader() const;
};
struct PgSQL_Execute_Data {

@ -7,12 +7,11 @@
// class PgSQL_STMT_Global_info represents information about a PgSQL Prepared Statement
// it is an internal representation of prepared statement
// it include all metadata associated with it
class PgSQL_STMT_Global_info {
public:
uint64_t digest;
PGSQL_QUERY_command PgQueryCmd;
char * digest_text;
char* digest_text;
uint64_t hash;
char *username;
char *dbname;
@ -26,7 +25,9 @@ public:
PgSQL_Describe_Prepared_Info* stmt_metadata;
bool is_select_NOT_for_update;
PgSQL_STMT_Global_info(uint64_t id, char* u, char* d, char* q, unsigned int ql, char* fc, uint64_t _h);
Parse_Param_Types parse_param_types;// array of parameter types, used for prepared statements
PgSQL_STMT_Global_info(uint64_t id, char* u, char* d, char* q, unsigned int ql, char* fc, Parse_Param_Types&& ppt, uint64_t _h);
~PgSQL_STMT_Global_info();
void update_stmt_metadata(PgSQL_Describe_Prepared_Info** new_stmt_metadata);
@ -41,11 +42,6 @@ private:
};
class PgSQL_STMTs_local_v14 {
private:
bool is_client_;
std::stack<uint32_t> free_backend_ids;
uint32_t local_max_stmt_id = 0;
public:
// this map associate client_stmt_id to global_stmt_id : this is used only for client connections
std::map<std::string, uint64_t> stmt_name_to_global_ids;
@ -58,7 +54,7 @@ public:
std::map<uint64_t, uint32_t> global_stmt_to_backend_ids;
PgSQL_Session *sess;
PgSQL_STMTs_local_v14(bool _ic) : is_client_(_ic), sess(NULL) { }
PgSQL_STMTs_local_v14(bool _ic) : sess(NULL), is_client_(_ic) { }
~PgSQL_STMTs_local_v14();
inline
@ -74,15 +70,38 @@ public:
void backend_insert(uint64_t global_stmt_id, uint32_t backend_stmt_id);
void client_insert(uint64_t global_stmt_id, const std::string& client_stmt_name);
uint64_t compute_hash(const char *user, const char *database, const char *query, unsigned int query_length);
uint64_t compute_hash(const char *user, const char *database, const char *query, unsigned int query_length,
const Parse_Param_Types& param_types);
uint32_t generate_new_backend_stmt_id();
uint64_t find_global_id_from_stmt_name(const std::string& client_stmt_name);
uint32_t find_backend_stmt_id_from_global_id(uint64_t global_id);
bool client_close(const std::string& stmt_name);
private:
bool is_client_;
std::stack<uint32_t> free_backend_ids;
uint32_t local_max_stmt_id = 0;
};
class PgSQL_STMT_Manager_v14 {
public:
PgSQL_STMT_Manager_v14();
~PgSQL_STMT_Manager_v14();
PgSQL_STMT_Global_info* find_prepared_statement_by_hash(uint64_t hash, bool lock=true);
PgSQL_STMT_Global_info* find_prepared_statement_by_stmt_id(uint64_t id, bool lock=true);
inline void rdlock() { pthread_rwlock_rdlock(&rwlock_); }
inline void wrlock() { pthread_rwlock_wrlock(&rwlock_); }
inline void unlock() { pthread_rwlock_unlock(&rwlock_); }
void ref_count_client(uint64_t _stmt, int _v, bool lock=true);
void ref_count_server(uint64_t _stmt, int _v, bool lock=true);
PgSQL_STMT_Global_info* add_prepared_statement(char *user, char *database, char *query, unsigned int query_len,
char *fc, Parse_Param_Types&& ppt, bool lock=true);
void get_metrics(uint64_t *c_unique, uint64_t *c_total, uint64_t *stmt_max_stmt_id, uint64_t *cached,
uint64_t *s_unique, uint64_t *s_total);
SQLite3_result* get_prepared_statements_global_infos();
void get_memory_usage(uint64_t& prep_stmt_metadata_mem_usage, uint64_t& prep_stmt_backend_mem_usage);
private:
uint64_t next_statement_id;
uint64_t num_stmt_with_ref_client_count_zero;
@ -100,20 +119,6 @@ private:
uint64_t s_total;
} statuses;
time_t last_purge_time;
public:
PgSQL_STMT_Manager_v14();
~PgSQL_STMT_Manager_v14();
PgSQL_STMT_Global_info* find_prepared_statement_by_hash(uint64_t hash, bool lock=true);
PgSQL_STMT_Global_info* find_prepared_statement_by_stmt_id(uint64_t id, bool lock=true);
inline void rdlock() { pthread_rwlock_rdlock(&rwlock_); }
inline void wrlock() { pthread_rwlock_wrlock(&rwlock_); }
inline void unlock() { pthread_rwlock_unlock(&rwlock_); }
void ref_count_client(uint64_t _stmt, int _v, bool lock=true);
void ref_count_server(uint64_t _stmt, int _v, bool lock=true);
PgSQL_STMT_Global_info * add_prepared_statement(char *user, char *database, char *query, unsigned int query_len, char *fc, bool lock=true);
void get_metrics(uint64_t *c_unique, uint64_t *c_total, uint64_t *stmt_max_stmt_id, uint64_t *cached, uint64_t *s_unique, uint64_t *s_total);
SQLite3_result* get_prepared_statements_global_infos();
void get_memory_usage(uint64_t& prep_stmt_metadata_mem_usage, uint64_t& prep_stmt_backend_mem_usage);
};
#endif /* CLASS_PGSQL_PREPARED_STATEMENT_H */

@ -134,6 +134,7 @@ public:
};
class PgSQL_STMT_Global_info;
using Parse_Param_Types = std::vector<uint32_t>; // Vector of parameter types for prepared statements
struct PgSQL_Extended_Query_Info {
const char* stmt_client_name;
@ -143,6 +144,7 @@ struct PgSQL_Extended_Query_Info {
uint64_t stmt_global_id;
uint32_t stmt_backend_id;
uint8_t stmt_type;
Parse_Param_Types parse_param_types;
};
class PgSQL_Query_Info {

@ -337,14 +337,10 @@ inline T overflow_safe_multiply(T val) {
* @param[out] dst_p A pointer where the extracted big endian 32-bit unsigned integer value will be stored.
*/
inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) {
int read_pos = 0;
unsigned a, b, c, d;
a = pkt[read_pos++];
b = pkt[read_pos++];
c = pkt[read_pos++];
d = pkt[read_pos++];
*dst_p = (a << 24) | (b << 16) | (c << 8) | d;
*dst_p = ((uint32_t)pkt[0] << 24) |
((uint32_t)pkt[1] << 16) |
((uint32_t)pkt[2] << 8) |
((uint32_t)pkt[3]);
return true;
}
@ -366,13 +362,8 @@ inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) {
* The function uses post-increment to move the reading position after extracting each byte.
*/
inline bool get_uint16be(const unsigned char* pkt, uint16_t* dst_p) {
int read_pos = 0; ///< Current read position in the buffer.
unsigned a, b;
// Read the two bytes from the buffer
a = pkt[read_pos++]; ///< First byte read from the buffer.
b = pkt[read_pos++]; ///< Second byte read from the buffer.
*dst_p = (a << 8) | b;
*dst_p = ((uint16_t)pkt[0] << 8) |
((uint16_t)pkt[1]);
return true;
}

@ -1591,7 +1591,9 @@ void PgSQL_Connection::stmt_prepare_start() {
PQsetNoticeReceiver(pgsql_conn, &PgSQL_Connection::notice_handler_cb, this);
if (PQsendPrepare(pgsql_conn, query.backend_stmt_name, query.ptr, 0, NULL) == 0) {
const Parse_Param_Types& parse_param_types = query.extended_query_info->parse_param_types;
if (PQsendPrepare(pgsql_conn, query.backend_stmt_name, query.ptr, parse_param_types.size(), parse_param_types.data()) == 0) {
set_error_from_PQerrorMessage();
proxy_error("Failed to send prepare. %s\n", get_error_code_with_message().c_str());
return;
@ -1702,54 +1704,49 @@ void PgSQL_Connection::stmt_execute_start() {
std::vector<int> result_formats;
if (bind_data.num_param_values > 0) {
PgSQL_Bind_Message::IteratorCtx valCtx;
bind_msg->init_param_value_iter(&valCtx);
auto param_value_reader = bind_msg->get_param_value_reader();
param_values.resize(bind_data.num_param_values);
param_lengths.resize(bind_data.num_param_values);
for (int i = 0; i < bind_data.num_param_values; ++i) {
PgSQL_Bind_Message::ParamValue_t param;
if (!bind_msg->next_param_value(&valCtx, &param)) {
PgSQL_Param_Value param_val;
if (!param_value_reader.next(&param_val)) {
proxy_error("Failed to read param value at index %u\n", i);
set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE,
"Failed to read param value", false);
return;
}
param_values[i] = (reinterpret_cast<const char*>(param.value));
param_lengths[i] = param.len;
param_values[i] = (reinterpret_cast<const char*>(param_val.value));
param_lengths[i] = param_val.len;
}
}
if (bind_data.num_param_formats > 0) {
PgSQL_Bind_Message::IteratorCtx fmtCtx;
bind_msg->init_param_format_iter(&fmtCtx);
auto param_fmt_reader = bind_msg->get_param_format_reader();
param_formats.resize(bind_data.num_param_formats);
for (int i = 0; i < bind_data.num_param_formats; ++i) {
uint16_t format;
if (!bind_msg->next_format(&fmtCtx, &format)) {
if (!param_fmt_reader.next(&format)) {
proxy_error("Failed to read param format at index %u\n", i);
set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE,
"Failed to read param format", false);
return;
return;
}
param_formats[i] = format;
}
}
if (bind_data.num_result_formats > 0) {
PgSQL_Bind_Message::IteratorCtx fmtCtx;
bind_msg->init_result_format_iter(&fmtCtx);
auto result_fmt_reader = bind_msg->get_result_format_reader();
result_formats.resize(bind_data.num_result_formats);
for (int i = 0; i < bind_data.num_result_formats; ++i) {
uint16_t format;
if (!bind_msg->next_format(&fmtCtx, &format)) {
if (!result_fmt_reader.next(&format)) {
proxy_error("Failed to read result format at index %u\n", i);
set_error(PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE,
"Failed to read result format", false);

@ -117,7 +117,7 @@ bool PgSQL_Parse_Message::parse(PtrSize_t& pkt) {
}
// Read the parameter types array (each is 4 bytes, big-endian)
data.param_types = reinterpret_cast<const uint32_t*>(packet + offset);
data.param_types_start_ptr = (packet + offset);
// Move past the parameter types
offset += data.num_param_types * sizeof(uint32_t);
@ -134,6 +134,11 @@ bool PgSQL_Parse_Message::parse(PtrSize_t& pkt) {
return true;
}
PgSQL_Field_Reader<uint32_t> PgSQL_Parse_Message::get_param_types_reader() const {
const PgSQL_Parse_Data& parse_data = data();
return PgSQL_Field_Reader<uint32_t>(parse_data.param_types_start_ptr, parse_data.num_param_types);
}
bool PgSQL_Describe_Message::parse(PtrSize_t& pkt) {
if (pkt.ptr == nullptr || pkt.size == 0) {
@ -413,60 +418,19 @@ bool PgSQL_Bind_Message::parse(PtrSize_t& pkt) {
return true;
}
// Initialize param format iterator
void PgSQL_Bind_Message::init_param_format_iter(IteratorCtx* ctx) const {
PgSQL_Field_Reader<uint16_t> PgSQL_Bind_Message::get_param_format_reader() const {
const PgSQL_Bind_Data& bind_data = data();
ctx->current = bind_data.param_formats_start_ptr;
ctx->remaining = bind_data.num_param_formats;
return PgSQL_Field_Reader<uint16_t>(bind_data.param_formats_start_ptr, bind_data.num_param_formats);
}
void PgSQL_Bind_Message::init_param_value_iter(IteratorCtx* ctx) const {
PgSQL_Field_Reader<uint16_t> PgSQL_Bind_Message::get_result_format_reader() const {
const PgSQL_Bind_Data& bind_data = data();
ctx->current = bind_data.param_values_start_ptr;
ctx->remaining = bind_data.num_param_values;
return PgSQL_Field_Reader<uint16_t>(bind_data.result_formats_start_ptr, bind_data.num_result_formats);
}
// Get next parameter value
bool PgSQL_Bind_Message::next_param_value(IteratorCtx* ctx, ParamValue_t* out) const {
if (ctx->remaining == 0) return false;
// Read length (big-endian)
uint32_t len;
if (!get_uint32be(ctx->current, &len)) {
return false;
}
ctx->current += sizeof(uint32_t);
out->len = (len == 0xFFFFFFFF) ? -1 : static_cast<int32_t>(len);
out->value = (len == 0xFFFFFFFF) ? nullptr : ctx->current;
// Advance pointer if not NULL
if (out->len > 0) {
ctx->current += len;
}
ctx->remaining--;
return true;
}
// Initialize format iterator
void PgSQL_Bind_Message::init_result_format_iter(IteratorCtx* ctx) const {
PgSQL_Field_Reader<PgSQL_Param_Value> PgSQL_Bind_Message::get_param_value_reader() const {
const PgSQL_Bind_Data& bind_data = data();
ctx->current = bind_data.result_formats_start_ptr;
ctx->remaining = bind_data.num_result_formats;
}
// Get next format value
bool PgSQL_Bind_Message::next_format(IteratorCtx* ctx, uint16_t* out) const {
if (ctx->remaining == 0) return false;
if (!get_uint16be(ctx->current, out)) {
return false;
}
ctx->current += sizeof(uint16_t);
ctx->remaining--;
return true;
return PgSQL_Field_Reader<PgSQL_Param_Value>(bind_data.param_values_start_ptr, bind_data.num_param_values);
}
// implement PgSQL_Execute_Message

@ -14,10 +14,11 @@ extern PgSQL_STMT_Manager_v14 *GloPgStmt;
const int PS_GLOBAL_STATUS_FIELD_NUM = 9;
static uint64_t stmt_compute_hash(const char *user,
const char *database, const char *query, unsigned int query_length) {
const char *database, const char *query, unsigned int query_length, const Parse_Param_Types& param_types) {
// two random seperators
static const char DELIM1[] = "-ZiODNjvcNHTFaARXoqqSPDqQe-";
static const char DELIM2[] = "-aSfpWDoswfuRsJXqZKfcelzCL-";
static const char DELIM3[] = "-rQkhRVXdvgVYsmiqZCMikjKmP-";
// NOSONAR: strlen is safe here
size_t user_length = strlen(user); // NOSONAR
@ -25,6 +26,7 @@ static uint64_t stmt_compute_hash(const char *user,
size_t database_length = strlen(database); // NOSONAR
size_t delim1_length = sizeof(DELIM1) - 1;
size_t delim2_length = sizeof(DELIM2) - 1;
size_t delim3_length = sizeof(DELIM3) - 1;
size_t l = 0;
l += user_length;
@ -32,6 +34,11 @@ static uint64_t stmt_compute_hash(const char *user,
l += delim1_length;
l += delim2_length;
l += query_length;
if (!param_types.empty()) {
l += delim3_length; // add length for the third delimiter
l += sizeof(uint16_t); // add length for number of parameter types
l += (param_types.size() * sizeof(uint32_t)); // add length for parameter types
}
auto buf = (char *)malloc(l);
l = 0;
@ -40,7 +47,12 @@ static uint64_t stmt_compute_hash(const char *user,
memcpy(buf + l, database, database_length); l += database_length; // write database
memcpy(buf + l, DELIM2, delim2_length); l += delim2_length; // write delimiter2
memcpy(buf + l, query, query_length); l += query_length; // write query
if (!param_types.empty()) {
uint16_t size = param_types.size();
memcpy(buf + l, DELIM3, delim3_length); l += delim3_length; // write delimiter3
memcpy(buf + l, &size, sizeof(uint16_t)); l += sizeof(uint16_t); // write number of parameter types
memcpy(buf + l, param_types.data(), size * sizeof(uint32_t)); l += (size * sizeof(uint32_t)); // write each parameter type
}
uint64_t hash = SpookyHash::Hash64(buf, l, 0);
free(buf);
return hash;
@ -48,13 +60,14 @@ static uint64_t stmt_compute_hash(const char *user,
void PgSQL_STMT_Global_info::compute_hash() {
hash = stmt_compute_hash(username, dbname, query,
query_length);
query_length, parse_param_types);
}
PgSQL_STMT_Global_info::PgSQL_STMT_Global_info(uint64_t id,
char *u, char *d, char *q,
unsigned int ql,
char *fc,
Parse_Param_Types&& ppt,
uint64_t _h) {
pthread_rwlock_init(&rwlock_, NULL);
total_mem_usage = 0;
@ -69,11 +82,8 @@ PgSQL_STMT_Global_info::PgSQL_STMT_Global_info(uint64_t id,
memcpy(query, q, ql);
query[ql] = '\0'; // add NULL byte
query_length = ql;
if (fc) {
first_comment = strdup(fc);
} else {
first_comment = nullptr;
}
first_comment = fc ? strdup(fc) : nullptr;
parse_param_types = std::move(ppt);
PgQueryCmd = PGSQL_QUERY__UNINITIALIZED;
if (_h) {
@ -235,6 +245,7 @@ PgSQL_STMT_Global_info::~PgSQL_STMT_Global_info() {
free(first_comment);
if (digest_text)
free(digest_text);
parse_param_types.clear(); // clear the parameter types vector
if (stmt_metadata)
delete stmt_metadata;
pthread_rwlock_destroy(&rwlock_);
@ -260,8 +271,8 @@ void PgSQL_STMTs_local_v14::client_insert(uint64_t global_stmt_id, const std::st
}
uint64_t PgSQL_STMTs_local_v14::compute_hash(const char *user,
const char *database, const char *query, unsigned int query_length) {
uint64_t hash = stmt_compute_hash(user, database, query, query_length);
const char *database, const char *query, unsigned int query_length, const Parse_Param_Types& param_types) {
uint64_t hash = stmt_compute_hash(user, database, query, query_length, param_types);
return hash;
}
@ -474,10 +485,10 @@ bool PgSQL_STMTs_local_v14::client_close(const std::string& stmt_name) {
PgSQL_STMT_Global_info* PgSQL_STMT_Manager_v14::add_prepared_statement(
char *u, char *d, char *q, unsigned int ql,
char *fc, bool lock) {
char *fc, Parse_Param_Types&& ppt, bool lock) {
PgSQL_STMT_Global_info *ret = nullptr;
uint64_t hash = stmt_compute_hash(
u, d, q, ql); // this identifies the prepared statement
u, d, q, ql, ppt); // this identifies the prepared statement
if (lock) {
wrlock();
}
@ -495,7 +506,7 @@ PgSQL_STMT_Global_info* PgSQL_STMT_Manager_v14::add_prepared_statement(
next_statement_id++;
}
auto stmt_info = std::make_unique<PgSQL_STMT_Global_info>(next_id, u, d, q, ql, fc, hash);
auto stmt_info = std::make_unique<PgSQL_STMT_Global_info>(next_id, u, d, q, ql, fc, std::move(ppt), hash);
// insert it in both maps
map_stmt_id_to_info.insert(std::make_pair(stmt_info->statement_id, stmt_info.get()));
map_stmt_hash_to_info.insert(std::make_pair(stmt_info->hash, stmt_info.get()));

@ -347,6 +347,7 @@ void PgSQL_Query_Info::reset_extended_query_info() {
extended_query_info.stmt_global_id = 0;
extended_query_info.stmt_backend_id = 0;
extended_query_info.stmt_type = 'S';
extended_query_info.parse_param_types.clear();
}
void PgSQL_Query_Info::init(unsigned char *_p, int len, bool header) {
@ -2811,7 +2812,7 @@ int PgSQL_Session::RunQuery(PgSQL_Data_Stream* myds, PgSQL_Connection* myconn) {
// bind_waiting_for_execute in case the client sends a sequence like
// Bind/Describe/Execute/Describe/Sync, so that a subsequent Describe Portal
// does not incorrectly assume a pending Bind.
if (rc != 1 && type == PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) {
if (rc == 0 && type == PGSQL_EXTENDED_QUERY_TYPE_EXECUTE) {
bind_waiting_for_execute.reset(nullptr);
}
}
@ -3122,6 +3123,9 @@ handler_again:
if (CurrentQuery.extended_query_info.stmt_info->first_comment) {
CurrentQuery.QueryParserArgs.first_comment = strdup(CurrentQuery.extended_query_info.stmt_info->first_comment);
}
if (CurrentQuery.extended_query_info.stmt_info->parse_param_types.empty() == false) {
CurrentQuery.extended_query_info.parse_param_types = CurrentQuery.extended_query_info.stmt_info->parse_param_types;
}
if (CurrentQuery.extended_query_info.stmt_global_id != CurrentQuery.extended_query_info.stmt_info->statement_id) {
PROXY_TRACE();
assert(0);
@ -5902,6 +5906,19 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg
(begint.tv_sec * 1000000000 + begint.tv_nsec);
}
assert(qpo); // GloPgQPro->process_mysql_query() should always return a qpo
if (parse_data.num_param_types > 0) {
Parse_Param_Types parse_param_type;
parse_param_type.resize(parse_data.num_param_types);
auto param_type_reader = parse_msg->get_param_types_reader(); // get the reader for the param types
for (uint16_t i = 0; i < parse_data.num_param_types; ++i) {
if (!param_type_reader.next(&parse_param_type[i])) {
proxy_error("Failed to read result format at index %u\n", i);
return 2;
}
}
CurrentQuery.extended_query_info.parse_param_types = std::move(parse_param_type);
}
auto parse_pkt = parse_msg->detach(); // detach the packet from the parse message
@ -5971,7 +5988,8 @@ int PgSQL_Session::handle_post_sync_parse_message(PgSQL_Parse_Message* parse_msg
client_myds->myconn->userinfo->username,
client_myds->myconn->userinfo->dbname,
(char*)CurrentQuery.QueryPointer,
CurrentQuery.QueryLength
CurrentQuery.QueryLength,
CurrentQuery.extended_query_info.parse_param_types
);
// Check global statement cache
@ -6575,6 +6593,7 @@ bool PgSQL_Session::handler___rc0_PROCESSING_STMT_PREPARE(enum session_status& s
(char*)CurrentQuery.QueryPointer,
CurrentQuery.QueryLength,
CurrentQuery.QueryParserArgs.first_comment,
std::move(CurrentQuery.extended_query_info.parse_param_types),
false);
assert(stmt_info); // GloPgStmt->add_prepared_statement() should always return a valid pointer
if (CurrentQuery.QueryParserArgs.digest_text) {

Loading…
Cancel
Save