@ -234,8 +234,7 @@ int MySQL_Tool_Handler::init_connection_pool() {
* @ brief Validate SQL identifier ( table name , column name , schema name )
* @ brief Validate SQL identifier ( table name , column name , schema name )
*
*
* Checks that the identifier contains only valid characters ( alphanumeric ,
* Checks that the identifier contains only valid characters ( alphanumeric ,
* underscore , dollar sign ) and doesn ' t start with a digit . Also checks
* underscore , dollar sign ) and doesn ' t start with a digit .
* for SQL injection attempts .
*
*
* @ param identifier The identifier to validate
* @ param identifier The identifier to validate
* @ return true if valid , false otherwise
* @ return true if valid , false otherwise
@ -262,15 +261,6 @@ static bool validate_sql_identifier(const std::string& identifier) {
}
}
}
}
// Check for SQL injection patterns (quoted identifiers, comments, etc.)
if ( identifier . find ( ' " ' ) ! = std : : string : : npos | |
identifier . find ( ' \' ' ) ! = std : : string : : npos | |
identifier . find ( ' ` ' ) ! = std : : string : : npos | |
identifier . find ( ' - ' ) ! = std : : string : : npos | |
identifier . find ( ' ; ' ) ! = std : : string : : npos ) {
return false ;
}
return true ;
return true ;
}
}
@ -290,16 +280,19 @@ static std::string escape_string(MYSQL* conn, const std::string& value) {
}
}
// Allocate buffer for escaped string (2 * input + 1 for null terminator)
// Allocate buffer for escaped string (2 * input + 1 for null terminator)
unsigned long escaped_length = value . length ( ) * 2 + 1 ;
std : : string escaped ( value . length ( ) * 2 + 1 , ' \0 ' ) ;
char * escaped = new char [ escaped_length ] ;
// Escape the string
// Escape the string and check for errors
mysql_real_escape_string ( conn , escaped , value . c_str ( ) , value . length ( ) ) ;
unsigned long result_len = mysql_real_escape_string ( conn , & escaped [ 0 ] , value . c_str ( ) , value . length ( ) ) ;
if ( result_len = = ( unsigned long ) - 1 ) {
// Error during escaping (e.g., invalid character set)
return " " ;
}
std : : string result ( escaped ) ;
// Resize to actual escaped length
delete [ ] escaped ;
escaped . resize ( result_len ) ;
return result ;
return escaped ;
}
}
/**
/**
@ -822,18 +815,85 @@ std::string MySQL_Tool_Handler::sample_rows(
return result . dump ( ) ;
return result . dump ( ) ;
}
}
// Validate columns parameter (if provided) - check for common SQL injection patterns
// Validate columns parameter (if provided) - parse and validate each column
if ( ! columns . empty ( ) ) {
if ( ! columns . empty ( ) ) {
// Check for dangerous patterns in columns
// Helper lambda to validate a single column identifier
std : : string upper_columns = columns ;
auto validate_column_identifier = [ ] ( const std : : string & col ) - > bool {
std : : transform ( upper_columns . begin ( ) , upper_columns . end ( ) , upper_columns . begin ( ) , : : toupper ) ;
if ( col . empty ( ) ) return false ;
if ( upper_columns . find ( " -- " ) ! = std : : string : : npos | |
upper_columns . find ( " /* " ) ! = std : : string : : npos | |
// Check for basic SQL injection patterns first
upper_columns . find ( " ; " ) ! = std : : string : : npos | |
if ( col . find ( " -- " ) ! = std : : string : : npos | |
upper_columns . find ( " UNION " ) ! = std : : string : : npos | |
col . find ( " /* " ) ! = std : : string : : npos | |
upper_columns . find ( " JOIN " ) ! = std : : string : : npos ) {
col . find ( " ; " ) ! = std : : string : : npos ) {
result [ " error " ] = " Invalid columns parameter: contains unsafe patterns " ;
return false ;
return result . dump ( ) ;
}
// Allow: identifier, identifier.identifier, or identifier AS identifier
// This is a simplified check - we validate character by character
bool has_dot = false ;
bool in_identifier = true ;
int identifier_count = 0 ;
for ( size_t i = 0 ; i < col . length ( ) ; i + + ) {
char c = col [ i ] ;
// Skip whitespace
if ( isspace ( c ) ) {
in_identifier = false ;
continue ;
}
// Check for "AS" keyword (case-insensitive)
if ( ! in_identifier & & i + 1 < col . length ( ) ) {
if ( ( c = = ' A ' | | c = = ' a ' ) & &
( col [ i + 1 ] = = ' S ' | | col [ i + 1 ] = = ' s ' ) ) {
i + + ; // Skip the 'S'
in_identifier = false ;
continue ;
}
}
// Check for dot (qualified identifier like table.column)
if ( c = = ' . ' ) {
if ( has_dot | | identifier_count = = 0 ) {
return false ; // Multiple dots or dot at start
}
has_dot = true ;
in_identifier = true ;
continue ;
}
// Must be valid identifier character
if ( ! isalnum ( c ) & & c ! = ' _ ' & & c ! = ' $ ' ) {
return false ;
}
if ( ! in_identifier ) {
in_identifier = true ;
identifier_count + + ;
}
}
// Must have at least one valid identifier
return identifier_count > 0 ;
} ;
// Parse comma-separated column list and validate each part
std : : stringstream col_stream ( columns ) ;
std : : string col_part ;
while ( std : : getline ( col_stream , col_part , ' , ' ) ) {
// Trim whitespace
size_t start = col_part . find_first_not_of ( " \t \n \r " ) ;
size_t end = col_part . find_last_not_of ( " \t \n \r " ) ;
if ( start = = std : : string : : npos | | end = = std : : string : : npos ) {
continue ; // Skip empty segments
}
std : : string trimmed = col_part . substr ( start , end - start + 1 ) ;
if ( ! validate_column_identifier ( trimmed ) ) {
result [ " error " ] = " Invalid columns parameter: ' " + trimmed + " ' is not a valid column specification " ;
return result . dump ( ) ;
}
}
}
}
}