Address AI code review feedback for PR #5313

This commit addresses concerns raised by AI code reviewers (gemini-code-assist,
Copilot, coderabbitai) on the initial security fixes.

Critical fixes:
- Fix lock.release() → lock.unlock() in GenAI_Thread.cpp worker_loop
  (lock.release() detaches without unlocking, causing deadlock)
- Add missing early return after schema validation failure in Query_Tool_Handler.cpp

Code quality improvements:
- Improve escape_string() memory management in MySQL_Tool_Handler.cpp:
  - Use std::string instead of new[]/delete[] for buffer management
  - Check return value of mysql_real_escape_string() for errors
- Remove redundant validation checks in validate_sql_identifier functions
  (character class loop already rejects unsafe characters)
- Add backslash escaping to escape_string_literal() for defense-in-depth
- Improve column list validation in MySQL_Tool_Handler sample_rows():
  - Replace blacklist approach with proper column identifier parsing
  - Allow qualified identifiers (table.column)
  - Allow AS aliases (column AS alias)
  - No longer rejects legitimate column names containing "JOIN"

These changes improve robustness while maintaining the security posture
of the original SQL injection fixes.
pull/5313/head
Rene Cannao 3 months ago
parent c914feb230
commit 3ccfa2bcc1

@ -1465,7 +1465,7 @@ void GenAI_Threads_Handler::worker_loop(int worker_id) {
} }
break; break;
} }
lock.release(); lock.unlock(); // Release the lock (not release() which would detach without unlocking)
// Process request // Process request
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();

