Improved Session parameter value validator

pull/4799/head
Rahim Kanji 1 year ago
parent 94399faaf9
commit c5fb8da375

@ -5,27 +5,28 @@
typedef enum {
VARIABLE_TYPE_NONE = 0, /**< No variable type. */
VARIABLE_TYPE_INT, /**< Integer variable type. */
VARIABLE_TYPE_FLOAT, /**< Float variable type. */
VARIABLE_TYPE_BOOL, /**< Boolean variable type. */
VARIABLE_TYPE_STRING, /**< String variable type. */
VARIABLE_TYPE_DATESTYLE, /**< DateStyle variable type. */
VARIABLE_TYPE_MAINTENANCE_WORK_MEM
} pgsql_variable_type_t;
/**
* @struct int_range_t
* @brief Struct representing a range of integer values.
*/
typedef struct {
int min; /**< Minimum value of the range. */
int max; /**< Maximum value of the range. */
} int_range_t;
template<typename T>
struct range_t {
T min; /**< Minimum value of the range. */
T max; /**< Maximum value of the range. */
};
/**
* @union params_t
* @brief Union representing the parameters for variable validation.
*/
typedef union {
int_range_t int_range; /**< Integer range parameters. */
range_t<int> int_range; /**< Integer range parameters. */
range_t<unsigned int> uint_range; /**< Integer range parameters. */
range_t<float> float_range; /**< Float range parameters. */
const char** string_allowed; /**< Allowed string values. */
} params_t;

