Address AI code review feedback for PR #5313

This commit addresses concerns raised by AI code reviewers (gemini-code-assist,
Copilot, coderabbitai) on the initial security fixes.

Critical fixes:
- Fix lock.release() → lock.unlock() in GenAI_Thread.cpp worker_loop
  (lock.release() detaches without unlocking, causing deadlock)
- Add missing early return after schema validation failure in Query_Tool_Handler.cpp

Code quality improvements:
- Improve escape_string() memory management in MySQL_Tool_Handler.cpp:
  - Use std::string instead of new[]/delete[] for buffer management
  - Check return value of mysql_real_escape_string() for errors
- Remove redundant validation checks in validate_sql_identifier functions
  (character class loop already rejects unsafe characters)
- Add backslash escaping to escape_string_literal() for defense-in-depth
- Improve column list validation in MySQL_Tool_Handler sample_rows():
  - Replace blacklist approach with proper column identifier parsing
  - Allow qualified identifiers (table.column)
  - Allow AS aliases (column AS alias)
  - No longer rejects legitimate column names containing "JOIN"

These changes improve robustness while maintaining the security posture
of the original SQL injection fixes.
pr-5312-fixes
Rene Cannao 3 months ago
parent c914feb230
commit 3ccfa2bcc1

@ -1465,7 +1465,7 @@ void GenAI_Threads_Handler::worker_loop(int worker_id) {
}
break;
}
lock.release();
lock.unlock(); // Release the lock (not release() which would detach without unlocking)
// Process request
auto start_time = std::chrono::steady_clock::now();