@ -234,8 +234,7 @@ int MySQL_Tool_Handler::init_connection_pool() {
* @brief Validate SQL identifier (table name, column name, schema name) * @brief Validate SQL identifier (table name, column name, schema name)
* *
* Checks that the identifier contains only valid characters (alphanumeric, * Checks that the identifier contains only valid characters (alphanumeric,
* underscore, dollar sign) and doesn't start with a digit. Also checks * underscore, dollar sign) and doesn't start with a digit.
* for SQL injection attempts.
* *
* @param identifier The identifier to validate * @param identifier The identifier to validate
* @return true if valid, false otherwise * @return true if valid, false otherwise
@ -262,15 +261,6 @@ static bool validate_sql_identifier(const std::string& identifier) {
} }
} }
// Check for SQL injection patterns (quoted identifiers, comments, etc.)
if (identifier.find('"') != std::string::npos ||
identifier.find('\'') != std::string::npos ||
identifier.find('`') != std::string::npos ||
identifier.find('-') != std::string::npos ||
identifier.find(';') != std::string::npos) {
return false;
}
return true; return true;
} }
@ -290,16 +280,19 @@ static std::string escape_string(MYSQL* conn, const std::string& value) {
} }
// Allocate buffer for escaped string (2 * input + 1 for null terminator) // Allocate buffer for escaped string (2 * input + 1 for null terminator)
unsigned long escaped_length = value.length() * 2 + 1; std::string escaped(value.length() * 2 + 1, '\0');
char* escaped = new char[escaped_length];
// Escape the string // Escape the string and check for errors
mysql_real_escape_string(conn, escaped, value.c_str(), value.length()); unsigned long result_len = mysql_real_escape_string(conn, &escaped[0], value.c_str(), value.length());
if (result_len == (unsigned long)-1) {
// Error during escaping (e.g., invalid character set)
return "";
}
std::string result(escaped); // Resize to actual escaped length
delete[] escaped; escaped.resize(result_len);
return result; return escaped;
} }
/** /**
@ -822,18 +815,85 @@ std::string MySQL_Tool_Handler::sample_rows(
return result.dump(); return result.dump();
} }
// Validate columns parameter (if provided) - check for common SQL injection patterns // Validate columns parameter (if provided) - parse and validate each column
if (!columns.empty()) { if (!columns.empty()) {
// Check for dangerous patterns in columns // Helper lambda to validate a single column identifier
std::string upper_columns = columns; auto validate_column_identifier = [](const std::string& col) -> bool {
std::transform(upper_columns.begin(), upper_columns.end(), upper_columns.begin(), ::toupper); if (col.empty()) return false;
if (upper_columns.find("--") != std::string::npos ||
upper_columns.find("/*") != std::string::npos || // Check for basic SQL injection patterns first
upper_columns.find(";") != std::string::npos || if (col.find("--") != std::string::npos ||
upper_columns.find("UNION") != std::string::npos || col.find("/*") != std::string::npos ||
upper_columns.find("JOIN") != std::string::npos) { col.find(";") != std::string::npos) {
result["error"] = "Invalid columns parameter: contains unsafe patterns"; return false;
return result.dump(); }
// Allow: identifier, identifier.identifier, or identifier AS identifier
// This is a simplified check - we validate character by character
bool has_dot = false;
bool in_identifier = true;
int identifier_count = 0;
for (size_t i = 0; i < col.length(); i++) {
char c = col[i];
// Skip whitespace
if (isspace(c)) {
in_identifier = false;
continue;
}
// Check for "AS" keyword (case-insensitive)
if (!in_identifier && i + 1 < col.length()) {
if ((c == 'A' || c == 'a') &&
(col[i + 1] == 'S' || col[i + 1] == 's')) {
i++; // Skip the 'S'
in_identifier = false;
continue;
}
}
// Check for dot (qualified identifier like table.column)
if (c == '.') {
if (has_dot || identifier_count == 0) {
return false; // Multiple dots or dot at start
}
has_dot = true;
in_identifier = true;
continue;
}
// Must be valid identifier character
if (!isalnum(c) && c != '_' && c != '$') {
return false;
}
if (!in_identifier) {
in_identifier = true;
identifier_count++;
}
}
// Must have at least one valid identifier
return identifier_count > 0;
};
// Parse comma-separated column list and validate each part
std::stringstream col_stream(columns);
std::string col_part;
while (std::getline(col_stream, col_part, ',')) {
// Trim whitespace
size_t start = col_part.find_first_not_of(" \t\n\r");
size_t end = col_part.find_last_not_of(" \t\n\r");
if (start == std::string::npos || end == std::string::npos) {
continue; // Skip empty segments
}
std::string trimmed = col_part.substr(start, end - start + 1);
if (!validate_column_identifier(trimmed)) {
result["error"] = "Invalid columns parameter: '" + trimmed + "' is not a valid column specification";
return result.dump();
}
} }
} }

@ -130,9 +130,8 @@ static double json_double(const json& j, const std::string& key, double default_
/** /**
* @brief Validate and escape a SQL identifier (table name, column name, etc.) * @brief Validate and escape a SQL identifier (table name, column name, etc.)
* *
* For SQLite, we validate that the identifier contains only safe characters * For SQLite, we validate that the identifier contains only safe characters.
* and quote it with double quotes if needed. This prevents SQL injection * This prevents SQL injection while allowing valid identifiers.
* while allowing valid identifiers.
* *
* @param identifier The identifier to validate/escape * @param identifier The identifier to validate/escape
* @return Empty string if unsafe, otherwise the validated identifier * @return Empty string if unsafe, otherwise the validated identifier
@ -159,35 +158,27 @@ static std::string validate_sql_identifier_sqlite(const std::string& identifier)
} }
} }
// Check for SQL injection patterns
if (identifier.find('"') != std::string::npos ||
identifier.find('\'') != std::string::npos ||
identifier.find('`') != std::string::npos ||
identifier.find('-') != std::string::npos ||
identifier.find(';') != std::string::npos ||
identifier.find('/') != std::string::npos) {
return "";
}
return identifier; return identifier;
} }
/** /**
* @brief Escape a SQL string literal for use in queries * @brief Escape a SQL string literal for use in queries
* *
* Doubles single quotes to escape them for SQL string literals. * Escapes single quotes by doubling them (standard SQL) and also escapes
* This is the standard SQL escaping mechanism. * backslashes for defense-in-depth (important for MySQL with certain modes).
* *
* @param value The string value to escape * @param value The string value to escape
* @return Escaped string safe for use in SQL queries * @return Escaped string safe for use in SQL queries
*/ */
static std::string escape_string_literal(const std::string& value) { static std::string escape_string_literal(const std::string& value) {
std::string escaped; std::string escaped;
escaped.reserve(value.length() * 2); escaped.reserve(value.length() * 2 + 1);
for (char c : value) { for (char c : value) {
if (c == '\'') { if (c == '\'') {
escaped += "''"; // Double single quotes to escape escaped += "''"; // Double single quotes to escape (SQL standard)
} else if (c == '\\') {
escaped += "\\\\"; // Escape backslash (defense-in-depth)
} else { } else {
escaped += c; escaped += c;
} }
@ -1017,6 +1008,7 @@ json Query_Tool_Handler::execute_tool(const std::string& tool_name, const json&
std::string validated = validate_sql_identifier_sqlite(schema); std::string validated = validate_sql_identifier_sqlite(schema);
if (validated.empty()) { if (validated.empty()) {
result = create_error_response("Invalid schema name: contains unsafe characters"); result = create_error_response("Invalid schema name: contains unsafe characters");
return result; // Early return on validation failure
} else { } else {
schema = validated; schema = validated;
} }

Loading…
Cancel
Save