diff --git a/include/PgSQL_Variables_Validator.h b/include/PgSQL_Variables_Validator.h index a3483cd0e..ae5469847 100644 --- a/include/PgSQL_Variables_Validator.h +++ b/include/PgSQL_Variables_Validator.h @@ -5,27 +5,28 @@ typedef enum { VARIABLE_TYPE_NONE = 0, /**< No variable type. */ VARIABLE_TYPE_INT, /**< Integer variable type. */ + VARIABLE_TYPE_FLOAT, /**< Float variable type. */ VARIABLE_TYPE_BOOL, /**< Boolean variable type. */ VARIABLE_TYPE_STRING, /**< String variable type. */ VARIABLE_TYPE_DATESTYLE, /**< DateStyle variable type. */ VARIABLE_TYPE_MAINTENANCE_WORK_MEM } pgsql_variable_type_t; -/** - * @struct int_range_t - * @brief Struct representing a range of integer values. - */ -typedef struct { - int min; /**< Minimum value of the range. */ - int max; /**< Maximum value of the range. */ -} int_range_t; + +template +struct range_t { + T min; /**< Minimum value of the range. */ + T max; /**< Maximum value of the range. */ +}; /** * @union params_t * @brief Union representing the parameters for variable validation. */ typedef union { - int_range_t int_range; /**< Integer range parameters. */ + range_t int_range; /**< Integer range parameters. */ + range_t uint_range; /**< Integer range parameters. */ + range_t float_range; /**< Float range parameters. */ const char** string_allowed; /**< Allowed string values. */ } params_t; diff --git a/lib/PgSQL_Variables_Validator.cpp b/lib/PgSQL_Variables_Validator.cpp index d04c22359..eab6caff5 100644 --- a/lib/PgSQL_Variables_Validator.cpp +++ b/lib/PgSQL_Variables_Validator.cpp @@ -25,7 +25,8 @@ bool pgsql_variable_validate_bool(const char* value, const params_t* params, PgS (strcasecmp(value, (char*)"0") == 0) || (strcasecmp(value, (char*)"f") == 0) || (strcasecmp(value, (char*)"false") == 0) || - (strcasecmp(value, (char*)"off") == 0)) { + (strcasecmp(value, (char*)"off") == 0) || + (strcasecmp(value, (char*)"no") == 0)) { if (transformed_value) *transformed_value = strdup("off"); result = true; @@ -33,7 +34,8 @@ bool pgsql_variable_validate_bool(const char* value, const params_t* params, PgS (strcasecmp(value, (char*)"1") == 0) || (strcasecmp(value, (char*)"t") == 0) || (strcasecmp(value, (char*)"true") == 0) || - (strcasecmp(value, (char*)"on") == 0)) { + (strcasecmp(value, (char*)"on") == 0) || + (strcasecmp(value, (char*)"yes") == 0)) { if (transformed_value) *transformed_value = strdup("on"); result = true; @@ -42,25 +44,26 @@ bool pgsql_variable_validate_bool(const char* value, const params_t* params, PgS } /** -* @brief Validates an integer variable for PostgreSQL. +* @brief Validates an float variable for PostgreSQL. * -* This function checks if the provided value is a valid integer representation +* This function checks if the provided value is a valid float representation * and falls within the specified range. The range is defined by the params * parameter. * * @param value The value to validate. -* @param params The parameter structure containing the integer range. +* @param params The parameter structure containing the float range. * @param session Unused parameter. * @param transformed_value If not null, will be set to null. -* @return true if the value is a valid integer representation within the specified range, false otherwise. +* @return true if the value is a valid float representation within the specified range, false otherwise. */ -bool pgsql_variable_validate_integer(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) { +bool pgsql_variable_validate_float(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) { (void)session; if (transformed_value) *transformed_value = nullptr; char* end = nullptr; - long num = strtol(value, &end, 10); + //long num = strtol(value, &end, 10); + double num = strtod(value, &end); if (end == value || *end != '\0') return false; - if (num < params->int_range.min || num > params->int_range.max) return false; + if (num < params->float_range.min || num > params->float_range.max) return false; return true; } @@ -197,9 +200,14 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param // Parse unit if (*p != '\0') { - switch (tolower(*p)) { - case 'k': case 'm': case 'g': case 't': - unit = tolower(*p++); + char tmp_unit = tolower(*p); + switch (tmp_unit) { + case 'k': + case 'm': + case 'g': + case 't': + if (tmp_unit != 'k') + unit = toupper(*p++); has_unit = true; // Check optional 'b'/'B' if (tolower(*p) == 'b') p++; @@ -217,7 +225,7 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param char output[128]; // Format normalized string (always show unit in lowercase) - int written = snprintf(output, sizeof(output), has_unit ? "%lld%cb" : "%lldkb", + int written = snprintf(output, sizeof(output), has_unit ? "%lld%cB" : "%lldkB", num, unit); if (written < 0 || written >= (int)sizeof(output)) return false; @@ -228,6 +236,98 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param return true; } +bool pgsql_variable_validate_maintenance_work_mem_v2(const char* value, const params_t* params, PgSQL_Session* session, char** transformed_value) { + (void)session; + const char* input = value; + + /* Trim leading whitespace */ + while (isspace((unsigned char)*input)) input++; + + /* Parse numeric part */ + uint64_t number; + char* endptr; + //size_t num_len = 0; + errno = 0; + number = strtoull(input, &endptr, 10); + + if (endptr == input || errno == ERANGE || number == 0) + return false; + + //num_len = endptr - input; + + // Skip whitespace after number + while (isspace((unsigned char)*endptr)) endptr++; + + /* Parse unit part */ + const char* unit_ptr = endptr; + uint64_t multiplier; + char unit[3] = { 0 }; + size_t unit_len = strlen(unit_ptr); + + /* Handle default unit (kB) if no unit specified */ + if (unit_len == 0) { + strcpy(unit, "kB"); + multiplier = 1024; + } + else { + /* Convert unit to lowercase for validation */ + char u[3] = { 0 }; + for (int i = 0; i < 2 && unit_ptr[i]; i++) + u[i] = tolower((unsigned char)unit_ptr[i]); + + /* Validate unit and set multiplier */ + if (unit_len == 1 && u[0] == 'b') { + strcpy(unit, "B"); + multiplier = 1; + } + else if (strcmp(u, "kb") == 0) { + strcpy(unit, "kB"); + multiplier = 1024; + } + else if (strcmp(u, "mb") == 0) { + strcpy(unit, "MB"); + multiplier = 1024 * 1024; + } + else if (strcmp(u, "gb") == 0) { + strcpy(unit, "GB"); + multiplier = 1024ULL * 1024 * 1024; + } + else if (strcmp(u, "tb") == 0) { + strcpy(unit, "TB"); + multiplier = 1024ULL * 1024 * 1024 * 1024; + } + else { + return false; + } + + /* Validate unit length matches parsed characters */ + size_t actual_unit_len = (unit[1] == 'B') ? 2 : (unit[0] == 'B') ? 1 : 0; + if (strlen(unit_ptr) != actual_unit_len) + return false; + } + + /* Check for multiplication overflow */ + if (number > UINT64_MAX / multiplier) + return false; + + uint64_t total_bytes = number * multiplier; + + /* Validate PostgreSQL's requirements */ + if ((total_bytes / 1024ULL) < params->uint_range.min || (total_bytes / 1024ULL) > params->uint_range.max) + return false; + + char output[128]; + /* Format output without leading zeros */ + int needed = snprintf(output, sizeof(output), "%lu%s", number, unit); + + if (needed < 0 || needed >= (int)sizeof(output)) return false; + + if (transformed_value) + *transformed_value = strdup(output); + + return true; +} + const pgsql_variable_validator pgsql_variable_validator_bool = { .type = VARIABLE_TYPE_BOOL, .validate = &pgsql_variable_validate_bool, @@ -235,10 +335,10 @@ const pgsql_variable_validator pgsql_variable_validator_bool = { }; const pgsql_variable_validator pgsql_variable_validator_extra_float_digits = { - .type = VARIABLE_TYPE_INT, - .validate = &pgsql_variable_validate_integer, + .type = VARIABLE_TYPE_FLOAT, + .validate = &pgsql_variable_validate_float, .params = { - .int_range = { .min = -15, .max = 3 } + .float_range = { .min = -15.0, .max = 3.0 } } }; @@ -282,6 +382,8 @@ const pgsql_variable_validator pgsql_variable_validator_datestyle = { const pgsql_variable_validator pgsql_variable_validator_maintenance_work_mem = { .type = VARIABLE_TYPE_MAINTENANCE_WORK_MEM, - .validate = &pgsql_variable_validate_maintenance_work_mem, - .params = {} + .validate = &pgsql_variable_validate_maintenance_work_mem_v2, + .params = { + .uint_range = {.min = 1024, .max = 2147483647 } // this range is in kB + } };