Add parameterized PID support for pg_cancel_backend/pg_terminate_backend

This commit extends the existing pg_cancel_backend() and pg_terminate_backend()
support to work with parameterized queries in the extended query protocol.
While literal PID values were already supported in both simple and extended
query protocols, this enhancement adds support for parameterized queries like
SELECT pg_cancel_backend($1).
pull/5299/head
Rahim Kanji 4 months ago
parent 7e9e00997d
commit a1e10e3055

@ -20,6 +20,7 @@ class PgSQL_Describe_Message;
class PgSQL_Close_Message;
class PgSQL_Bind_Message;
class PgSQL_Execute_Message;
struct PgSQL_Param_Value;
#ifndef PROXYJSON
#define PROXYJSON
@ -580,6 +581,7 @@ public:
void Memory_Stats();
void create_new_session_and_reset_connection(PgSQL_Data_Stream* _myds) override;
bool handle_command_query_kill(PtrSize_t*);
//void update_expired_conns(const std::vector<std::function<bool(PgSQL_Connection*)>>&);
/**
* @brief Performs the final operations after current query has finished to be executed. It updates the session
@ -603,6 +605,12 @@ public:
void set_previous_status_mode3(bool allow_execute = true);
char* get_current_query(int max_length = -1);
private:
int32_t extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const;
void send_parameter_error_response(const char* error_message);
bool handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt);
bool handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc);
#if defined(__clang__)
template<typename SESS, typename DS, typename BE, typename THD>
friend class Base_Session;

@ -436,6 +436,31 @@ inline T overflow_safe_multiply(T val) {
return (val * FACTOR);
}
/**
* @brief Read a 64-bit unsigned integer from a big-endian byte buffer.
*
* Reads 8 bytes from the provided buffer and converts them from
* big-endian (network byte order) into host byte order.
*
* @param pkt Pointer to at least 8 bytes of input data.
* @param dst_p Pointer to the destination uint64_t where the result
* will be stored.
*
* @return true Always returns true.
*/
inline bool get_uint64be(const unsigned char* pkt, uint64_t* dst_p) {
*dst_p =
((uint64_t)pkt[0] << 56) |
((uint64_t)pkt[1] << 48) |
((uint64_t)pkt[2] << 40) |
((uint64_t)pkt[3] << 32) |
((uint64_t)pkt[4] << 24) |
((uint64_t)pkt[5] << 16) |
((uint64_t)pkt[6] << 8) |
((uint64_t)pkt[7]);
return true;
}
/*
* @brief Reads and converts a big endian 32-bit unsigned integer from the provided packet buffer into the destination pointer.
*
@ -448,9 +473,9 @@ inline T overflow_safe_multiply(T val) {
*/
inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) {
*dst_p = ((uint32_t)pkt[0] << 24) |
((uint32_t)pkt[1] << 16) |
((uint32_t)pkt[2] << 8) |
((uint32_t)pkt[3]);
((uint32_t)pkt[1] << 16) |
((uint32_t)pkt[2] << 8) |
((uint32_t)pkt[3]);
return true;
}

