mirror of https://github.com/sysown/proxysql
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
327 lines
11 KiB
327 lines
11 KiB
/**
|
|
* @file nl2sql_prompt_builder-t.cpp
|
|
* @brief TAP unit tests for NL2SQL prompt building
|
|
*
|
|
* Test Categories:
|
|
* 1. Basic prompt construction
|
|
* 2. Schema context inclusion
|
|
* 3. System instruction formatting
|
|
* 4. Edge cases (empty query, special characters)
|
|
*
|
|
* Prerequisites:
|
|
* - ProxySQL with AI features enabled
|
|
* - Admin interface on localhost:6032
|
|
*
|
|
* Usage:
|
|
* make nl2sql_prompt_builder-t
|
|
* ./nl2sql_prompt_builder-t
|
|
*
|
|
* @date 2025-01-16
|
|
*/
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include <string.h>
|
|
#include <stdio.h>
|
|
#include <unistd.h>
|
|
#include <vector>
|
|
|
|
#include "mysql.h"
|
|
#include "mysqld_error.h"
|
|
|
|
#include "tap.h"
|
|
#include "command_line.h"
|
|
#include "utils.h"
|
|
|
|
using std::string;
|
|
using std::vector;
|
|
|
|
// Global admin connection
|
|
MYSQL* g_admin = NULL;
|
|
|
|
// ============================================================================
|
|
// Helper Functions
|
|
// ============================================================================
|
|
|
|
/**
|
|
* @brief Build a prompt using NL2SQL converter
|
|
*
|
|
* This is a placeholder that simulates the prompt building process.
|
|
* In a full implementation, this would call NL2SQL_Converter::build_prompt().
|
|
*
|
|
* @param natural_language The user's natural language query
|
|
* @param schema_context Optional schema information
|
|
* @return The constructed prompt
|
|
*/
|
|
string build_test_prompt(const string& natural_language, const string& schema_context = "") {
|
|
string prompt;
|
|
|
|
// System instructions
|
|
prompt += "You are a SQL expert. Convert the following natural language question to a SQL query.\n\n";
|
|
|
|
// Add schema context if available
|
|
if (!schema_context.empty()) {
|
|
prompt += "Database Schema:\n";
|
|
prompt += schema_context;
|
|
prompt += "\n";
|
|
}
|
|
|
|
// User's question
|
|
prompt += "Question: " + natural_language + "\n\n";
|
|
prompt += "Return ONLY the SQL query. No explanations, no markdown formatting.\n";
|
|
|
|
return prompt;
|
|
}
|
|
|
|
/**
|
|
* @brief Check if prompt contains required elements
|
|
* @param prompt The prompt to check
|
|
* @param elements Vector of required strings
|
|
* @return true if all elements are present
|
|
*/
|
|
bool prompt_contains_elements(const string& prompt, const vector<string>& elements) {
|
|
for (const auto& elem : elements) {
|
|
if (prompt.find(elem) == string::npos) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Test: Basic Prompt Construction
|
|
// ============================================================================
|
|
|
|
/**
|
|
* @test Basic prompt construction
|
|
* @description Verify that basic prompts are constructed correctly
|
|
* @expected Prompt should contain system instructions and user query
|
|
*/
|
|
void test_basic_prompt_construction() {
|
|
diag("=== Basic Prompt Construction Tests ===");
|
|
|
|
// Test 1: Simple query
|
|
string prompt = build_test_prompt("Show all users");
|
|
vector<string> required = {"You are a SQL expert", "Show all users", "Return ONLY the SQL query"};
|
|
ok(prompt_contains_elements(prompt, required), "Simple query prompt contains all required elements");
|
|
|
|
// Test 2: Query with conditions
|
|
prompt = build_test_prompt("Find customers where age > 25");
|
|
required = {"You are a SQL expert", "Find customers where age > 25", "SQL query"};
|
|
ok(prompt_contains_elements(prompt, required), "Query with conditions prompt is correct");
|
|
|
|
// Test 3: Aggregation query
|
|
prompt = build_test_prompt("Count users by country");
|
|
required = {"You are a SQL expert", "Count users by country"};
|
|
ok(prompt_contains_elements(prompt, required), "Aggregation query prompt is correct");
|
|
|
|
// Test 4: Query with JOIN
|
|
prompt = build_test_prompt("Show orders with customer names");
|
|
required = {"You are a SQL expert", "Show orders with customer names"};
|
|
ok(prompt_contains_elements(prompt, required), "JOIN query prompt is correct");
|
|
|
|
// Test 5: Complex query
|
|
prompt = build_test_prompt("Find the top 10 customers by total order amount in the last 30 days");
|
|
required = {"You are a SQL expert", "Find the top 10 customers", "last 30 days"};
|
|
ok(prompt_contains_elements(prompt, required), "Complex query prompt is correct");
|
|
}
|
|
|
|
// ============================================================================
|
|
// Test: Schema Context Inclusion
|
|
// ============================================================================
|
|
|
|
/**
|
|
* @test Schema context inclusion
|
|
* @description Verify that schema context is properly included in prompts
|
|
* @expected Prompt should contain schema information when provided
|
|
*/
|
|
void test_schema_context_inclusion() {
|
|
diag("=== Schema Context Inclusion Tests ===");
|
|
|
|
// Test 1: Empty schema context
|
|
string prompt = build_test_prompt("Show all users", "");
|
|
ok(prompt.find("Database Schema:") == string::npos, "Empty schema context doesn't add schema section");
|
|
|
|
// Test 2: Simple schema context
|
|
string schema = "Table: users (id INT, name VARCHAR(100))";
|
|
prompt = build_test_prompt("Show all users", schema);
|
|
ok(prompt.find("Database Schema:") != string::npos && prompt.find("users") != string::npos,
|
|
"Simple schema context is included");
|
|
|
|
// Test 3: Multi-table schema context
|
|
schema = "Table: users (id INT, name VARCHAR(100))\nTable: orders (id INT, user_id INT, amount DECIMAL)";
|
|
prompt = build_test_prompt("Show orders with user names", schema);
|
|
ok(prompt.find("users") != string::npos && prompt.find("orders") != string::npos,
|
|
"Multi-table schema context is included");
|
|
|
|
// Test 4: Schema with foreign keys
|
|
schema = "users.id <- orders.user_id (FOREIGN KEY)";
|
|
prompt = build_test_prompt("Show all orders with user info", schema);
|
|
ok(prompt.find("FOREIGN KEY") != string::npos, "Schema with foreign keys is included");
|
|
|
|
// Test 5: Large schema context
|
|
schema.clear();
|
|
for (int i = 0; i < 20; i++) {
|
|
char table_name[64];
|
|
snprintf(table_name, sizeof(table_name), "Table: table%d (id INT, data VARCHAR)", i);
|
|
schema += table_name;
|
|
schema += "\n";
|
|
}
|
|
prompt = build_test_prompt("Show data from table5", schema);
|
|
ok(prompt.find("table5") != string::npos, "Large schema context includes relevant table");
|
|
}
|
|
|
|
// ============================================================================
|
|
// Test: System Instruction Formatting
|
|
// ============================================================================
|
|
|
|
/**
|
|
* @test System instruction formatting
|
|
* @description Verify that system instructions are properly formatted
|
|
* @expected Prompt should have proper system instruction section
|
|
*/
|
|
void test_system_instruction_formatting() {
|
|
diag("=== System Instruction Formatting Tests ===");
|
|
|
|
// Test 1: System instruction presence
|
|
string prompt = build_test_prompt("Any query");
|
|
ok(prompt.find("You are a SQL expert") != string::npos, "System instruction contains role definition");
|
|
|
|
// Test 2: Task description
|
|
ok(prompt.find("Convert the following natural language question") != string::npos,
|
|
"System instruction contains task description");
|
|
|
|
// Test 3: Output format requirement
|
|
ok(prompt.find("Return ONLY the SQL query") != string::npos,
|
|
"System instruction specifies output format");
|
|
|
|
// Test 4: No explanations requirement
|
|
ok(prompt.find("No explanations") != string::npos,
|
|
"System instruction specifies no explanations");
|
|
|
|
// Test 5: No markdown requirement
|
|
ok(prompt.find("no markdown formatting") != string::npos,
|
|
"System instruction specifies no markdown");
|
|
}
|
|
|
|
// ============================================================================
|
|
// Test: Edge Cases
|
|
// ============================================================================
|
|
|
|
/**
|
|
* @test Edge cases
|
|
* @description Verify proper handling of edge cases
|
|
* @expected Edge cases should be handled gracefully
|
|
*/
|
|
void test_edge_cases() {
|
|
diag("=== Edge Case Tests ===");
|
|
|
|
// Test 1: Empty query
|
|
string prompt = build_test_prompt("");
|
|
ok(prompt.find("Question: ") != string::npos, "Empty query is handled");
|
|
|
|
// Test 2: Very long query
|
|
string long_query(10000, 'a');
|
|
prompt = build_test_prompt(long_query);
|
|
ok(prompt.length() > 10000, "Very long query is included");
|
|
|
|
// Test 3: Query with special characters
|
|
string special_query = "Find users with émojis 🎉 and quotes \"'";
|
|
prompt = build_test_prompt(special_query);
|
|
ok(prompt.find("émojis") != string::npos, "Special characters are preserved");
|
|
|
|
// Test 4: Query with newlines
|
|
string newline_query = "Show users\nwhere\nage > 25";
|
|
prompt = build_test_prompt(newline_query);
|
|
ok(prompt.find("age > 25") != string::npos, "Query with newlines is handled");
|
|
|
|
// Test 5: Query with SQL injection attempt (should be safe)
|
|
string injection_query = "'; DROP TABLE users; --";
|
|
prompt = build_test_prompt(injection_query);
|
|
ok(prompt.find("DROP TABLE") != string::npos,
|
|
"SQL injection text is included in prompt (LLM must handle safety)");
|
|
}
|
|
|
|
// ============================================================================
|
|
// Test: Prompt Structure Validation
|
|
// ============================================================================>
|
|
|
|
/**
|
|
* @test Prompt structure validation
|
|
* @description Verify that prompts follow the expected structure
|
|
* @expected Prompts should have proper sections in correct order
|
|
*/
|
|
void test_prompt_structure_validation() {
|
|
diag("=== Prompt Structure Validation Tests ===");
|
|
|
|
string prompt = build_test_prompt("Show users", "Table: users (id INT, name VARCHAR)");
|
|
|
|
// Test 1: System instructions come first
|
|
size_t system_pos = prompt.find("You are a SQL expert");
|
|
ok(system_pos == 0, "System instructions are at the beginning");
|
|
|
|
// Test 2: Schema section comes before question
|
|
size_t schema_pos = prompt.find("Database Schema:");
|
|
size_t question_pos = prompt.find("Question:");
|
|
if (schema_pos != string::npos) {
|
|
ok(schema_pos < question_pos, "Schema section comes before question");
|
|
} else {
|
|
skip(1, "No schema section present");
|
|
}
|
|
|
|
// Test 3: Question section contains the original query
|
|
ok(question_pos != string::npos, "Question section exists");
|
|
|
|
// Test 4: Output requirements come at the end
|
|
size_t output_pos = prompt.find("Return ONLY the SQL query");
|
|
ok(output_pos != string::npos && output_pos > question_pos,
|
|
"Output requirements come after question");
|
|
|
|
// Test 5: Proper line breaks between sections
|
|
size_t newline_count = 0;
|
|
for (char c : prompt) {
|
|
if (c == '\n') newline_count++;
|
|
}
|
|
ok(newline_count >= 3, "Prompt has proper line breaks between sections");
|
|
}
|
|
|
|
// ============================================================================
|
|
// Main
|
|
// ============================================================================
|
|
|
|
int main(int argc, char** argv) {
|
|
// Parse command line
|
|
CommandLine cl;
|
|
if (cl.getEnv()) {
|
|
diag("Error getting environment variables");
|
|
return exit_status();
|
|
}
|
|
|
|
// Connect to admin interface (for config checks)
|
|
g_admin = mysql_init(NULL);
|
|
if (!g_admin) {
|
|
diag("Failed to initialize MySQL connection");
|
|
return exit_status();
|
|
}
|
|
|
|
if (!mysql_real_connect(g_admin, cl.host, cl.admin_username, cl.admin_password,
|
|
NULL, cl.admin_port, NULL, 0)) {
|
|
diag("Failed to connect to admin interface: %s", mysql_error(g_admin));
|
|
mysql_close(g_admin);
|
|
return exit_status();
|
|
}
|
|
|
|
// Plan tests: 6 categories with 5 tests each
|
|
plan(30);
|
|
|
|
// Run test categories
|
|
test_basic_prompt_construction();
|
|
test_schema_context_inclusion();
|
|
test_system_instruction_formatting();
|
|
test_edge_cases();
|
|
test_prompt_structure_validation();
|
|
|
|
mysql_close(g_admin);
|
|
return exit_status();
|
|
}
|