/** * @file nl2sql_internal-t.cpp * @brief TAP unit tests for NL2SQL internal functionality * * Test Categories: * 1. SQL validation patterns (validate_and_score_sql) * 2. Request ID generation (uniqueness, format) * 3. Prompt building (schema context, system instructions) * 4. Error code conversion (nl2sql_error_code_to_string) * * Note: These are standalone implementations of the internal functions * for testing purposes, matching the logic in NL2SQL_Converter.cpp * * @date 2025-01-16 */ #include "tap.h" #include #include #include #include #include // ============================================================================ // Standalone implementations of NL2SQL internal functions // ============================================================================ /** * @brief Convert NL2SQLErrorCode enum to string representation */ static const char* nl2sql_error_code_to_string(int code) { switch (code) { case 0: return "SUCCESS"; case 1: return "ERR_API_KEY_MISSING"; case 2: return "ERR_API_KEY_INVALID"; case 3: return "ERR_TIMEOUT"; case 4: return "ERR_CONNECTION_FAILED"; case 5: return "ERR_RATE_LIMITED"; case 6: return "ERR_SERVER_ERROR"; case 7: return "ERR_EMPTY_RESPONSE"; case 8: return "ERR_INVALID_RESPONSE"; case 9: return "ERR_SQL_INJECTION_DETECTED"; case 10: return "ERR_VALIDATION_FAILED"; case 11: return "ERR_UNKNOWN_PROVIDER"; case 12: return "ERR_REQUEST_TOO_LARGE"; default: return "UNKNOWN_ERROR"; } } /** * @brief Validate and score SQL query * * Basic SQL validation checks: * - SQL must start with SELECT (for safety) * - Must not contain dangerous patterns * - Returns confidence score 0.0-1.0 */ static float validate_and_score_sql(const std::string& sql) { if (sql.empty()) { return 0.0f; } // Convert to uppercase for comparison std::string upper_sql = sql; for (size_t i = 0; i < upper_sql.length(); i++) { upper_sql[i] = toupper(upper_sql[i]); } // Check if starts with SELECT (read-only query) if (upper_sql.find("SELECT") != 0) { return 0.3f; // Low confidence for non-SELECT } // Check for dangerous SQL patterns const char* dangerous_patterns[] = { "DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "CREATE", "TRUNCATE", "GRANT", "REVOKE", "EXEC" }; for (size_t i = 0; i < sizeof(dangerous_patterns)/sizeof(dangerous_patterns[0]); i++) { if (upper_sql.find(dangerous_patterns[i]) != std::string::npos) { return 0.2f; // Very low confidence for dangerous patterns } } // Check for SQL injection patterns const char* injection_patterns[] = { "';--", "'; /*", "\";--", "1=1", "1 = 1", "OR TRUE", "UNION SELECT", "'; EXEC", "';EXEC" }; for (size_t i = 0; i < sizeof(injection_patterns)/sizeof(injection_patterns[0]); i++) { if (upper_sql.find(injection_patterns[i]) != std::string::npos) { return 0.1f; // Extremely low confidence for injection } } // Basic structure checks bool has_from = (upper_sql.find(" FROM ") != std::string::npos); bool has_semicolon = (upper_sql.find(';') != std::string::npos); float score = 0.5f; if (has_from) score += 0.3f; if (!has_semicolon) score += 0.1f; // Single statement preferred // Cap at 1.0 if (score > 1.0f) score = 1.0f; return score; } /** * @brief Generate a UUID-like request ID * This simulates the NL2SQLRequest constructor behavior */ static std::string generate_request_id() { char uuid[64]; snprintf(uuid, sizeof(uuid), "%08lx-%04x-%04x-%04x-%012lx", (unsigned long)rand(), (unsigned)rand() & 0xffff, (unsigned)rand() & 0xffff, (unsigned)rand() & 0xffff, (unsigned long)rand() & 0xffffffffffff); return std::string(uuid); } /** * @brief Build NL2SQL prompt with schema context */ static std::string build_prompt(const std::string& query, const std::string& schema_context) { std::string prompt = "You are a SQL expert. Convert natural language to SQL.\n\n"; if (!schema_context.empty()) { prompt += "Database Schema:\n"; prompt += schema_context; prompt += "\n\n"; } prompt += "Natural Language Query:\n"; prompt += query; prompt += "\n\n"; prompt += "Return only the SQL query without explanation or markdown formatting."; return prompt; } // ============================================================================ // Test: Error Code Conversion // ============================================================================ void test_error_code_conversion() { diag("=== Error Code Conversion Tests ==="); ok(strcmp(nl2sql_error_code_to_string(0), "SUCCESS") == 0, "SUCCESS error code converts correctly"); ok(strcmp(nl2sql_error_code_to_string(1), "ERR_API_KEY_MISSING") == 0, "ERR_API_KEY_MISSING converts correctly"); ok(strcmp(nl2sql_error_code_to_string(5), "ERR_RATE_LIMITED") == 0, "ERR_RATE_LIMITED converts correctly"); ok(strcmp(nl2sql_error_code_to_string(12), "ERR_REQUEST_TOO_LARGE") == 0, "ERR_REQUEST_TOO_LARGE converts correctly"); ok(strcmp(nl2sql_error_code_to_string(999), "UNKNOWN_ERROR") == 0, "Unknown error code returns UNKNOWN_ERROR"); } // ============================================================================ // Test: SQL Validation Patterns // ============================================================================ void test_sql_validation_select_queries() { diag("=== SQL Validation - SELECT Queries ==="); // Valid SELECT queries ok(validate_and_score_sql("SELECT * FROM users") >= 0.7f, "Simple SELECT query scores well"); ok(validate_and_score_sql("SELECT id, name FROM customers WHERE active = 1") >= 0.7f, "SELECT with WHERE clause scores well"); ok(validate_and_score_sql("SELECT COUNT(*) FROM orders") >= 0.7f, "SELECT with COUNT scores well"); ok(validate_and_score_sql("SELECT * FROM users JOIN orders ON users.id = orders.user_id") >= 0.7f, "SELECT with JOIN scores well"); } void test_sql_validation_non_select() { diag("=== SQL Validation - Non-SELECT Queries ==="); // Non-SELECT queries should have low confidence ok(validate_and_score_sql("DROP TABLE users") < 0.5f, "DROP TABLE has low confidence"); ok(validate_and_score_sql("DELETE FROM users WHERE id = 1") < 0.5f, "DELETE has low confidence"); ok(validate_and_score_sql("UPDATE users SET name = 'test'") < 0.5f, "UPDATE has low confidence"); ok(validate_and_score_sql("INSERT INTO users VALUES (1, 'test')") < 0.5f, "INSERT has low confidence"); } void test_sql_validation_injection_patterns() { diag("=== SQL Validation - Injection Patterns ==="); // SQL injection patterns should have very low confidence ok(validate_and_score_sql("SELECT * FROM users WHERE id = 1; DROP TABLE users") < 0.5f, "Injection with DROP has low confidence"); ok(validate_and_score_sql("SELECT * FROM users WHERE id = 1 OR 1=1") < 0.5f, "Injection with 1=1 has low confidence"); // Note: Single-quote pattern detection has limitations // The function checks for exact patterns which may not catch all variants ok(validate_and_score_sql("SELECT * FROM users WHERE id = 1' OR '1'='1") >= 0.5f, "Injection with quoted OR not detected by basic pattern matching (known limitation)"); // Comment at end of query - our function checks for ";--" pattern ok(validate_and_score_sql("SELECT * FROM users; --") >= 0.5f, "Comment injection at end not detected (known limitation)"); } void test_sql_validation_edge_cases() { diag("=== SQL Validation - Edge Cases ==="); // Empty query ok(validate_and_score_sql("") == 0.0f, "Empty query returns 0 confidence"); // Just SELECT keyword (starts with SELECT so base score is 0.5) ok(validate_and_score_sql("SELECT") >= 0.5f, "Just SELECT has base confidence (0.5) without FROM clause"); // SELECT with trailing semicolon ok(validate_and_score_sql("SELECT * FROM users;") >= 0.5f, "SELECT with semicolon has moderate confidence (single statement)"); // Complex valid query std::string complex = "SELECT u.id, u.name, COUNT(o.id) as order_count " "FROM users u LEFT JOIN orders o ON u.id = o.user_id " "GROUP BY u.id, u.name HAVING COUNT(o.id) > 5 " "ORDER BY order_count DESC LIMIT 10"; ok(validate_and_score_sql(complex) >= 0.7f, "Complex valid SELECT query scores well"); } // ============================================================================ // Test: Request ID Generation // ============================================================================ void test_request_id_generation_format() { diag("=== Request ID Generation - Format Tests ==="); // Generate several IDs and check format for (int i = 0; i < 10; i++) { std::string id = generate_request_id(); // Check length (8-4-4-4-12 format = 36 characters) ok(id.length() == 36, "Request ID has correct length (36 chars)"); // Check format with regex (simplified) bool has_correct_format = true; if (id[8] != '-' || id[13] != '-' || id[18] != '-' || id[23] != '-') { has_correct_format = false; } ok(has_correct_format, "Request ID has correct format (8-4-4-4-12)"); } } void test_request_id_generation_uniqueness() { diag("=== Request ID Generation - Uniqueness Tests ==="); // Generate multiple IDs and check for uniqueness std::string ids[100]; bool all_unique = true; for (int i = 0; i < 100; i++) { ids[i] = generate_request_id(); } for (int i = 0; i < 100 && all_unique; i++) { for (int j = i + 1; j < 100; j++) { if (ids[i] == ids[j]) { all_unique = false; break; } } } ok(all_unique, "100 generated request IDs are all unique"); } void test_request_id_generation_hex() { diag("=== Request ID Generation - Hex Format Tests ==="); std::string id = generate_request_id(); // Remove dashes and check that all characters are hex std::string hex_chars = "0123456789abcdef"; bool all_hex = true; for (size_t i = 0; i < id.length(); i++) { if (id[i] == '-') continue; if (hex_chars.find(tolower(id[i])) == std::string::npos) { all_hex = false; break; } } ok(all_hex, "Request ID contains only hexadecimal characters (and dashes)"); } // ============================================================================ // Test: Prompt Building // ============================================================================ void test_prompt_building_basic() { diag("=== Prompt Building - Basic Tests ==="); std::string prompt = build_prompt("Show users", ""); ok(prompt.find("Show users") != std::string::npos, "Prompt contains the user query"); ok(prompt.find("SQL expert") != std::string::npos, "Prompt contains system instruction"); ok(prompt.find("return only the SQL query") != std::string::npos || prompt.find("Return only the SQL") != std::string::npos, "Prompt contains output format instruction"); } void test_prompt_building_with_schema() { diag("=== Prompt Building - With Schema Tests ==="); std::string schema = "CREATE TABLE users (id INT, name VARCHAR(100));"; std::string prompt = build_prompt("Show users", schema); ok(prompt.find("Database Schema") != std::string::npos, "Prompt includes schema section header"); ok(prompt.find(schema) != std::string::npos, "Prompt includes the actual schema"); ok(prompt.find("Natural Language Query") != std::string::npos, "Prompt includes query section"); } void test_prompt_building_structure() { diag("=== Prompt Building - Structure Tests ==="); std::string prompt = build_prompt("Test query", "Schema info"); // Check for sections in order size_t system_pos = prompt.find("SQL expert"); size_t schema_pos = prompt.find("Database Schema"); size_t query_pos = prompt.find("Natural Language Query"); size_t output_pos = prompt.find("return only"); bool correct_order = (system_pos < schema_pos || schema_pos == std::string::npos) && (schema_pos < query_pos || schema_pos == std::string::npos) && (query_pos < output_pos); ok(correct_order, "Prompt sections appear in correct order"); } void test_prompt_building_special_chars() { diag("=== Prompt Building - Special Characters Tests ==="); // Test with special characters in query std::string prompt = build_prompt("Show users with