diff --git a/include/PgSQL_Session.h b/include/PgSQL_Session.h index 39ec44d48..967515fa5 100644 --- a/include/PgSQL_Session.h +++ b/include/PgSQL_Session.h @@ -607,7 +607,7 @@ public: private: int32_t extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const; - void send_parameter_error_response(const char* error_message); + void send_parameter_error_response(const char* error_message, PGSQL_ERROR_CODES code = PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION); bool handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt); bool handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc); diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 36bebfae6..aac4326df 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -5200,8 +5200,15 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) { while ((p = strstr(p, "$")) != nullptr) { p++; // Skip '$' if (isdigit(*p)) { - int param_num = atoi(p); - if (param_num > max_param) max_param = param_num; + char* end; + long param_num = strtol(p, &end, 10); + if (p != end) { // check if any digits were parsed + if (param_num > max_param) { + max_param = static_cast(param_num); + } + p = end; + continue; + } } p++; } @@ -5252,13 +5259,13 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) { // Invalid parameter - send appropriate error response if (pid == -2) { // NULL parameter - send_parameter_error_response("NULL is not allowed"); + send_parameter_error_response("NULL is not allowed", PGSQL_ERROR_CODES::ERRCODE_NULL_VALUE_NOT_ALLOWED); } else if (pid == -1) { // Invalid format (not a valid integer) - send_parameter_error_response("invalid input syntax for integer"); + send_parameter_error_response("invalid input syntax for integer", PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE); } else if (pid == 0) { // PID <= 0 (non-positive) - send_parameter_error_response("PID must be a positive integer"); + send_parameter_error_response("PID must be a positive integer", PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE); } l_free(pkt->size, pkt->ptr); return true; @@ -5297,7 +5304,7 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui } // Convert text to integer - std::string_view str_val(reinterpret_cast(param.value), param.len); + std::string str_val(reinterpret_cast(param.value), param.len); // Validate that the string contains only digits for (size_t i = 0; i < str_val.size(); i++) { @@ -5308,10 +5315,10 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui // Parse the integer char* endptr; - long pid = strtol(str_val.data(), &endptr, 10); + long pid = strtol(str_val.c_str(), &endptr, 10); // Check for conversion errors - if (endptr != str_val.data() + str_val.size()) { + if (endptr != str_val.c_str() + str_val.size()) { return -1; } @@ -5334,12 +5341,10 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui // uint32 in network byte order uint32_t host_u32; get_uint32be(reinterpret_cast(param.value), &host_u32); - int32_t pid = static_cast(host_u32); - - // Validate positive PID - if (pid <= 0) { - return 0; // Special value for non-positive + if (host_u32 & 0x80000000u) { // negative int4 + return 0; } + int32_t pid = static_cast(host_u32); return pid; } @@ -5347,15 +5352,13 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui // int64 in network byte order (PostgreSQL sends int8 for some integer types) uint64_t host_u64 = 0; get_uint64be(reinterpret_cast(param.value), &host_u64); - int64_t pid = static_cast(host_u64); - - // Validate positive PID and within int32 range - if (pid <= 0) { - return 0; // Special value for non-positive + if (host_u64 & 0x8000000000000000ull) { // negative int8 + return 0; } - if (pid > INT_MAX) { - return -1; // Out of range + if (host_u64 > static_cast(INT32_MAX)) { + return -1; // out of range for PID } + int64_t pid = static_cast(host_u64); return static_cast(pid); } @@ -5379,13 +5382,12 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui sprintf(buf, "localhost"); break; } - // Unknown format code - proxy_error("Unknown parameter format code: %u from client %s", format, buf); + proxy_error("Unknown parameter format code: %u received from client %s:%d", format, buf, client_myds->addr.port); return -1; } -void PgSQL_Session::send_parameter_error_response(const char* error_message) { +void PgSQL_Session::send_parameter_error_response(const char* error_message, PGSQL_ERROR_CODES error_code) { if (!client_myds) return; // Create proper PostgreSQL error message @@ -5394,7 +5396,7 @@ void PgSQL_Session::send_parameter_error_response(const char* error_message) { client_myds->setDSS_STATE_QUERY_SENT_NET(); // Generate and send error packet using PostgreSQL protocol client_myds->myprot.generate_error_packet(true, is_extended_query_ready_for_query(), - full_error.c_str(), PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION, false, true); + full_error.c_str(), error_code, false, true); RequestEnd(NULL, true); } @@ -5425,7 +5427,7 @@ bool PgSQL_Session::handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* // Handle literal query (original implementation) char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength, pgsql_thread___query_digests_lowercase); - std::string nq { qu, strlen(qu) }; + std::string nq(qu); re2::RE2::Options opt2(RE2::Quiet); opt2.set_case_sensitive(false); diff --git a/test/tap/tests/pgsql-parameterized_kill_queries_test-t.cpp b/test/tap/tests/pgsql-parameterized_kill_queries_test-t.cpp index 076633f64..8e6931765 100644 --- a/test/tap/tests/pgsql-parameterized_kill_queries_test-t.cpp +++ b/test/tap/tests/pgsql-parameterized_kill_queries_test-t.cpp @@ -502,7 +502,7 @@ int main(int argc, char** argv) { { auto test_conn = createNewConnection(BACKEND); if (!test_conn || PQstatus(test_conn.get()) != CONNECTION_OK) { - skip(2, "Connection failed"); + skip(1, "Connection failed"); } else { // Test with SELECT pg_cancel_backend($1, $2) std::string stmt_name = "multi_param_" + std::to_string(std::chrono::system_clock::now().time_since_epoch().count()); @@ -512,7 +512,6 @@ int main(int argc, char** argv) { bool prepare_failed = (PQresultStatus(res) == PGRES_FATAL_ERROR); PQclear(res); - ok(prepare_failed, "Multiple parameters should return error"); // Clean up