Added pgsql_variable_validate_maintenance_work_mem_v3, it now supports decimal values too

pull/4799/head
Rahim Kanji 1 year ago
parent 19297e37d6
commit 9f7cb598f7

@ -328,6 +328,101 @@ bool pgsql_variable_validate_maintenance_work_mem_v2(const char* value, const pa
return true;
}
bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) {
(void)session;
// Trim leading whitespace
while (isspace((unsigned char)*value)) value++;
char* endptr;
const char* num_start = value;
errno = 0;
double number = strtod(value, &endptr);
// Basic numeric validation
if (endptr == num_start || errno == ERANGE || number <= 0)
return false;
// Validate numeric format (digits and single decimal point)
int dot_count = 0;
const char* p = num_start;
while (p < endptr) {
if (*p == '.') {
if (++dot_count > 1) return false;
}
else if (!isdigit((unsigned char)*p)) {
return false;
}
p++;
}
// Parse unit
const char* unit_ptr = endptr;
uint64_t multiplier;
char unit[3] = { 0 };
size_t unit_len = strlen(unit_ptr);
// Default to kB if no unit specified
if (unit_len == 0) {
strcpy(unit, "kB");
multiplier = 1024;
}
else {
// Convert unit to lowercase for validation
char u[3] = { 0 };
for (int i = 0; i < 2 && unit_ptr[i]; i++)
u[i] = tolower((unsigned char)unit_ptr[i]);
// Validate units and set multipliers
if (unit_len == 1 && u[0] == 'b') {
strcpy(unit, "B");
multiplier = 1;
}
else if (strcmp(u, "kb") == 0) {
strcpy(unit, "kB");
multiplier = 1024;
}
else if (strcmp(u, "mb") == 0) {
strcpy(unit, "MB");
multiplier = 1024 * 1024;
}
else if (strcmp(u, "gb") == 0) {
strcpy(unit, "GB");
multiplier = 1024ULL * 1024 * 1024;
}
else if (strcmp(u, "tb") == 0) {
strcpy(unit, "TB");
multiplier = 1024ULL * 1024 * 1024 * 1024;
}
else {
return false;
}
// Validate unit length matches parsed characters
size_t expected_len = (unit[1] == 'B') ? 2 : (unit[0] == 'B') ? 1 : 0;
if (strlen(unit_ptr) != expected_len)
return false;
}
// Calculate total bytes with floating point
uint64_t total_bytes = (uint64_t)number * multiplier;
/* Validate PostgreSQL's requirements */
if ((total_bytes / 1024ULL) < params->uint_range.min || (total_bytes / 1024ULL) > params->uint_range.max)
return false;
char output[128];
/* Format output without leading zeros */
int needed = snprintf(output, sizeof(output), "%.15g%s", number, unit);
if (needed < 0 || needed >= (int)sizeof(output)) return false;
if (transformed_value)
*transformed_value = strdup(output);
return true;
}
const pgsql_variable_validator pgsql_variable_validator_bool = {
.type = VARIABLE_TYPE_BOOL,
.validate = &pgsql_variable_validate_bool,
@ -382,7 +477,7 @@ const pgsql_variable_validator pgsql_variable_validator_datestyle = {
const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = {
.type = VARIABLE_TYPE_MAINTENANCE_WORK_MEM,
.validate = &pgsql_variable_validate_maintenance_work_mem_v2,
.validate = &pgsql_variable_validate_maintenance_work_mem_v3,
.params = {
.uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB
}

@ -273,7 +273,7 @@ std::vector<SetTestCase> test_cases = {
{"SET maintenance_work_mem = '1GB'", true, "1GB"},
{"SET maintenance_work_mem = '1024kB'", true, "1MB"},
{"SET maintenance_work_mem = '1TB'", true, "1TB"},
//{"SET maintenance_work_mem = '1.5GB'", true, "1536MB"}, decimal values not yet supported
{"SET maintenance_work_mem = '1.5GB'", true, "1536MB"},
// Invalid values
{"SET maintenance_work_mem = '64XB'", false, ""},
{"SET maintenance_work_mem = '-128MB'", false, ""},

Loading…
Cancel
Save