@ -25,7 +25,8 @@ bool pgsql_variable_validate_bool(const char* value, const params_t* params, PgS
(strcasecmp(value, (char*)"0") == 0) ||
(strcasecmp(value, (char*)"f") == 0) ||
(strcasecmp(value, (char*)"false") == 0) ||
(strcasecmp(value, (char*)"off") == 0)) {
(strcasecmp(value, (char*)"off") == 0) ||
(strcasecmp(value, (char*)"no") == 0)) {
if (transformed_value)
*transformed_value = strdup("off");
result = true;
@ -33,7 +34,8 @@ bool pgsql_variable_validate_bool(const char* value, const params_t* params, PgS
(strcasecmp(value, (char*)"1") == 0) ||
(strcasecmp(value, (char*)"t") == 0) ||
(strcasecmp(value, (char*)"true") == 0) ||
(strcasecmp(value, (char*)"on") == 0)) {
(strcasecmp(value, (char*)"on") == 0) ||
(strcasecmp(value, (char*)"yes") == 0)) {
if (transformed_value)
*transformed_value = strdup("on");
result = true;
@ -42,25 +44,26 @@ bool pgsql_variable_validate_bool(const char* value, const params_t* params, PgS
}
/**
* @brief Validates an integer variable for PostgreSQL.
* @brief Validates an float variable for PostgreSQL.
*
* This function checks if the provided value is a valid integer representation
* This function checks if the provided value is a valid float representation
* and falls within the specified range. The range is defined by the params
* parameter.
*
* @param value The value to validate.
* @param params The parameter structure containing the integer range.
* @param params The parameter structure containing the float range.
* @param session Unused parameter.
* @param transformed_value If not null, will be set to null.
* @return true if the value is a valid integer representation within the specified range, false otherwise.
* @return true if the value is a valid float representation within the specified range, false otherwise.
*/
bool pgsql_variable_validate_integer(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) {
bool pgsql_variable_validate_float(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) {
(void)session;
if (transformed_value) *transformed_value = nullptr;
char* end = nullptr;
long num = strtol(value, &end, 10);
//long num = strtol(value, &end, 10);
double num = strtod(value, &end);
if (end == value || *end != '\0') return false;
if (num < params->int_range.min || num > params->int_range.max) return false;
if (num < params->float_range.min || num > params->float_range.max) return false;
return true;
}
@ -197,9 +200,14 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
// Parse unit
if (*p != '\0') {
switch (tolower(*p)) {
case 'k': case 'm': case 'g': case 't':
unit = tolower(*p++);
char tmp_unit = tolower(*p);
switch (tmp_unit) {
case 'k':
case 'm':
case 'g':
case 't':
if (tmp_unit != 'k')
unit = toupper(*p++);
has_unit = true;
// Check optional 'b'/'B'
if (tolower(*p) == 'b') p++;
@ -217,7 +225,7 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
char output[128];
// Format normalized string (always show unit in lowercase)
int written = snprintf(output, sizeof(output), has_unit ? "%lld%cb" : "%lldkb",
int written = snprintf(output, sizeof(output), has_unit ? "%lld%cB" : "%lldkB",
num, unit);
if (written < 0 || written >= (int)sizeof(output)) return false;
@ -228,6 +236,98 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
return true;
}
bool pgsql_variable_validate_maintenance_work_mem_v2(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) {
(void)session;
const char* input = value;
/* Trim leading whitespace */
while (isspace((unsigned char)*input)) input++;
/* Parse numeric part */
uint64_t number;
char* endptr;
//size_t num_len = 0;
errno = 0;
number = strtoull(input, &endptr, 10);
if (endptr == input || errno == ERANGE || number == 0)
return false;
//num_len = endptr - input;
// Skip whitespace after number
while (isspace((unsigned char)*endptr)) endptr++;
/* Parse unit part */
const char* unit_ptr = endptr;
uint64_t multiplier;
char unit[3] = { 0 };
size_t unit_len = strlen(unit_ptr);
/* Handle default unit (kB) if no unit specified */
if (unit_len == 0) {
strcpy(unit, "kB");
multiplier = 1024;
}
else {
/* Convert unit to lowercase for validation */
char u[3] = { 0 };
for (int i = 0; i < 2 && unit_ptr[i]; i++)
u[i] = tolower((unsigned char)unit_ptr[i]);
/* Validate unit and set multiplier */
if (unit_len == 1 && u[0] == 'b') {
strcpy(unit, "B");
multiplier = 1;
}
else if (strcmp(u, "kb") == 0) {
strcpy(unit, "kB");
multiplier = 1024;
}
else if (strcmp(u, "mb") == 0) {
strcpy(unit, "MB");
multiplier = 1024 * 1024;
}
else if (strcmp(u, "gb") == 0) {
strcpy(unit, "GB");
multiplier = 1024ULL * 1024 * 1024;
}
else if (strcmp(u, "tb") == 0) {
strcpy(unit, "TB");
multiplier = 1024ULL * 1024 * 1024 * 1024;
}
else {
return false;
}
/* Validate unit length matches parsed characters */
size_t actual_unit_len = (unit[1] == 'B') ? 2 : (unit[0] == 'B') ? 1 : 0;
if (strlen(unit_ptr) != actual_unit_len)
return false;
}
/* Check for multiplication overflow */
if (number > UINT64_MAX / multiplier)
return false;
uint64_t total_bytes = number * multiplier;
/* Validate PostgreSQL's requirements */
if ((total_bytes / 1024ULL) < params->uint_range.min || (total_bytes / 1024ULL) > params->uint_range.max)
return false;
char output[128];
/* Format output without leading zeros */
int needed = snprintf(output, sizeof(output), "%lu%s", number, unit);
if (needed < 0 || needed >= (int)sizeof(output)) return false;
if (transformed_value)
*transformed_value = strdup(output);
return true;
}
const pgsql_variable_validator pgsql_variable_validator_bool = {
.type = VARIABLE_TYPE_BOOL,
.validate = &pgsql_variable_validate_bool,
@ -235,10 +335,10 @@ const pgsql_variable_validator pgsql_variable_validator_bool = {
};
const pgsql_variable_validator pgsql_variable_validator_extra_float_digits = {
.type = VARIABLE_TYPE_INT,
.validate = &pgsql_variable_validate_integer,
.type = VARIABLE_TYPE_FLOAT,
.validate = &pgsql_variable_validate_float,
.params = {
.int_range = { .min = -15, .max = 3 }
.float_range = { .min = -15.0, .max = 3.0 }
}
};
@ -282,6 +382,8 @@ const pgsql_variable_validator pgsql_variable_validator_datestyle = {
const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = {
.type = VARIABLE_TYPE_MAINTENANCE_WORK_MEM,
.validate = &pgsql_variable_validate_maintenance_work_mem,
.params = {}
.validate = &pgsql_variable_validate_maintenance_work_mem_v2,
.params = {
.uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB
}
};

Loading…
Cancel
Save