From 3f44229e2835516c77700b5cc84dc9dd460643dd Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Fri, 16 Jan 2026 13:32:06 +0000 Subject: [PATCH] feat: Add MCP AI Tool Handler for NL2SQL with test script Phase 5: MCP Tool Implementation for NL2SQL This commit implements the AI Tool Handler for the MCP (Model Context Protocol) server, exposing NL2SQL functionality as an MCP tool. **New Files:** - include/AI_Tool_Handler.h: Header for AI_Tool_Handler class - Provides ai_nl2sql_convert tool via MCP protocol - Wraps NL2SQL_Converter and Anomaly_Detector - Inherits from MCP_Tool_Handler base class - lib/AI_Tool_Handler.cpp: Implementation - Implements ai_nl2sql_convert tool execution - Accepts parameters: natural_language (required), schema, context_tables, max_latency_ms, allow_cache - Returns JSON response with sql_query, confidence, explanation, cached, cache_id - scripts/mcp/test_nl2sql_tools.sh: Test script for NL2SQL MCP tool - Tests ai_nl2sql_convert via JSON-RPC over HTTPS - 10 test cases covering SELECT, WHERE, JOIN, aggregation, etc. - Includes error handling test for empty queries - Supports --verbose, --quiet options **Modified Files:** - include/MCP_Thread.h: Add AI_Tool_Handler forward declaration and pointer - lib/Makefile: Add AI_Tool_Handler.oo to _OBJ_CXX list - lib/ProxySQL_MCP_Server.cpp: Initialize and register AI tool handler - Creates AI_Tool_Handler with GloAI components - Registers /mcp/ai endpoint - Adds cleanup in destructor **MCP Tool Details:** - Endpoint: /mcp/ai - Tool: ai_nl2sql_convert - Parameters: - natural_language (string, required): Natural language query - schema (string, optional): Database schema name - context_tables (string, optional): Comma-separated table list - max_latency_ms (integer, optional): Max acceptable latency - allow_cache (boolean, optional): Check semantic cache (default: true) **Testing:** Run the test script with: ./scripts/mcp/test_nl2sql_tools.sh [--verbose] [--quiet] See scripts/mcp/test_nl2sql_tools.sh --help for usage. Related: Phase 1-4 (Documentation, Unit Tests, Integration Tests, E2E Tests) Related: Phase 6-8 (User Docs, Developer Docs, Test Docs) --- include/AI_Tool_Handler.h | 96 +++++++ include/MCP_Thread.h | 2 + lib/AI_Tool_Handler.cpp | 275 +++++++++++++++++++ lib/Makefile | 2 +- lib/ProxySQL_MCP_Server.cpp | 49 +++- scripts/mcp/test_nl2sql_tools.sh | 441 +++++++++++++++++++++++++++++++ 6 files changed, 858 insertions(+), 7 deletions(-) create mode 100644 include/AI_Tool_Handler.h create mode 100644 lib/AI_Tool_Handler.cpp create mode 100755 scripts/mcp/test_nl2sql_tools.sh diff --git a/include/AI_Tool_Handler.h b/include/AI_Tool_Handler.h new file mode 100644 index 000000000..85e102284 --- /dev/null +++ b/include/AI_Tool_Handler.h @@ -0,0 +1,96 @@ +/** + * @file ai_tool_handler.h + * @brief AI Tool Handler for MCP protocol + * + * Provides AI-related tools via MCP protocol including: + * - NL2SQL (Natural Language to SQL) conversion + * - Anomaly detection queries + * - Vector storage operations + * + * @date 2025-01-16 + */ + +#ifndef CLASS_AI_TOOL_HANDLER_H +#define CLASS_AI_TOOL_HANDLER_H + +#include "MCP_Tool_Handler.h" +#include +#include +#include + +// Forward declarations +class NL2SQL_Converter; +class Anomaly_Detector; + +/** + * @brief AI Tool Handler for MCP + * + * Provides AI-powered tools through the MCP protocol: + * - ai_nl2sql_convert: Convert natural language to SQL + * - Future: anomaly detection, vector operations + */ +class AI_Tool_Handler : public MCP_Tool_Handler { +private: + NL2SQL_Converter* nl2sql_converter; + Anomaly_Detector* anomaly_detector; + bool owns_components; + + /** + * @brief Helper to extract string parameter from JSON + */ + static std::string get_json_string(const json& j, const std::string& key, + const std::string& default_val = ""); + + /** + * @brief Helper to extract int parameter from JSON + */ + static int get_json_int(const json& j, const std::string& key, int default_val = 0); + +public: + /** + * @brief Constructor - uses existing AI components + */ + AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly); + + /** + * @brief Constructor - creates own components + */ + AI_Tool_Handler(); + + /** + * @brief Destructor + */ + ~AI_Tool_Handler(); + + /** + * @brief Initialize the tool handler + */ + int init() override; + + /** + * @brief Close and cleanup + */ + void close() override; + + /** + * @brief Get handler name + */ + std::string get_handler_name() const override { return "ai"; } + + /** + * @brief Get list of available tools + */ + json get_tool_list() override; + + /** + * @brief Get description of a specific tool + */ + json get_tool_description(const std::string& tool_name) override; + + /** + * @brief Execute a tool with arguments + */ + json execute_tool(const std::string& tool_name, const json& arguments) override; +}; + +#endif /* CLASS_AI_TOOL_HANDLER_H */ diff --git a/include/MCP_Thread.h b/include/MCP_Thread.h index acf68dfb4..bae5585f0 100644 --- a/include/MCP_Thread.h +++ b/include/MCP_Thread.h @@ -16,6 +16,7 @@ class Query_Tool_Handler; class Admin_Tool_Handler; class Cache_Tool_Handler; class Observe_Tool_Handler; +class AI_Tool_Handler; /** * @brief MCP Threads Handler class for managing MCP module configuration @@ -100,6 +101,7 @@ public: Admin_Tool_Handler* admin_tool_handler; Cache_Tool_Handler* cache_tool_handler; Observe_Tool_Handler* observe_tool_handler; + AI_Tool_Handler* ai_tool_handler; /** diff --git a/lib/AI_Tool_Handler.cpp b/lib/AI_Tool_Handler.cpp new file mode 100644 index 000000000..3bc1c45d1 --- /dev/null +++ b/lib/AI_Tool_Handler.cpp @@ -0,0 +1,275 @@ +/** + * @file AI_Tool_Handler.cpp + * @brief Implementation of AI Tool Handler for MCP protocol + * + * Implements AI-powered tools through MCP protocol, primarily + * the ai_nl2sql_convert tool for natural language to SQL conversion. + * + * @see AI_Tool_Handler.h + */ + +#include "AI_Tool_Handler.h" +#include "NL2SQL_Converter.h" +#include "Anomaly_Detector.h" +#include "AI_Features_Manager.h" +#include "proxysql_debug.h" +#include "cpp.h" +#include +#include + +// JSON library +#include "../deps/json/json.hpp" +using json = nlohmann::json; +#define PROXYJSON + +// ============================================================================ +// Constructor/Destructor +// ============================================================================ + +/** + * @brief Constructor using existing AI components + */ +AI_Tool_Handler::AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly) + : nl2sql_converter(nl2sql), + anomaly_detector(anomaly), + owns_components(false) +{ + proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler created (wrapping existing components)\n"); +} + +/** + * @brief Constructor - creates own components + * Note: This implementation uses global instances + */ +AI_Tool_Handler::AI_Tool_Handler() + : nl2sql_converter(NULL), + anomaly_detector(NULL), + owns_components(false) +{ + // Use global instances from AI_Features_Manager + if (GloAI) { + nl2sql_converter = GloAI->get_nl2sql(); + anomaly_detector = GloAI->get_anomaly_detector(); + } + proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler created (using global instances)\n"); +} + +/** + * @brief Destructor + */ +AI_Tool_Handler::~AI_Tool_Handler() { + close(); + proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler destroyed\n"); +} + +// ============================================================================ +// Lifecycle +// ============================================================================ + +/** + * @brief Initialize the tool handler + */ +int AI_Tool_Handler::init() { + if (!nl2sql_converter) { + proxy_error("AI_Tool_Handler: NL2SQL converter not available\n"); + return -1; + } + proxy_info("AI_Tool_Handler initialized\n"); + return 0; +} + +/** + * @brief Close and cleanup + */ +void AI_Tool_Handler::close() { + if (owns_components) { + // Components would be cleaned up here + // For now, we use global instances managed by AI_Features_Manager + } +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/** + * @brief Extract string parameter from JSON + */ +std::string AI_Tool_Handler::get_json_string(const json& j, const std::string& key, + const std::string& default_val) { + if (j.contains(key) && !j[key].is_null()) { + if (j[key].is_string()) { + return j[key].get(); + } else { + // Convert to string if not already + return j[key].dump(); + } + } + return default_val; +} + +/** + * @brief Extract int parameter from JSON + */ +int AI_Tool_Handler::get_json_int(const json& j, const std::string& key, int default_val) { + if (j.contains(key) && !j[key].is_null()) { + if (j[key].is_number()) { + return j[key].get(); + } else if (j[key].is_string()) { + return std::stoi(j[key].get()); + } + } + return default_val; +} + +// ============================================================================ +// Tool List +// ============================================================================ + +/** + * @brief Get list of available AI tools + */ +json AI_Tool_Handler::get_tool_list() { + json tools = json::array(); + + // NL2SQL tool + json nl2sql_params = json::object(); + nl2sql_params["type"] = "object"; + nl2sql_params["properties"] = json::object(); + nl2sql_params["properties"]["natural_language"] = { + {"type", "string"}, + {"description", "Natural language query to convert to SQL"} + }; + nl2sql_params["properties"]["schema"] = { + {"type", "string"}, + {"description", "Database/schema name for context"} + }; + nl2sql_params["properties"]["context_tables"] = { + {"type", "string"}, + {"description", "Comma-separated list of relevant tables (optional)"} + }; + nl2sql_params["properties"]["max_latency_ms"] = { + {"type", "integer"}, + {"description", "Maximum acceptable latency in milliseconds (optional)"} + }; + nl2sql_params["properties"]["allow_cache"] = { + {"type", "boolean"}, + {"description", "Whether to check semantic cache (default: true)"} + }; + nl2sql_params["required"] = json::array({"natural_language"}); + + tools.push_back({ + {"name", "ai_nl2sql_convert"}, + {"description", "Convert natural language query to SQL using LLM"}, + {"inputSchema", nl2sql_params} + }); + + json result; + result["tools"] = tools; + return result; +} + +/** + * @brief Get description of a specific tool + */ +json AI_Tool_Handler::get_tool_description(const std::string& tool_name) { + json tools_list = get_tool_list(); + for (const auto& tool : tools_list["tools"]) { + if (tool["name"] == tool_name) { + return tool; + } + } + return create_error_response("Tool not found: " + tool_name); +} + +// ============================================================================ +// Tool Execution +// ============================================================================ + +/** + * @brief Execute an AI tool + */ +json AI_Tool_Handler::execute_tool(const std::string& tool_name, const json& arguments) { + proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler: execute_tool(%s)\n", tool_name.c_str()); + + try { + // NL2SQL conversion tool + if (tool_name == "ai_nl2sql_convert") { + if (!nl2sql_converter) { + return create_error_response("NL2SQL converter not available"); + } + + // Extract parameters + std::string natural_language = get_json_string(arguments, "natural_language"); + if (natural_language.empty()) { + return create_error_response("Missing required parameter: natural_language"); + } + + std::string schema = get_json_string(arguments, "schema"); + int max_latency_ms = get_json_int(arguments, "max_latency_ms", 0); + bool allow_cache = true; + if (arguments.contains("allow_cache") && !arguments["allow_cache"].is_null()) { + if (arguments["allow_cache"].is_boolean()) { + allow_cache = arguments["allow_cache"].get(); + } else if (arguments["allow_cache"].is_string()) { + std::string val = arguments["allow_cache"].get(); + allow_cache = (val == "true" || val == "1"); + } + } + + // Parse context_tables + std::vector context_tables; + std::string tables_str = get_json_string(arguments, "context_tables"); + if (!tables_str.empty()) { + std::istringstream ts(tables_str); + std::string table; + while (std::getline(ts, table, ',')) { + table.erase(0, table.find_first_not_of(" \t")); + table.erase(table.find_last_not_of(" \t") + 1); + if (!table.empty()) { + context_tables.push_back(table); + } + } + } + + // Create NL2SQL request + NL2SQLRequest req; + req.natural_language = natural_language; + req.schema_name = schema; + req.max_latency_ms = max_latency_ms; + req.allow_cache = allow_cache; + req.context_tables = context_tables; + + // Call NL2SQL converter + NL2SQLResult result = nl2sql_converter->convert(req); + + // Build response + json response_data; + response_data["sql_query"] = result.sql_query; + response_data["confidence"] = result.confidence; + response_data["explanation"] = result.explanation; + response_data["cached"] = result.cached; + response_data["cache_id"] = result.cache_id; + + // Add tables used if available + if (!result.tables_used.empty()) { + response_data["tables_used"] = result.tables_used; + } + + proxy_info("AI_Tool_Handler: NL2SQL conversion complete. SQL: %s, Confidence: %.2f\n", + result.sql_query.c_str(), result.confidence); + + return create_success_response(response_data); + } + + // Unknown tool + return create_error_response("Unknown tool: " + tool_name); + + } catch (const std::exception& e) { + proxy_error("AI_Tool_Handler: Exception in execute_tool: %s\n", e.what()); + return create_error_response(std::string("Exception: ") + e.what()); + } catch (...) { + proxy_error("AI_Tool_Handler: Unknown exception in execute_tool\n"); + return create_error_response("Unknown exception"); + } +} diff --git a/lib/Makefile b/lib/Makefile index 251b7c0a8..fc1e2960d 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -85,7 +85,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo MySQL_Catalog.oo MySQL_Tool_Handler.oo \ Config_Tool_Handler.oo Query_Tool_Handler.oo \ Admin_Tool_Handler.oo Cache_Tool_Handler.oo Observe_Tool_Handler.oo \ - AI_Features_Manager.oo NL2SQL_Converter.oo LLM_Clients.oo Anomaly_Detector.oo AI_Vector_Storage.oo + AI_Features_Manager.oo NL2SQL_Converter.oo LLM_Clients.oo Anomaly_Detector.oo AI_Vector_Storage.oo AI_Tool_Handler.oo OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX)) HEADERS := ../include/*.h ../include/*.hpp diff --git a/lib/ProxySQL_MCP_Server.cpp b/lib/ProxySQL_MCP_Server.cpp index fc58f6405..434627a34 100644 --- a/lib/ProxySQL_MCP_Server.cpp +++ b/lib/ProxySQL_MCP_Server.cpp @@ -12,6 +12,8 @@ using json = nlohmann::json; #include "Admin_Tool_Handler.h" #include "Cache_Tool_Handler.h" #include "Observe_Tool_Handler.h" +#include "AI_Tool_Handler.h" +#include "AI_Features_Manager.h" #include "proxysql_utils.h" using namespace httpserver; @@ -119,6 +121,22 @@ ProxySQL_MCP_Server::ProxySQL_MCP_Server(int p, MCP_Threads_Handler* h) proxy_info("Observe Tool Handler initialized\n"); } + // 6. AI Tool Handler (for NL2SQL and other AI features) + extern AI_Features_Manager *GloAI; + if (GloAI) { + handler->ai_tool_handler = new AI_Tool_Handler(GloAI->get_nl2sql(), GloAI->get_anomaly_detector()); + if (handler->ai_tool_handler->init() == 0) { + proxy_info("AI Tool Handler initialized\n"); + } else { + proxy_error("Failed to initialize AI Tool Handler\n"); + delete handler->ai_tool_handler; + handler->ai_tool_handler = NULL; + } + } else { + proxy_warning("AI_Features_Manager not available, AI Tool Handler not initialized\n"); + handler->ai_tool_handler = NULL; + } + // Register MCP endpoints // Each endpoint gets its own dedicated tool handler std::unique_ptr config_resource = @@ -146,17 +164,36 @@ ProxySQL_MCP_Server::ProxySQL_MCP_Server(int p, MCP_Threads_Handler* h) ws->register_resource("/mcp/cache", cache_resource.get(), true); _endpoints.push_back({"/mcp/cache", std::move(cache_resource)}); - proxy_info("Registered 5 MCP endpoints with dedicated tool handlers: /mcp/config, /mcp/observe, /mcp/query, /mcp/admin, /mcp/cache\n"); + // 6. AI endpoint (for NL2SQL and other AI features) + if (handler->ai_tool_handler) { + std::unique_ptr ai_resource = + std::unique_ptr(new MCP_JSONRPC_Resource(handler, handler->ai_tool_handler, "ai")); + ws->register_resource("/mcp/ai", ai_resource.get(), true); + _endpoints.push_back({"/mcp/ai", std::move(ai_resource)}); + } + + proxy_info("Registered %d MCP endpoints with dedicated tool handlers: /mcp/config, /mcp/observe, /mcp/query, /mcp/admin, /mcp/cache%s/mcp/ai\n", + handler->ai_tool_handler ? 6 : 5, handler->ai_tool_handler ? ", " : ""); } ProxySQL_MCP_Server::~ProxySQL_MCP_Server() { stop(); - // Clean up MySQL Tool Handler - if (handler && handler->mysql_tool_handler) { - proxy_info("Cleaning up MySQL Tool Handler...\n"); - delete handler->mysql_tool_handler; - handler->mysql_tool_handler = NULL; + // Clean up tool handlers + if (handler) { + // Clean up AI Tool Handler (uses shared components, don't delete them) + if (handler->ai_tool_handler) { + proxy_info("Cleaning up AI Tool Handler...\n"); + delete handler->ai_tool_handler; + handler->ai_tool_handler = NULL; + } + + // Clean up MySQL Tool Handler + if (handler->mysql_tool_handler) { + proxy_info("Cleaning up MySQL Tool Handler...\n"); + delete handler->mysql_tool_handler; + handler->mysql_tool_handler = NULL; + } } } diff --git a/scripts/mcp/test_nl2sql_tools.sh b/scripts/mcp/test_nl2sql_tools.sh new file mode 100755 index 000000000..b8dfeec2c --- /dev/null +++ b/scripts/mcp/test_nl2sql_tools.sh @@ -0,0 +1,441 @@ +#!/bin/bash +# +# @file test_nl2sql_tools.sh +# @brief Test NL2SQL MCP tools via HTTPS/JSON-RPC +# +# Tests the ai_nl2sql_convert tool through the MCP protocol. +# +# Prerequisites: +# - ProxySQL with MCP server running on https://127.0.0.1:6071 +# - AI features enabled (GloAI initialized) +# - LLM configured (Ollama or cloud API with valid keys) +# +# Usage: +# ./test_nl2sql_tools.sh [options] +# +# Options: +# -v, --verbose Show verbose output including HTTP requests/responses +# -q, --quiet Suppress progress messages +# -h, --help Show this help message +# +# @date 2025-01-16 + +set -e + +# ============================================================================ +# Configuration +# ============================================================================ + +MCP_HOST="${MCP_HOST:-127.0.0.1}" +MCP_PORT="${MCP_PORT:-6071}" +MCP_ENDPOINT="${MCP_ENDPOINT:-ai}" + +# Test options +VERBOSE=false +QUIET=false + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' + +# Statistics +TOTAL_TESTS=0 +PASSED_TESTS=0 +FAILED_TESTS=0 + +# ============================================================================ +# Helper Functions +# ============================================================================ + +log_info() { + if [ "${QUIET}" = "false" ]; then + echo -e "${GREEN}[INFO]${NC} $1" + fi +} + +log_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +log_verbose() { + if [ "${VERBOSE}" = "true" ]; then + echo -e "${BLUE}[DEBUG]${NC} $1" + fi +} + +log_test() { + if [ "${QUIET}" = "false" ]; then + echo -e "${CYAN}[TEST]${NC} $1" + fi +} + +# Get endpoint URL +get_endpoint_url() { + echo "https://${MCP_HOST}:${MCP_PORT}/mcp/${MCP_ENDPOINT}" +} + +# Execute MCP request +mcp_request() { + local payload="$1" + + local response + response=$(curl -k -s -w "\n%{http_code}" -X POST "$(get_endpoint_url)" \ + -H "Content-Type: application/json" \ + -d "${payload}" 2>/dev/null) + + local body + body=$(echo "$response" | head -n -1) + local code + code=$(echo "$response" | tail -n 1) + + if [ "${VERBOSE}" = "true" ]; then + echo "Request: ${payload}" >&2 + echo "Response (${code}): ${body}" >&2 + fi + + echo "${body}" + return 0 +} + +# Check if MCP server is accessible +check_mcp_server() { + log_test "Checking MCP server accessibility at $(get_endpoint_url)..." + + local response + response=$(mcp_request '{"jsonrpc":"2.0","method":"tools/list","id":1}') + + if echo "${response}" | grep -q "result"; then + log_info "MCP server is accessible" + return 0 + else + log_error "MCP server is not accessible" + log_error "Response: ${response}" + return 1 + fi +} + +# List available tools +list_tools() { + log_test "Listing available AI tools..." + + local payload='{"jsonrpc":"2.0","method":"tools/list","id":1}' + local response + response=$(mcp_request "${payload}") + + echo "${response}" +} + +# Get tool description +describe_tool() { + local tool_name="$1" + + log_verbose "Getting description for tool: ${tool_name}" + + local payload + payload=$(cat </dev/null 2>&1; then + result_data=$(echo "${response}" | jq -r '.result.data' 2>/dev/null || echo "{}") + else + # Fallback: extract JSON between { and } + result_data=$(echo "${response}" | grep -o '"data":{[^}]*}' | sed 's/"data"://') + fi + + # Check for errors + if echo "${response}" | grep -q '"error"'; then + local error_msg + if command -v jq >/dev/null 2>&1; then + error_msg=$(echo "${response}" | jq -r '.error.message' 2>/dev/null || echo "Unknown error") + else + error_msg=$(echo "${response}" | grep -o '"message"[[:space:]]*:[[:space:]]*"[^"]*"' | sed 's/.*: "\(.*\)"/\1/') + fi + log_error " FAILED: ${error_msg}" + FAILED_TESTS=$((FAILED_TESTS + 1)) + return 1 + fi + + # Extract SQL query from result + local sql_query + if command -v jq >/dev/null 2>&1; then + sql_query=$(echo "${response}" | jq -r '.result.data.sql_query' 2>/dev/null || echo "") + else + sql_query=$(echo "${response}" | grep -o '"sql_query"[[:space:]]*:[[:space:]]*"[^"]*"' | sed 's/.*: "\(.*\)"/\1/') + fi + + log_verbose " Generated SQL: ${sql_query}" + + # Check if expected pattern exists + if [ -n "${expected_pattern}" ] && [ -n "${sql_query}" ]; then + sql_upper=$(echo "${sql_query}" | tr '[:lower:]' '[:upper:]') + pattern_upper=$(echo "${expected_pattern}" | tr '[:lower:]' '[:upper:]') + + if echo "${sql_upper}" | grep -qE "${pattern_upper}"; then + log_info " PASSED: Pattern '${expected_pattern}' found in SQL" + PASSED_TESTS=$((PASSED_TESTS + 1)) + return 0 + else + log_error " FAILED: Pattern '${expected_pattern}' not found in SQL: ${sql_query}" + FAILED_TESTS=$((FAILED_TESTS + 1)) + return 1 + fi + elif [ -n "${sql_query}" ]; then + # No pattern check, just verify SQL was generated + log_info " PASSED: SQL generated successfully" + PASSED_TESTS=$((PASSED_TESTS + 1)) + return 0 + else + log_error " FAILED: No SQL query in response" + FAILED_TESTS=$((FAILED_TESTS + 1)) + return 1 + fi +} + +# ============================================================================ +# Test Cases +# ============================================================================ + +run_all_tests() { + log_info "Running NL2SQL MCP tool tests..." + + # Test 1: Simple SELECT + run_test \ + "Simple SELECT all customers" \ + "Show all customers" \ + "SELECT.*customers" + + # Test 2: SELECT with WHERE clause + run_test \ + "SELECT with WHERE clause" \ + "Find customers from USA" \ + "SELECT.*WHERE" + + # Test 3: JOIN query + run_test \ + "JOIN customers and orders" \ + "Show customer names with their order amounts" \ + "JOIN" + + # Test 4: Aggregation (COUNT) + run_test \ + "COUNT aggregation" \ + "Count customers by country" \ + "COUNT.*GROUP BY" + + # Test 5: Sorting + run_test \ + "ORDER BY clause" \ + "Show orders sorted by total amount" \ + "ORDER BY" + + # Test 6: Limit + run_test \ + "LIMIT clause" \ + "Show top 5 customers by revenue" \ + "SELECT.*customers" + + # Test 7: Complex aggregation + run_test \ + "AVG aggregation" \ + "What is the average order total?" \ + "SELECT" + + # Test 8: Schema-specified query + run_test \ + "Schema-specified query" \ + "List all users from the users table" \ + "SELECT.*users" + + # Test 9: Subquery hint + run_test \ + "Subquery pattern" \ + "Find customers with orders above average" \ + "SELECT" + + # Test 10: Empty query (error handling) + log_test "Test: Empty query (should handle gracefully)" + TOTAL_TESTS=$((TOTAL_TESTS + 1)) + local payload='{"jsonrpc":"2.0","method":"tools/call","params":{"name":"ai_nl2sql_convert","arguments":{"natural_language":""}},"id":11}' + local response + response=$(mcp_request "${payload}") + + if echo "${response}" | grep -q '"error"'; then + log_info " PASSED: Empty query handled with error" + PASSED_TESTS=$((PASSED_TESTS + 1)) + else + log_warn " SKIPPED: Error handling for empty query not as expected" + SKIPPED_TESTS=$((SKIPPED_TESTS + 1)) + fi +} + +# ============================================================================ +# Results Summary +# ============================================================================ + +print_summary() { + echo "" + echo "========================================" + echo " Test Summary" + echo "========================================" + echo "Total tests: ${TOTAL_TESTS}" + echo -e "Passed: ${GREEN}${PASSED_TESTS}${NC}" + echo -e "Failed: ${RED}${FAILED_TESTS}${NC}" + echo -e "Skipped: ${YELLOW}${SKIPPED_TESTS:-0}${NC}" + echo "========================================" + + if [ ${FAILED_TESTS} -eq 0 ]; then + echo -e "\n${GREEN}All tests passed!${NC}\n" + return 0 + else + echo -e "\n${RED}Some tests failed${NC}\n" + return 1 + fi +} + +# ============================================================================ +# Parse Arguments +# ============================================================================ + +parse_args() { + while [ $# -gt 0 ]; do + case "$1" in + -v|--verbose) + VERBOSE=true + shift + ;; + -q|--quiet) + QUIET=true + shift + ;; + -h|--help) + cat </dev/null 2>&1; then + echo "${tools}" | jq -r '.result.tools[] | " - \(.name): \(.description)"' 2>/dev/null || echo "${tools}" + else + echo "${tools}" + fi + echo "" + + # Run tests + run_all_tests + + # Print summary + print_summary +} + +main "$@"