From 3ccfa2bcc19b899dfb70fbbf4c572bdb42faf343 Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Fri, 23 Jan 2026 01:03:37 +0000 Subject: [PATCH] Address AI code review feedback for PR #5313 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit addresses concerns raised by AI code reviewers (gemini-code-assist, Copilot, coderabbitai) on the initial security fixes. Critical fixes: - Fix lock.release() → lock.unlock() in GenAI_Thread.cpp worker_loop (lock.release() detaches without unlocking, causing deadlock) - Add missing early return after schema validation failure in Query_Tool_Handler.cpp Code quality improvements: - Improve escape_string() memory management in MySQL_Tool_Handler.cpp: - Use std::string instead of new[]/delete[] for buffer management - Check return value of mysql_real_escape_string() for errors - Remove redundant validation checks in validate_sql_identifier functions (character class loop already rejects unsafe characters) - Add backslash escaping to escape_string_literal() for defense-in-depth - Improve column list validation in MySQL_Tool_Handler sample_rows(): - Replace blacklist approach with proper column identifier parsing - Allow qualified identifiers (table.column) - Allow AS aliases (column AS alias) - No longer rejects legitimate column names containing "JOIN" These changes improve robustness while maintaining the security posture of the original SQL injection fixes. --- lib/GenAI_Thread.cpp | 2 +- lib/MySQL_Tool_Handler.cpp | 118 ++++++++++++++++++++++++++++--------- lib/Query_Tool_Handler.cpp | 26 +++----- 3 files changed, 99 insertions(+), 47 deletions(-) diff --git a/lib/GenAI_Thread.cpp b/lib/GenAI_Thread.cpp index c107435bd..dc763fdbb 100644 --- a/lib/GenAI_Thread.cpp +++ b/lib/GenAI_Thread.cpp @@ -1465,7 +1465,7 @@ void GenAI_Threads_Handler::worker_loop(int worker_id) { } break; } - lock.release(); + lock.unlock(); // Release the lock (not release() which would detach without unlocking) // Process request auto start_time = std::chrono::steady_clock::now(); diff --git a/lib/MySQL_Tool_Handler.cpp b/lib/MySQL_Tool_Handler.cpp index a4c345b0e..9209dd4d8 100644 --- a/lib/MySQL_Tool_Handler.cpp +++ b/lib/MySQL_Tool_Handler.cpp @@ -234,8 +234,7 @@ int MySQL_Tool_Handler::init_connection_pool() { * @brief Validate SQL identifier (table name, column name, schema name) * * Checks that the identifier contains only valid characters (alphanumeric, - * underscore, dollar sign) and doesn't start with a digit. Also checks - * for SQL injection attempts. + * underscore, dollar sign) and doesn't start with a digit. * * @param identifier The identifier to validate * @return true if valid, false otherwise @@ -262,15 +261,6 @@ static bool validate_sql_identifier(const std::string& identifier) { } } - // Check for SQL injection patterns (quoted identifiers, comments, etc.) - if (identifier.find('"') != std::string::npos || - identifier.find('\'') != std::string::npos || - identifier.find('`') != std::string::npos || - identifier.find('-') != std::string::npos || - identifier.find(';') != std::string::npos) { - return false; - } - return true; } @@ -290,16 +280,19 @@ static std::string escape_string(MYSQL* conn, const std::string& value) { } // Allocate buffer for escaped string (2 * input + 1 for null terminator) - unsigned long escaped_length = value.length() * 2 + 1; - char* escaped = new char[escaped_length]; + std::string escaped(value.length() * 2 + 1, '\0'); - // Escape the string - mysql_real_escape_string(conn, escaped, value.c_str(), value.length()); + // Escape the string and check for errors + unsigned long result_len = mysql_real_escape_string(conn, &escaped[0], value.c_str(), value.length()); + if (result_len == (unsigned long)-1) { + // Error during escaping (e.g., invalid character set) + return ""; + } - std::string result(escaped); - delete[] escaped; + // Resize to actual escaped length + escaped.resize(result_len); - return result; + return escaped; } /** @@ -822,18 +815,85 @@ std::string MySQL_Tool_Handler::sample_rows( return result.dump(); } - // Validate columns parameter (if provided) - check for common SQL injection patterns + // Validate columns parameter (if provided) - parse and validate each column if (!columns.empty()) { - // Check for dangerous patterns in columns - std::string upper_columns = columns; - std::transform(upper_columns.begin(), upper_columns.end(), upper_columns.begin(), ::toupper); - if (upper_columns.find("--") != std::string::npos || - upper_columns.find("/*") != std::string::npos || - upper_columns.find(";") != std::string::npos || - upper_columns.find("UNION") != std::string::npos || - upper_columns.find("JOIN") != std::string::npos) { - result["error"] = "Invalid columns parameter: contains unsafe patterns"; - return result.dump(); + // Helper lambda to validate a single column identifier + auto validate_column_identifier = [](const std::string& col) -> bool { + if (col.empty()) return false; + + // Check for basic SQL injection patterns first + if (col.find("--") != std::string::npos || + col.find("/*") != std::string::npos || + col.find(";") != std::string::npos) { + return false; + } + + // Allow: identifier, identifier.identifier, or identifier AS identifier + // This is a simplified check - we validate character by character + bool has_dot = false; + bool in_identifier = true; + int identifier_count = 0; + + for (size_t i = 0; i < col.length(); i++) { + char c = col[i]; + + // Skip whitespace + if (isspace(c)) { + in_identifier = false; + continue; + } + + // Check for "AS" keyword (case-insensitive) + if (!in_identifier && i + 1 < col.length()) { + if ((c == 'A' || c == 'a') && + (col[i + 1] == 'S' || col[i + 1] == 's')) { + i++; // Skip the 'S' + in_identifier = false; + continue; + } + } + + // Check for dot (qualified identifier like table.column) + if (c == '.') { + if (has_dot || identifier_count == 0) { + return false; // Multiple dots or dot at start + } + has_dot = true; + in_identifier = true; + continue; + } + + // Must be valid identifier character + if (!isalnum(c) && c != '_' && c != '$') { + return false; + } + + if (!in_identifier) { + in_identifier = true; + identifier_count++; + } + } + + // Must have at least one valid identifier + return identifier_count > 0; + }; + + // Parse comma-separated column list and validate each part + std::stringstream col_stream(columns); + std::string col_part; + while (std::getline(col_stream, col_part, ',')) { + // Trim whitespace + size_t start = col_part.find_first_not_of(" \t\n\r"); + size_t end = col_part.find_last_not_of(" \t\n\r"); + if (start == std::string::npos || end == std::string::npos) { + continue; // Skip empty segments + } + std::string trimmed = col_part.substr(start, end - start + 1); + + if (!validate_column_identifier(trimmed)) { + result["error"] = "Invalid columns parameter: '" + trimmed + "' is not a valid column specification"; + return result.dump(); + } } } diff --git a/lib/Query_Tool_Handler.cpp b/lib/Query_Tool_Handler.cpp index 5de200ece..79f2abcd9 100644 --- a/lib/Query_Tool_Handler.cpp +++ b/lib/Query_Tool_Handler.cpp @@ -130,9 +130,8 @@ static double json_double(const json& j, const std::string& key, double default_ /** * @brief Validate and escape a SQL identifier (table name, column name, etc.) * - * For SQLite, we validate that the identifier contains only safe characters - * and quote it with double quotes if needed. This prevents SQL injection - * while allowing valid identifiers. + * For SQLite, we validate that the identifier contains only safe characters. + * This prevents SQL injection while allowing valid identifiers. * * @param identifier The identifier to validate/escape * @return Empty string if unsafe, otherwise the validated identifier @@ -159,35 +158,27 @@ static std::string validate_sql_identifier_sqlite(const std::string& identifier) } } - // Check for SQL injection patterns - if (identifier.find('"') != std::string::npos || - identifier.find('\'') != std::string::npos || - identifier.find('`') != std::string::npos || - identifier.find('-') != std::string::npos || - identifier.find(';') != std::string::npos || - identifier.find('/') != std::string::npos) { - return ""; - } - return identifier; } /** * @brief Escape a SQL string literal for use in queries * - * Doubles single quotes to escape them for SQL string literals. - * This is the standard SQL escaping mechanism. + * Escapes single quotes by doubling them (standard SQL) and also escapes + * backslashes for defense-in-depth (important for MySQL with certain modes). * * @param value The string value to escape * @return Escaped string safe for use in SQL queries */ static std::string escape_string_literal(const std::string& value) { std::string escaped; - escaped.reserve(value.length() * 2); + escaped.reserve(value.length() * 2 + 1); for (char c : value) { if (c == '\'') { - escaped += "''"; // Double single quotes to escape + escaped += "''"; // Double single quotes to escape (SQL standard) + } else if (c == '\\') { + escaped += "\\\\"; // Escape backslash (defense-in-depth) } else { escaped += c; } @@ -1017,6 +1008,7 @@ json Query_Tool_Handler::execute_tool(const std::string& tool_name, const json& std::string validated = validate_sql_identifier_sqlite(schema); if (validated.empty()) { result = create_error_response("Invalid schema name: contains unsafe characters"); + return result; // Early return on validation failure } else { schema = validated; }