@ -234,8 +234,7 @@ int MySQL_Tool_Handler::init_connection_pool() {
* @brief Validate SQL identifier (table name, column name, schema name)
*
* Checks that the identifier contains only valid characters (alphanumeric,
* underscore, dollar sign) and doesn't start with a digit. Also checks
* for SQL injection attempts.
* underscore, dollar sign) and doesn't start with a digit.
*
* @param identifier The identifier to validate
* @return true if valid, false otherwise
@ -262,15 +261,6 @@ static bool validate_sql_identifier(const std::string& identifier) {
}
}
// Check for SQL injection patterns (quoted identifiers, comments, etc.)
if (identifier.find('"') != std::string::npos ||
identifier.find('\'') != std::string::npos ||
identifier.find('`') != std::string::npos ||
identifier.find('-') != std::string::npos ||
identifier.find(';') != std::string::npos) {
return false;
}
return true;
}
@ -290,16 +280,19 @@ static std::string escape_string(MYSQL* conn, const std::string& value) {
}
// Allocate buffer for escaped string (2 * input + 1 for null terminator)
unsigned long escaped_length = value.length() * 2 + 1;
char* escaped = new char[escaped_length];
std::string escaped(value.length() * 2 + 1, '\0');
// Escape the string
mysql_real_escape_string(conn, escaped, value.c_str(), value.length());
// Escape the string and check for errors
unsigned long result_len = mysql_real_escape_string(conn, &escaped[0], value.c_str(), value.length());
if (result_len == (unsigned long)-1) {
// Error during escaping (e.g., invalid character set)
return "";
}
std::string result(escaped);
delete[] escaped;
// Resize to actual escaped length
escaped.resize(result_len);
return result;
return escaped;
}
/**
@ -822,18 +815,85 @@ std::string MySQL_Tool_Handler::sample_rows(
return result.dump();
}
// Validate columns parameter (if provided) - check for common SQL injection patterns
// Validate columns parameter (if provided) - parse and validate each column
if (!columns.empty()) {
// Check for dangerous patterns in columns
std::string upper_columns = columns;
std::transform(upper_columns.begin(), upper_columns.end(), upper_columns.begin(), ::toupper);
if (upper_columns.find("--") != std::string::npos ||
upper_columns.find("/*") != std::string::npos ||
upper_columns.find(";") != std::string::npos ||
upper_columns.find("UNION") != std::string::npos ||
upper_columns.find("JOIN") != std::string::npos) {
result["error"] = "Invalid columns parameter: contains unsafe patterns";
return result.dump();
// Helper lambda to validate a single column identifier
auto validate_column_identifier = [](const std::string& col) -> bool {
if (col.empty()) return false;
// Check for basic SQL injection patterns first
if (col.find("--") != std::string::npos ||
col.find("/*") != std::string::npos ||
col.find(";") != std::string::npos) {
return false;
}
// Allow: identifier, identifier.identifier, or identifier AS identifier
// This is a simplified check - we validate character by character
bool has_dot = false;
bool in_identifier = true;
int identifier_count = 0;
for (size_t i = 0; i < col.length(); i++) {
char c = col[i];
// Skip whitespace
if (isspace(c)) {
in_identifier = false;
continue;
}
// Check for "AS" keyword (case-insensitive)
if (!in_identifier && i + 1 < col.length()) {
if ((c == 'A' || c == 'a') &&
(col[i + 1] == 'S' || col[i + 1] == 's')) {
i++; // Skip the 'S'
in_identifier = false;
continue;
}
}
// Check for dot (qualified identifier like table.column)
if (c == '.') {
if (has_dot || identifier_count == 0) {
return false; // Multiple dots or dot at start
}
has_dot = true;
in_identifier = true;
continue;
}
// Must be valid identifier character
if (!isalnum(c) && c != '_' && c != '$') {
return false;
}
if (!in_identifier) {
in_identifier = true;
identifier_count++;
}
}
// Must have at least one valid identifier
return identifier_count > 0;
};
// Parse comma-separated column list and validate each part
std::stringstream col_stream(columns);
std::string col_part;
while (std::getline(col_stream, col_part, ',')) {
// Trim whitespace
size_t start = col_part.find_first_not_of(" \t\n\r");
size_t end = col_part.find_last_not_of(" \t\n\r");
if (start == std::string::npos || end == std::string::npos) {
continue; // Skip empty segments
}
std::string trimmed = col_part.substr(start, end - start + 1);
if (!validate_column_identifier(trimmed)) {
result["error"] = "Invalid columns parameter: '" + trimmed + "' is not a valid column specification";
return result.dump();
}
}
}

@ -130,9 +130,8 @@ static double json_double(const json& j, const std::string& key, double default_
/**
* @brief Validate and escape a SQL identifier (table name, column name, etc.)
*
* For SQLite, we validate that the identifier contains only safe characters
* and quote it with double quotes if needed. This prevents SQL injection
* while allowing valid identifiers.
* For SQLite, we validate that the identifier contains only safe characters.
* This prevents SQL injection while allowing valid identifiers.
*
* @param identifier The identifier to validate/escape
* @return Empty string if unsafe, otherwise the validated identifier
@ -159,35 +158,27 @@ static std::string validate_sql_identifier_sqlite(const std::string& identifier)
}
}
// Check for SQL injection patterns
if (identifier.find('"') != std::string::npos ||
identifier.find('\'') != std::string::npos ||
identifier.find('`') != std::string::npos ||
identifier.find('-') != std::string::npos ||
identifier.find(';') != std::string::npos ||
identifier.find('/') != std::string::npos) {
return "";
}
return identifier;
}
/**
* @brief Escape a SQL string literal for use in queries
*
* Doubles single quotes to escape them for SQL string literals.
* This is the standard SQL escaping mechanism.
* Escapes single quotes by doubling them (standard SQL) and also escapes
* backslashes for defense-in-depth (important for MySQL with certain modes).
*
* @param value The string value to escape
* @return Escaped string safe for use in SQL queries
*/
static std::string escape_string_literal(const std::string& value) {
std::string escaped;
escaped.reserve(value.length() * 2);
escaped.reserve(value.length() * 2 + 1);
for (char c : value) {
if (c == '\'') {
escaped += "''"; // Double single quotes to escape
escaped += "''"; // Double single quotes to escape (SQL standard)
} else if (c == '\\') {
escaped += "\\\\"; // Escape backslash (defense-in-depth)
} else {
escaped += c;
}
@ -1017,6 +1008,7 @@ json Query_Tool_Handler::execute_tool(const std::string& tool_name, const json&
std::string validated = validate_sql_identifier_sqlite(schema);
if (validated.empty()) {
result = create_error_response("Invalid schema name: contains unsafe characters");
return result; // Early return on validation failure
} else {
schema = validated;
}

Loading…
Cancel
Save