@ -4425,11 +4425,10 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_Q
}
// Handle KILL command
//if (prepared == false) {
if (handle_command_query_kill(pkt)) {
return true;
}
//
// Query cache handling
if (qpo->cache_ttl > 0 && stmt_type == PGSQL_EXTENDED_QUERY_TYPE_NOT_SET) {
const std::shared_ptr<PgSQL_QC_entry_t> pgsql_qc_entry = GloPgQC->get(
@ -5171,55 +5170,284 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) {
if (!CurrentQuery.QueryParserArgs.digest_text)
return false;
if (client_myds && client_myds->myconn) {
PgSQL_Connection* mc = client_myds->myconn;
if (mc->userinfo && mc->userinfo->username) {
if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND ||
CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) {
char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength,
pgsql_thread___query_digests_lowercase);
string nq = string(qu, strlen(qu));
re2::RE2::Options* opt2 = new re2::RE2::Options(RE2::Quiet);
opt2->set_case_sensitive(false);
char* pattern = (char*)"^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$";
re2::RE2* re = new RE2(pattern, *opt2);
string tk;
int id = 0;
RE2::FullMatch(nq, *re, &tk, &id);
delete re;
delete opt2;
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu);
free(qu);
if (id) {
int tki = -1;
// Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match)
if (strcasecmp(tk.c_str(), "TERMINATE") == 0) {
tki = 0; // Connection terminate
} else if (strcasecmp(tk.c_str(), "CANCEL") == 0) {
tki = 1; // Query cancel
}
if (tki >= 0) {
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n", (tki == 0 ? "CONNECTION" : "QUERY"), id);
GloPTH->kill_connection_or_query(id, 0, mc->userinfo->username, (tki == 0 ? false : true));
client_myds->DSS = STATE_QUERY_SENT_NET;
std::unique_ptr<SQLite3_result> resultset = std::make_unique<SQLite3_result>(1);
resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend");
char* pta[1];
pta[0] = (char*)"t";
resultset->add_row(pta);
bool send_ready_packet = is_extended_query_ready_for_query();
unsigned int nTxn = NumActiveTransactions();
char txn_state = (nTxn ? 'T' : 'I');
SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, (const char*)pkt->ptr + 5, send_ready_packet, txn_state);
if (!client_myds ||
!client_myds->myconn ||
!client_myds->myconn->userinfo ||
!client_myds->myconn->userinfo->username) {
return false;
}
RequestEnd(NULL, false);
PgSQL_Connection* mc = client_myds->myconn;
if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND ||
CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) {
if (cmd == 'Q') {
// Simple query protocol - only handle literal values
// Parameterized queries in simple protocol are invalid and will be handled by PostgreSQL
return handle_literal_kill_query(pkt, mc);
} else {
// cmd == 'E' - Execute phase of extended query protocol
// Check if this is a parameterized query (contains $1)
// Note: This simple check might have false positives if $1 appears in comments or string literals
// but those cases would fail later when checking bind_msg or parameter validation
const char* digest_text = CurrentQuery.QueryParserArgs.digest_text;
bool is_parameterized = strstr(digest_text, "$1") != nullptr;
if (is_parameterized) {
// Check if there are multiple parameters (e.g., $1, $2)
// Look for $2, $3, etc. to reject multiple parameters
const char* p = digest_text;
int max_param = 0;
while ((p = strstr(p, "$")) != nullptr) {
p++; // Skip '$'
if (isdigit(*p)) {
int param_num = atoi(p);
if (param_num > max_param) max_param = param_num;
}
p++;
}
if (max_param > 1) {
// Multiple parameters not supported
send_parameter_error_response("function requires exactly one parameter");
l_free(pkt->size, pkt->ptr);
return true;
}
// Handle parameterized query
if (CurrentQuery.extended_query_info.bind_msg) {
const PgSQL_Bind_Message* bind_msg = CurrentQuery.extended_query_info.bind_msg;
auto param_reader = bind_msg->get_param_value_reader();
PgSQL_Param_Value param;
// Check that we have exactly one parameter
if (bind_msg->data().num_param_values != 1) {
send_parameter_error_response("function requires exactly one parameter");
l_free(pkt->size, pkt->ptr);
return true;
}
if (param_reader.next(&param)) {
// Get parameter format (default to text format 0)
uint16_t param_format = 0;
if (bind_msg->data().num_param_formats == 1) {
// Single format applies to all parameters
auto format_reader = bind_msg->get_param_format_reader();
format_reader.next(&param_format);
}
// Extract PID from parameter
int32_t pid = extract_pid_from_param(param, param_format);
if (pid > 0) {
// Determine if this is terminate or cancel
int tki = -1;
if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) {
tki = 0; // Connection terminate
} else if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND) {
tki = 1; // Query cancel
}
if (tki >= 0) {
return handle_kill_success(pid, tki, digest_text, mc, pkt);
}
} else {
// Invalid parameter - send appropriate error response
if (pid == -2) {
// NULL parameter
send_parameter_error_response("NULL is not allowed");
} else if (pid == -1) {
// Invalid format (not a valid integer)
send_parameter_error_response("invalid input syntax for integer");
} else if (pid == 0) {
// PID <= 0 (non-positive)
send_parameter_error_response("PID must be a positive integer");
}
l_free(pkt->size, pkt->ptr);
return true;
}
} else {
// No parameter available - this shouldn't happen
return false;
}
} else {
// No bind message available (shouldn't happen for Execute phase)
return false;
}
} else {
// Literal query in extended protocol
return handle_literal_kill_query(pkt, mc);
}
}
}
return false;
}
int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const {
if (param.len == -1) {
// NULL parameter
return -2; // Special value for NULL
}
/* ---------------- TEXT FORMAT ---------------- */
if (format == 0) {
// Text format
if (param.len == 0) {
// Empty string
return -1;
}
// Convert text to integer
std::string_view str_val(reinterpret_cast<const char*>(param.value), param.len);
// Validate that the string contains only digits
for (size_t i = 0; i < str_val.size(); i++) {
if (!isdigit(str_val[i])) {
return -1;
}
}
// Parse the integer
char* endptr;
long pid = strtol(str_val.data(), &endptr, 10);
// Check for conversion errors
if (endptr != str_val.data() + str_val.size()) {
return -1;
}
// Check valid range
if (pid <= 0) {
return 0; // Special value for non-positive
}
if (pid > INT_MAX) {
return -1; // Out of range
}
return static_cast<int32_t>(pid);
}
/* ---------------- BINARY FORMAT ---------------- */
// PostgreSQL sends int4 or int8 for integer parameters
if (format == 1) { // Binary format (format == 1)
if (param.len == 4) {
// uint32 in network byte order
uint32_t host_u32;
get_uint32be(reinterpret_cast<const unsigned char*>(param.value), &host_u32);
int32_t pid = static_cast<int32_t>(host_u32);
// Validate positive PID
if (pid <= 0) {
return 0; // Special value for non-positive
}
return pid;
}
if (param.len == 8) {
// int64 in network byte order (PostgreSQL sends int8 for some integer types)
uint64_t host_u64 = 0;
get_uint64be(reinterpret_cast<const unsigned char*>(param.value), &host_u64);
int64_t pid = static_cast<int64_t>(host_u64);
// Validate positive PID and within int32 range
if (pid <= 0) {
return 0; // Special value for non-positive
}
if (pid > INT_MAX) {
return -1; // Out of range
}
return static_cast<int32_t>(pid);
}
// Invalid integer width for Bind
return -1;
}
char buf[INET6_ADDRSTRLEN];
switch (client_myds->client_addr->sa_family) {
case AF_INET: {
struct sockaddr_in* ipv4 = (struct sockaddr_in*)client_myds->client_addr;
inet_ntop(client_myds->client_addr->sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN);
break;
}
case AF_INET6: {
struct sockaddr_in6* ipv6 = (struct sockaddr_in6*)client_myds->client_addr;
inet_ntop(client_myds->client_addr->sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN);
break;
}
default:
sprintf(buf, "localhost");
break;
}
// Unknown format code
proxy_error("Unknown parameter format code: %u from client %s", format, buf);
return -1;
}
void PgSQL_Session::send_parameter_error_response(const char* error_message) {
if (!client_myds) return;
// Create proper PostgreSQL error message
std::string full_error = std::string("invalid input syntax for integer: \"") +
(error_message ? error_message : "parameter error") + "\"";
client_myds->setDSS_STATE_QUERY_SENT_NET();
// Generate and send error packet using PostgreSQL protocol
client_myds->myprot.generate_error_packet(true, is_extended_query_ready_for_query(),
full_error.c_str(), PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION, false, true);
RequestEnd(NULL, true);
}
bool PgSQL_Session::handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt) {
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n",
(tki == 0 ? "CONNECTION" : "QUERY"), pid);
GloPTH->kill_connection_or_query(pid, 0, mc->userinfo->username, (tki == 0 ? false : true));
client_myds->DSS = STATE_QUERY_SENT_NET;
std::unique_ptr<SQLite3_result> resultset = std::make_unique<SQLite3_result>(1);
resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend");
char* pta[1];
pta[0] = (char*)"t";
resultset->add_row(pta);
bool send_ready_packet = is_extended_query_ready_for_query();
unsigned int nTxn = NumActiveTransactions();
char txn_state = (nTxn ? 'T' : 'I');
SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, digest_text, send_ready_packet, txn_state);
RequestEnd(NULL, false);
l_free(pkt->size, pkt->ptr);
return true;
}
bool PgSQL_Session::handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc) {
// Handle literal query (original implementation)
char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength,
pgsql_thread___query_digests_lowercase);
std::string nq { qu, strlen(qu) };
re2::RE2::Options opt2(RE2::Quiet);
opt2.set_case_sensitive(false);
const char* pattern = "^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$";
re2::RE2 re(pattern, opt2);
std::string tk;
uint32_t id = 0;
RE2::FullMatch(nq, re, &tk, &id);
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu);
free(qu);
if (id > 0) {
int tki = -1;
// Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match)
if (strcasecmp(tk.c_str(), "TERMINATE") == 0) {
tki = 0; // Connection terminate
} else if (strcasecmp(tk.c_str(), "CANCEL") == 0) {
tki = 1; // Query cancel
}
if (tki >= 0) {
return handle_kill_success(id, tki, CurrentQuery.QueryParserArgs.digest_text, mc, pkt);
}
}
return false;
@ -6124,6 +6352,17 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu
// if we are here, it means we have handled the special command
return 0;
}
PGSQL_QUERY_command pg_query_cmd = extended_query_info.stmt_info->PgQueryCmd;
if (pg_query_cmd == PGSQL_QUERY_CANCEL_BACKEND ||
pg_query_cmd == PGSQL_QUERY_TERMINATE_BACKEND) {
CurrentQuery.PgQueryCmd = pg_query_cmd;
auto execute_pkt = execute_msg->get_raw_pkt(); // detach the packet from the describe message
if (handle_command_query_kill(&execute_pkt)) {
execute_msg->detach(); // detach the packet from the execute message
return 0;
}
}
}
current_hostgroup = previous_hostgroup; // reset current hostgroup to previous hostgroup
proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Session=%p client_myds=%p. Using previous hostgroup '%d'\n",

Loading…
Cancel
Save