Improvements

pull/5299/head
Rahim Kanji 3 months ago
parent a892d9a05b
commit ce42c188f5

@ -607,7 +607,7 @@ public:
private:
int32_t extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const;
void send_parameter_error_response(const char* error_message);
void send_parameter_error_response(const char* error_message, PGSQL_ERROR_CODES code = PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION);
bool handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt);
bool handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc);

@ -5200,8 +5200,15 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) {
while ((p = strstr(p, "$")) != nullptr) {
p++; // Skip '$'
if (isdigit(*p)) {
int param_num = atoi(p);
if (param_num > max_param) max_param = param_num;
char* end;
long param_num = strtol(p, &end, 10);
if (p != end) { // check if any digits were parsed
if (param_num > max_param) {
max_param = static_cast<int>(param_num);
}
p = end;
continue;
}
}
p++;
}
@ -5252,13 +5259,13 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) {
// Invalid parameter - send appropriate error response
if (pid == -2) {
// NULL parameter
send_parameter_error_response("NULL is not allowed");
send_parameter_error_response("NULL is not allowed", PGSQL_ERROR_CODES::ERRCODE_NULL_VALUE_NOT_ALLOWED);
} else if (pid == -1) {
// Invalid format (not a valid integer)
send_parameter_error_response("invalid input syntax for integer");
send_parameter_error_response("invalid input syntax for integer", PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE);
} else if (pid == 0) {
// PID <= 0 (non-positive)
send_parameter_error_response("PID must be a positive integer");
send_parameter_error_response("PID must be a positive integer", PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE);
}
l_free(pkt->size, pkt->ptr);
return true;
@ -5297,7 +5304,7 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui
}
// Convert text to integer
std::string_view str_val(reinterpret_cast<const char*>(param.value), param.len);
std::string str_val(reinterpret_cast<const char*>(param.value), param.len);
// Validate that the string contains only digits
for (size_t i = 0; i < str_val.size(); i++) {
@ -5308,10 +5315,10 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui
// Parse the integer
char* endptr;
long pid = strtol(str_val.data(), &endptr, 10);
long pid = strtol(str_val.c_str(), &endptr, 10);
// Check for conversion errors
if (endptr != str_val.data() + str_val.size()) {
if (endptr != str_val.c_str() + str_val.size()) {
return -1;
}
@ -5334,12 +5341,10 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui
// uint32 in network byte order
uint32_t host_u32;
get_uint32be(reinterpret_cast<const unsigned char*>(param.value), &host_u32);
int32_t pid = static_cast<int32_t>(host_u32);
// Validate positive PID
if (pid <= 0) {
return 0; // Special value for non-positive
if (host_u32 & 0x80000000u) { // negative int4
return 0;
}
int32_t pid = static_cast<int32_t>(host_u32);
return pid;
}
@ -5347,15 +5352,13 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui
// int64 in network byte order (PostgreSQL sends int8 for some integer types)
uint64_t host_u64 = 0;
get_uint64be(reinterpret_cast<const unsigned char*>(param.value), &host_u64);
int64_t pid = static_cast<int64_t>(host_u64);
// Validate positive PID and within int32 range
if (pid <= 0) {
return 0; // Special value for non-positive
if (host_u64 & 0x8000000000000000ull) { // negative int8
return 0;
}
if (pid > INT_MAX) {
return -1; // Out of range
if (host_u64 > static_cast<uint64_t>(INT32_MAX)) {
return -1; // out of range for PID
}
int64_t pid = static_cast<int64_t>(host_u64);
return static_cast<int32_t>(pid);
}
@ -5379,13 +5382,12 @@ int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, ui
sprintf(buf, "localhost");
break;
}
// Unknown format code
proxy_error("Unknown parameter format code: %u from client %s", format, buf);
proxy_error("Unknown parameter format code: %u received from client %s:%d", format, buf, client_myds->addr.port);
return -1;
}
void PgSQL_Session::send_parameter_error_response(const char* error_message) {
void PgSQL_Session::send_parameter_error_response(const char* error_message, PGSQL_ERROR_CODES error_code) {
if (!client_myds) return;
// Create proper PostgreSQL error message
@ -5394,7 +5396,7 @@ void PgSQL_Session::send_parameter_error_response(const char* error_message) {
client_myds->setDSS_STATE_QUERY_SENT_NET();
// Generate and send error packet using PostgreSQL protocol
client_myds->myprot.generate_error_packet(true, is_extended_query_ready_for_query(),
full_error.c_str(), PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION, false, true);
full_error.c_str(), error_code, false, true);
RequestEnd(NULL, true);
}
@ -5425,7 +5427,7 @@ bool PgSQL_Session::handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection*
// Handle literal query (original implementation)
char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength,
pgsql_thread___query_digests_lowercase);
std::string nq { qu, strlen(qu) };
std::string nq(qu);
re2::RE2::Options opt2(RE2::Quiet);
opt2.set_case_sensitive(false);

@ -502,7 +502,7 @@ int main(int argc, char** argv) {
{
auto test_conn = createNewConnection(BACKEND);
if (!test_conn || PQstatus(test_conn.get()) != CONNECTION_OK) {
skip(2, "Connection failed");
skip(1, "Connection failed");
} else {
// Test with SELECT pg_cancel_backend($1, $2)
std::string stmt_name = "multi_param_" + std::to_string(std::chrono::system_clock::now().time_since_epoch().count());
@ -512,7 +512,6 @@ int main(int argc, char** argv) {
bool prepare_failed = (PQresultStatus(res) == PGRES_FATAL_ERROR);
PQclear(res);
ok(prepare_failed, "Multiple parameters should return error");
// Clean up

Loading…
Cancel
Save