diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 098ddd14b..39ec44d48 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -20,6 +20,7 @@ class PgSQL_Describe_Message; class PgSQL_Close_Message; class PgSQL_Bind_Message; class PgSQL_Execute_Message; +struct PgSQL_Param_Value; #ifndef PROXYJSON #define PROXYJSON @@ -580,6 +581,7 @@ public: void Memory_Stats(); void create_new_session_and_reset_connection(PgSQL_Data_Stream* _myds) override; bool handle_command_query_kill(PtrSize_t*); + //void update_expired_conns(const std::vector>&); /** * @brief Performs the final operations after current query has finished to be executed. It updates the session @@ -603,6 +605,12 @@ public: void set_previous_status_mode3(bool allow_execute = true); char* get_current_query(int max_length = -1); +private: + int32_t extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const; + void send_parameter_error_response(const char* error_message); + bool handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt); + bool handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc); + #if defined(__clang__) template friend class Base_Session; diff --git a/include/gen_utils.h b/include/gen_utils.h index 34c260531..8556fd468 100644 --- a/include/gen_utils.h +++ b/include/gen_utils.h @@ -436,6 +436,31 @@ inline T overflow_safe_multiply(T val) { return (val * FACTOR); } +/** + * @brief Read a 64-bit unsigned integer from a big-endian byte buffer. + * + * Reads 8 bytes from the provided buffer and converts them from + * big-endian (network byte order) into host byte order. + * + * @param pkt Pointer to at least 8 bytes of input data. + * @param dst_p Pointer to the destination uint64_t where the result + * will be stored. + * + * @return true Always returns true. + */ +inline bool get_uint64be(const unsigned char* pkt, uint64_t* dst_p) { + *dst_p = + ((uint64_t)pkt[0] << 56) | + ((uint64_t)pkt[1] << 48) | + ((uint64_t)pkt[2] << 40) | + ((uint64_t)pkt[3] << 32) | + ((uint64_t)pkt[4] << 24) | + ((uint64_t)pkt[5] << 16) | + ((uint64_t)pkt[6] << 8) | + ((uint64_t)pkt[7]); + return true; +} + /* * @brief Reads and converts a big endian 32-bit unsigned integer from the provided packet buffer into the destination pointer. * @@ -448,9 +473,9 @@ inline T overflow_safe_multiply(T val) { */ inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) { *dst_p = ((uint32_t)pkt[0] << 24) | - ((uint32_t)pkt[1] << 16) | - ((uint32_t)pkt[2] << 8) | - ((uint32_t)pkt[3]); + ((uint32_t)pkt[1] << 16) | + ((uint32_t)pkt[2] << 8) | + ((uint32_t)pkt[3]); return true; } diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 919bfd882..36bebfae6 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -4425,11 +4425,10 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_Q } // Handle KILL command - //if (prepared == false) { if (handle_command_query_kill(pkt)) { return true; } - // + // Query cache handling if (qpo->cache_ttl > 0 && stmt_type == PGSQL_EXTENDED_QUERY_TYPE_NOT_SET) { const std::shared_ptr pgsql_qc_entry = GloPgQC->get( @@ -5171,55 +5170,284 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) { if (!CurrentQuery.QueryParserArgs.digest_text) return false; - if (client_myds && client_myds->myconn) { - PgSQL_Connection* mc = client_myds->myconn; - if (mc->userinfo && mc->userinfo->username) { - if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND || - CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) { - char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, - pgsql_thread___query_digests_lowercase); - string nq = string(qu, strlen(qu)); - re2::RE2::Options* opt2 = new re2::RE2::Options(RE2::Quiet); - opt2->set_case_sensitive(false); - char* pattern = (char*)"^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$"; - re2::RE2* re = new RE2(pattern, *opt2); - string tk; - int id = 0; - RE2::FullMatch(nq, *re, &tk, &id); - delete re; - delete opt2; - proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu); - free(qu); - - if (id) { - int tki = -1; - // Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match) - if (strcasecmp(tk.c_str(), "TERMINATE") == 0) { - tki = 0; // Connection terminate - } else if (strcasecmp(tk.c_str(), "CANCEL") == 0) { - tki = 1; // Query cancel - } - if (tki >= 0) { - proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n", (tki == 0 ? "CONNECTION" : "QUERY"), id); - GloPTH->kill_connection_or_query(id, 0, mc->userinfo->username, (tki == 0 ? false : true)); - client_myds->DSS = STATE_QUERY_SENT_NET; - - std::unique_ptr resultset = std::make_unique(1); - resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend"); - char* pta[1]; - pta[0] = (char*)"t"; - resultset->add_row(pta); - bool send_ready_packet = is_extended_query_ready_for_query(); - unsigned int nTxn = NumActiveTransactions(); - char txn_state = (nTxn ? 'T' : 'I'); - SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, (const char*)pkt->ptr + 5, send_ready_packet, txn_state); + if (!client_myds || + !client_myds->myconn || + !client_myds->myconn->userinfo || + !client_myds->myconn->userinfo->username) { + return false; + } - RequestEnd(NULL, false); + PgSQL_Connection* mc = client_myds->myconn; + if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND || + CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) { + + if (cmd == 'Q') { + // Simple query protocol - only handle literal values + // Parameterized queries in simple protocol are invalid and will be handled by PostgreSQL + return handle_literal_kill_query(pkt, mc); + } else { + // cmd == 'E' - Execute phase of extended query protocol + // Check if this is a parameterized query (contains $1) + // Note: This simple check might have false positives if $1 appears in comments or string literals + // but those cases would fail later when checking bind_msg or parameter validation + const char* digest_text = CurrentQuery.QueryParserArgs.digest_text; + bool is_parameterized = strstr(digest_text, "$1") != nullptr; + if (is_parameterized) { + // Check if there are multiple parameters (e.g., $1, $2) + // Look for $2, $3, etc. to reject multiple parameters + const char* p = digest_text; + int max_param = 0; + while ((p = strstr(p, "$")) != nullptr) { + p++; // Skip '$' + if (isdigit(*p)) { + int param_num = atoi(p); + if (param_num > max_param) max_param = param_num; + } + p++; + } + if (max_param > 1) { + // Multiple parameters not supported + send_parameter_error_response("function requires exactly one parameter"); + l_free(pkt->size, pkt->ptr); + return true; + } + + // Handle parameterized query + if (CurrentQuery.extended_query_info.bind_msg) { + const PgSQL_Bind_Message* bind_msg = CurrentQuery.extended_query_info.bind_msg; + auto param_reader = bind_msg->get_param_value_reader(); + PgSQL_Param_Value param; + + // Check that we have exactly one parameter + if (bind_msg->data().num_param_values != 1) { + send_parameter_error_response("function requires exactly one parameter"); l_free(pkt->size, pkt->ptr); return true; } + + if (param_reader.next(¶m)) { + // Get parameter format (default to text format 0) + uint16_t param_format = 0; + if (bind_msg->data().num_param_formats == 1) { + // Single format applies to all parameters + auto format_reader = bind_msg->get_param_format_reader(); + format_reader.next(¶m_format); + } + + // Extract PID from parameter + int32_t pid = extract_pid_from_param(param, param_format); + if (pid > 0) { + // Determine if this is terminate or cancel + int tki = -1; + if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) { + tki = 0; // Connection terminate + } else if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND) { + tki = 1; // Query cancel + } + + if (tki >= 0) { + return handle_kill_success(pid, tki, digest_text, mc, pkt); + } + } else { + // Invalid parameter - send appropriate error response + if (pid == -2) { + // NULL parameter + send_parameter_error_response("NULL is not allowed"); + } else if (pid == -1) { + // Invalid format (not a valid integer) + send_parameter_error_response("invalid input syntax for integer"); + } else if (pid == 0) { + // PID <= 0 (non-positive) + send_parameter_error_response("PID must be a positive integer"); + } + l_free(pkt->size, pkt->ptr); + return true; + } + } else { + // No parameter available - this shouldn't happen + return false; + } + } else { + // No bind message available (shouldn't happen for Execute phase) + return false; } + } else { + // Literal query in extended protocol + return handle_literal_kill_query(pkt, mc); + } + } + } + + return false; +} + +int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const { + + if (param.len == -1) { + // NULL parameter + return -2; // Special value for NULL + } + + /* ---------------- TEXT FORMAT ---------------- */ + if (format == 0) { + // Text format + if (param.len == 0) { + // Empty string + return -1; + } + + // Convert text to integer + std::string_view str_val(reinterpret_cast(param.value), param.len); + + // Validate that the string contains only digits + for (size_t i = 0; i < str_val.size(); i++) { + if (!isdigit(str_val[i])) { + return -1; + } + } + + // Parse the integer + char* endptr; + long pid = strtol(str_val.data(), &endptr, 10); + + // Check for conversion errors + if (endptr != str_val.data() + str_val.size()) { + return -1; + } + + // Check valid range + if (pid <= 0) { + return 0; // Special value for non-positive + } + if (pid > INT_MAX) { + return -1; // Out of range + } + + return static_cast(pid); + } + + /* ---------------- BINARY FORMAT ---------------- */ + // PostgreSQL sends int4 or int8 for integer parameters + if (format == 1) { // Binary format (format == 1) + + if (param.len == 4) { + // uint32 in network byte order + uint32_t host_u32; + get_uint32be(reinterpret_cast(param.value), &host_u32); + int32_t pid = static_cast(host_u32); + + // Validate positive PID + if (pid <= 0) { + return 0; // Special value for non-positive } + return pid; + } + + if (param.len == 8) { + // int64 in network byte order (PostgreSQL sends int8 for some integer types) + uint64_t host_u64 = 0; + get_uint64be(reinterpret_cast(param.value), &host_u64); + int64_t pid = static_cast(host_u64); + + // Validate positive PID and within int32 range + if (pid <= 0) { + return 0; // Special value for non-positive + } + if (pid > INT_MAX) { + return -1; // Out of range + } + return static_cast(pid); + } + + // Invalid integer width for Bind + return -1; + } + + char buf[INET6_ADDRSTRLEN]; + switch (client_myds->client_addr->sa_family) { + case AF_INET: { + struct sockaddr_in* ipv4 = (struct sockaddr_in*)client_myds->client_addr; + inet_ntop(client_myds->client_addr->sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN); + break; + } + case AF_INET6: { + struct sockaddr_in6* ipv6 = (struct sockaddr_in6*)client_myds->client_addr; + inet_ntop(client_myds->client_addr->sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN); + break; + } + default: + sprintf(buf, "localhost"); + break; + } + + // Unknown format code + proxy_error("Unknown parameter format code: %u from client %s", format, buf); + return -1; +} + +void PgSQL_Session::send_parameter_error_response(const char* error_message) { + if (!client_myds) return; + + // Create proper PostgreSQL error message + std::string full_error = std::string("invalid input syntax for integer: \"") + + (error_message ? error_message : "parameter error") + "\""; + client_myds->setDSS_STATE_QUERY_SENT_NET(); + // Generate and send error packet using PostgreSQL protocol + client_myds->myprot.generate_error_packet(true, is_extended_query_ready_for_query(), + full_error.c_str(), PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION, false, true); + + RequestEnd(NULL, true); +} + +bool PgSQL_Session::handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt) { + + proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n", + (tki == 0 ? "CONNECTION" : "QUERY"), pid); + GloPTH->kill_connection_or_query(pid, 0, mc->userinfo->username, (tki == 0 ? false : true)); + client_myds->DSS = STATE_QUERY_SENT_NET; + + std::unique_ptr resultset = std::make_unique(1); + resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend"); + char* pta[1]; + pta[0] = (char*)"t"; + resultset->add_row(pta); + bool send_ready_packet = is_extended_query_ready_for_query(); + unsigned int nTxn = NumActiveTransactions(); + char txn_state = (nTxn ? 'T' : 'I'); + SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, digest_text, send_ready_packet, txn_state); + + RequestEnd(NULL, false); + l_free(pkt->size, pkt->ptr); + return true; +} + +bool PgSQL_Session::handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc) { + // Handle literal query (original implementation) + char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, + pgsql_thread___query_digests_lowercase); + std::string nq { qu, strlen(qu) }; + + re2::RE2::Options opt2(RE2::Quiet); + opt2.set_case_sensitive(false); + const char* pattern = "^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$"; + re2::RE2 re(pattern, opt2); + std::string tk; + uint32_t id = 0; + RE2::FullMatch(nq, re, &tk, &id); + + proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu); + free(qu); + + if (id > 0) { + int tki = -1; + // Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match) + if (strcasecmp(tk.c_str(), "TERMINATE") == 0) { + tki = 0; // Connection terminate + } else if (strcasecmp(tk.c_str(), "CANCEL") == 0) { + tki = 1; // Query cancel + } + if (tki >= 0) { + return handle_kill_success(id, tki, CurrentQuery.QueryParserArgs.digest_text, mc, pkt); } } return false; @@ -6124,6 +6352,17 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu // if we are here, it means we have handled the special command return 0; } + + PGSQL_QUERY_command pg_query_cmd = extended_query_info.stmt_info->PgQueryCmd; + if (pg_query_cmd == PGSQL_QUERY_CANCEL_BACKEND || + pg_query_cmd == PGSQL_QUERY_TERMINATE_BACKEND) { + CurrentQuery.PgQueryCmd = pg_query_cmd; + auto execute_pkt = execute_msg->get_raw_pkt(); // detach the packet from the describe message + if (handle_command_query_kill(&execute_pkt)) { + execute_msg->detach(); // detach the packet from the execute message + return 0; + } + } } current_hostgroup = previous_hostgroup; // reset current hostgroup to previous hostgroup proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Session=%p client_myds=%p. Using previous hostgroup '%d'\n",