From a1e10e30558e40863bbaf343a18b879a3cc640e1 Mon Sep 17 00:00:00 2001 From: Rahim Kanji Date: Wed, 14 Jan 2026 14:11:34 +0500 Subject: [PATCH] Add parameterized PID support for pg_cancel_backend/pg_terminate_backend This commit extends the existing pg_cancel_backend() and pg_terminate_backend() support to work with parameterized queries in the extended query protocol. While literal PID values were already supported in both simple and extended query protocols, this enhancement adds support for parameterized queries like SELECT pg_cancel_backend($1). --- include/PgSQL_Session.h | 8 + include/gen_utils.h | 31 +++- lib/PgSQL_Session.cpp | 329 ++++++++++++++++++++++++++++++++++------ 3 files changed, 320 insertions(+), 48 deletions(-) diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 098ddd14b..39ec44d48 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -20,6 +20,7 @@ class PgSQL_Describe_Message; class PgSQL_Close_Message; class PgSQL_Bind_Message; class PgSQL_Execute_Message; +struct PgSQL_Param_Value; #ifndef PROXYJSON #define PROXYJSON @@ -580,6 +581,7 @@ public: void Memory_Stats(); void create_new_session_and_reset_connection(PgSQL_Data_Stream* _myds) override; bool handle_command_query_kill(PtrSize_t*); + //void update_expired_conns(const std::vector>&); /** * @brief Performs the final operations after current query has finished to be executed. It updates the session @@ -603,6 +605,12 @@ public: void set_previous_status_mode3(bool allow_execute = true); char* get_current_query(int max_length = -1); +private: + int32_t extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const; + void send_parameter_error_response(const char* error_message); + bool handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt); + bool handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc); + #if defined(__clang__) template friend class Base_Session; diff --git a/include/gen_utils.h b/include/gen_utils.h index 34c260531..8556fd468 100644 --- a/include/gen_utils.h +++ b/include/gen_utils.h @@ -436,6 +436,31 @@ inline T overflow_safe_multiply(T val) { return (val * FACTOR); } +/** + * @brief Read a 64-bit unsigned integer from a big-endian byte buffer. + * + * Reads 8 bytes from the provided buffer and converts them from + * big-endian (network byte order) into host byte order. + * + * @param pkt Pointer to at least 8 bytes of input data. + * @param dst_p Pointer to the destination uint64_t where the result + * will be stored. + * + * @return true Always returns true. + */ +inline bool get_uint64be(const unsigned char* pkt, uint64_t* dst_p) { + *dst_p = + ((uint64_t)pkt[0] << 56) | + ((uint64_t)pkt[1] << 48) | + ((uint64_t)pkt[2] << 40) | + ((uint64_t)pkt[3] << 32) | + ((uint64_t)pkt[4] << 24) | + ((uint64_t)pkt[5] << 16) | + ((uint64_t)pkt[6] << 8) | + ((uint64_t)pkt[7]); + return true; +} + /* * @brief Reads and converts a big endian 32-bit unsigned integer from the provided packet buffer into the destination pointer. * @@ -448,9 +473,9 @@ inline T overflow_safe_multiply(T val) { */ inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) { *dst_p = ((uint32_t)pkt[0] << 24) | - ((uint32_t)pkt[1] << 16) | - ((uint32_t)pkt[2] << 8) | - ((uint32_t)pkt[3]); + ((uint32_t)pkt[1] << 16) | + ((uint32_t)pkt[2] << 8) | + ((uint32_t)pkt[3]); return true; } diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 919bfd882..36bebfae6 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -4425,11 +4425,10 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_Q } // Handle KILL command - //if (prepared == false) { if (handle_command_query_kill(pkt)) { return true; } - // + // Query cache handling if (qpo->cache_ttl > 0 && stmt_type == PGSQL_EXTENDED_QUERY_TYPE_NOT_SET) { const std::shared_ptr pgsql_qc_entry = GloPgQC->get( @@ -5171,55 +5170,284 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) { if (!CurrentQuery.QueryParserArgs.digest_text) return false; - if (client_myds && client_myds->myconn) { - PgSQL_Connection* mc = client_myds->myconn; - if (mc->userinfo && mc->userinfo->username) { - if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND || - CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) { - char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, - pgsql_thread___query_digests_lowercase); - string nq = string(qu, strlen(qu)); - re2::RE2::Options* opt2 = new re2::RE2::Options(RE2::Quiet); - opt2->set_case_sensitive(false); - char* pattern = (char*)"^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$"; - re2::RE2* re = new RE2(pattern, *opt2); - string tk; - int id = 0; - RE2::FullMatch(nq, *re, &tk, &id); - delete re; - delete opt2; - proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu); - free(qu); - - if (id) { - int tki = -1; - // Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match) - if (strcasecmp(tk.c_str(), "TERMINATE") == 0) { - tki = 0; // Connection terminate - } else if (strcasecmp(tk.c_str(), "CANCEL") == 0) { - tki = 1; // Query cancel - } - if (tki >= 0) { - proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n", (tki == 0 ? "CONNECTION" : "QUERY"), id); - GloPTH->kill_connection_or_query(id, 0, mc->userinfo->username, (tki == 0 ? false : true)); - client_myds->DSS = STATE_QUERY_SENT_NET; - - std::unique_ptr resultset = std::make_unique(1); - resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend"); - char* pta[1]; - pta[0] = (char*)"t"; - resultset->add_row(pta); - bool send_ready_packet = is_extended_query_ready_for_query(); - unsigned int nTxn = NumActiveTransactions(); - char txn_state = (nTxn ? 'T' : 'I'); - SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, (const char*)pkt->ptr + 5, send_ready_packet, txn_state); + if (!client_myds || + !client_myds->myconn || + !client_myds->myconn->userinfo || + !client_myds->myconn->userinfo->username) { + return false; + } - RequestEnd(NULL, false); + PgSQL_Connection* mc = client_myds->myconn; + if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND || + CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) { + + if (cmd == 'Q') { + // Simple query protocol - only handle literal values + // Parameterized queries in simple protocol are invalid and will be handled by PostgreSQL + return handle_literal_kill_query(pkt, mc); + } else { + // cmd == 'E' - Execute phase of extended query protocol + // Check if this is a parameterized query (contains $1) + // Note: This simple check might have false positives if $1 appears in comments or string literals + // but those cases would fail later when checking bind_msg or parameter validation + const char* digest_text = CurrentQuery.QueryParserArgs.digest_text; + bool is_parameterized = strstr(digest_text, "$1") != nullptr; + if (is_parameterized) { + // Check if there are multiple parameters (e.g., $1, $2) + // Look for $2, $3, etc. to reject multiple parameters + const char* p = digest_text; + int max_param = 0; + while ((p = strstr(p, "$")) != nullptr) { + p++; // Skip '$' + if (isdigit(*p)) { + int param_num = atoi(p); + if (param_num > max_param) max_param = param_num; + } + p++; + } + if (max_param > 1) { + // Multiple parameters not supported + send_parameter_error_response("function requires exactly one parameter"); + l_free(pkt->size, pkt->ptr); + return true; + } + + // Handle parameterized query + if (CurrentQuery.extended_query_info.bind_msg) { + const PgSQL_Bind_Message* bind_msg = CurrentQuery.extended_query_info.bind_msg; + auto param_reader = bind_msg->get_param_value_reader(); + PgSQL_Param_Value param; + + // Check that we have exactly one parameter + if (bind_msg->data().num_param_values != 1) { + send_parameter_error_response("function requires exactly one parameter"); l_free(pkt->size, pkt->ptr); return true; } + + if (param_reader.next(¶m)) { + // Get parameter format (default to text format 0) + uint16_t param_format = 0; + if (bind_msg->data().num_param_formats == 1) { + // Single format applies to all parameters + auto format_reader = bind_msg->get_param_format_reader(); + format_reader.next(¶m_format); + } + + // Extract PID from parameter + int32_t pid = extract_pid_from_param(param, param_format); + if (pid > 0) { + // Determine if this is terminate or cancel + int tki = -1; + if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) { + tki = 0; // Connection terminate + } else if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND) { + tki = 1; // Query cancel + } + + if (tki >= 0) { + return handle_kill_success(pid, tki, digest_text, mc, pkt); + } + } else { + // Invalid parameter - send appropriate error response + if (pid == -2) { + // NULL parameter + send_parameter_error_response("NULL is not allowed"); + } else if (pid == -1) { + // Invalid format (not a valid integer) + send_parameter_error_response("invalid input syntax for integer"); + } else if (pid == 0) { + // PID <= 0 (non-positive) + send_parameter_error_response("PID must be a positive integer"); + } + l_free(pkt->size, pkt->ptr); + return true; + } + } else { + // No parameter available - this shouldn't happen + return false; + } + } else { + // No bind message available (shouldn't happen for Execute phase) + return false; } + } else { + // Literal query in extended protocol + return handle_literal_kill_query(pkt, mc); + } + } + } + + return false; +} + +int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const { + + if (param.len == -1) { + // NULL parameter + return -2; // Special value for NULL + } + + /* ---------------- TEXT FORMAT ---------------- */ + if (format == 0) { + // Text format + if (param.len == 0) { + // Empty string + return -1; + } + + // Convert text to integer + std::string_view str_val(reinterpret_cast(param.value), param.len); + + // Validate that the string contains only digits + for (size_t i = 0; i < str_val.size(); i++) { + if (!isdigit(str_val[i])) { + return -1; + } + } + + // Parse the integer + char* endptr; + long pid = strtol(str_val.data(), &endptr, 10); + + // Check for conversion errors + if (endptr != str_val.data() + str_val.size()) { + return -1; + } + + // Check valid range + if (pid <= 0) { + return 0; // Special value for non-positive + } + if (pid > INT_MAX) { + return -1; // Out of range + } + + return static_cast(pid); + } + + /* ---------------- BINARY FORMAT ---------------- */ + // PostgreSQL sends int4 or int8 for integer parameters + if (format == 1) { // Binary format (format == 1) + + if (param.len == 4) { + // uint32 in network byte order + uint32_t host_u32; + get_uint32be(reinterpret_cast(param.value), &host_u32); + int32_t pid = static_cast(host_u32); + + // Validate positive PID + if (pid <= 0) { + return 0; // Special value for non-positive } + return pid; + } + + if (param.len == 8) { + // int64 in network byte order (PostgreSQL sends int8 for some integer types) + uint64_t host_u64 = 0; + get_uint64be(reinterpret_cast(param.value), &host_u64); + int64_t pid = static_cast(host_u64); + + // Validate positive PID and within int32 range + if (pid <= 0) { + return 0; // Special value for non-positive + } + if (pid > INT_MAX) { + return -1; // Out of range + } + return static_cast(pid); + } + + // Invalid integer width for Bind + return -1; + } + + char buf[INET6_ADDRSTRLEN]; + switch (client_myds->client_addr->sa_family) { + case AF_INET: { + struct sockaddr_in* ipv4 = (struct sockaddr_in*)client_myds->client_addr; + inet_ntop(client_myds->client_addr->sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN); + break; + } + case AF_INET6: { + struct sockaddr_in6* ipv6 = (struct sockaddr_in6*)client_myds->client_addr; + inet_ntop(client_myds->client_addr->sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN); + break; + } + default: + sprintf(buf, "localhost"); + break; + } + + // Unknown format code + proxy_error("Unknown parameter format code: %u from client %s", format, buf); + return -1; +} + +void PgSQL_Session::send_parameter_error_response(const char* error_message) { + if (!client_myds) return; + + // Create proper PostgreSQL error message + std::string full_error = std::string("invalid input syntax for integer: \"") + + (error_message ? error_message : "parameter error") + "\""; + client_myds->setDSS_STATE_QUERY_SENT_NET(); + // Generate and send error packet using PostgreSQL protocol + client_myds->myprot.generate_error_packet(true, is_extended_query_ready_for_query(), + full_error.c_str(), PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION, false, true); + + RequestEnd(NULL, true); +} + +bool PgSQL_Session::handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt) { + + proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n", + (tki == 0 ? "CONNECTION" : "QUERY"), pid); + GloPTH->kill_connection_or_query(pid, 0, mc->userinfo->username, (tki == 0 ? false : true)); + client_myds->DSS = STATE_QUERY_SENT_NET; + + std::unique_ptr resultset = std::make_unique(1); + resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend"); + char* pta[1]; + pta[0] = (char*)"t"; + resultset->add_row(pta); + bool send_ready_packet = is_extended_query_ready_for_query(); + unsigned int nTxn = NumActiveTransactions(); + char txn_state = (nTxn ? 'T' : 'I'); + SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, digest_text, send_ready_packet, txn_state); + + RequestEnd(NULL, false); + l_free(pkt->size, pkt->ptr); + return true; +} + +bool PgSQL_Session::handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc) { + // Handle literal query (original implementation) + char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, + pgsql_thread___query_digests_lowercase); + std::string nq { qu, strlen(qu) }; + + re2::RE2::Options opt2(RE2::Quiet); + opt2.set_case_sensitive(false); + const char* pattern = "^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$"; + re2::RE2 re(pattern, opt2); + std::string tk; + uint32_t id = 0; + RE2::FullMatch(nq, re, &tk, &id); + + proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu); + free(qu); + + if (id > 0) { + int tki = -1; + // Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match) + if (strcasecmp(tk.c_str(), "TERMINATE") == 0) { + tki = 0; // Connection terminate + } else if (strcasecmp(tk.c_str(), "CANCEL") == 0) { + tki = 1; // Query cancel + } + if (tki >= 0) { + return handle_kill_success(id, tki, CurrentQuery.QueryParserArgs.digest_text, mc, pkt); } } return false; @@ -6124,6 +6352,17 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu // if we are here, it means we have handled the special command return 0; } + + PGSQL_QUERY_command pg_query_cmd = extended_query_info.stmt_info->PgQueryCmd; + if (pg_query_cmd == PGSQL_QUERY_CANCEL_BACKEND || + pg_query_cmd == PGSQL_QUERY_TERMINATE_BACKEND) { + CurrentQuery.PgQueryCmd = pg_query_cmd; + auto execute_pkt = execute_msg->get_raw_pkt(); // detach the packet from the describe message + if (handle_command_query_kill(&execute_pkt)) { + execute_msg->detach(); // detach the packet from the execute message + return 0; + } + } } current_hostgroup = previous_hostgroup; // reset current hostgroup to previous hostgroup proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Session=%p client_myds=%p. Using previous hostgroup '%d'\n",