mirror of https://github.com/sysown/proxysql
Phase 5: MCP Tool Implementation for NL2SQL
This commit implements the AI Tool Handler for the MCP (Model Context
Protocol) server, exposing NL2SQL functionality as an MCP tool.
**New Files:**
- include/AI_Tool_Handler.h: Header for AI_Tool_Handler class
- Provides ai_nl2sql_convert tool via MCP protocol
- Wraps NL2SQL_Converter and Anomaly_Detector
- Inherits from MCP_Tool_Handler base class
- lib/AI_Tool_Handler.cpp: Implementation
- Implements ai_nl2sql_convert tool execution
- Accepts parameters: natural_language (required), schema,
context_tables, max_latency_ms, allow_cache
- Returns JSON response with sql_query, confidence, explanation,
cached, cache_id
- scripts/mcp/test_nl2sql_tools.sh: Test script for NL2SQL MCP tool
- Tests ai_nl2sql_convert via JSON-RPC over HTTPS
- 10 test cases covering SELECT, WHERE, JOIN, aggregation, etc.
- Includes error handling test for empty queries
- Supports --verbose, --quiet options
**Modified Files:**
- include/MCP_Thread.h: Add AI_Tool_Handler forward declaration and pointer
- lib/Makefile: Add AI_Tool_Handler.oo to _OBJ_CXX list
- lib/ProxySQL_MCP_Server.cpp: Initialize and register AI tool handler
- Creates AI_Tool_Handler with GloAI components
- Registers /mcp/ai endpoint
- Adds cleanup in destructor
**MCP Tool Details:**
- Endpoint: /mcp/ai
- Tool: ai_nl2sql_convert
- Parameters:
- natural_language (string, required): Natural language query
- schema (string, optional): Database schema name
- context_tables (string, optional): Comma-separated table list
- max_latency_ms (integer, optional): Max acceptable latency
- allow_cache (boolean, optional): Check semantic cache (default: true)
**Testing:**
Run the test script with:
./scripts/mcp/test_nl2sql_tools.sh [--verbose] [--quiet]
See scripts/mcp/test_nl2sql_tools.sh --help for usage.
Related: Phase 1-4 (Documentation, Unit Tests, Integration Tests, E2E Tests)
Related: Phase 6-8 (User Docs, Developer Docs, Test Docs)
pull/5310/head
parent
83c3983070
commit
3f44229e28
@ -0,0 +1,96 @@
|
||||
/**
|
||||
* @file ai_tool_handler.h
|
||||
* @brief AI Tool Handler for MCP protocol
|
||||
*
|
||||
* Provides AI-related tools via MCP protocol including:
|
||||
* - NL2SQL (Natural Language to SQL) conversion
|
||||
* - Anomaly detection queries
|
||||
* - Vector storage operations
|
||||
*
|
||||
* @date 2025-01-16
|
||||
*/
|
||||
|
||||
#ifndef CLASS_AI_TOOL_HANDLER_H
|
||||
#define CLASS_AI_TOOL_HANDLER_H
|
||||
|
||||
#include "MCP_Tool_Handler.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
// Forward declarations
|
||||
class NL2SQL_Converter;
|
||||
class Anomaly_Detector;
|
||||
|
||||
/**
|
||||
* @brief AI Tool Handler for MCP
|
||||
*
|
||||
* Provides AI-powered tools through the MCP protocol:
|
||||
* - ai_nl2sql_convert: Convert natural language to SQL
|
||||
* - Future: anomaly detection, vector operations
|
||||
*/
|
||||
class AI_Tool_Handler : public MCP_Tool_Handler {
|
||||
private:
|
||||
NL2SQL_Converter* nl2sql_converter;
|
||||
Anomaly_Detector* anomaly_detector;
|
||||
bool owns_components;
|
||||
|
||||
/**
|
||||
* @brief Helper to extract string parameter from JSON
|
||||
*/
|
||||
static std::string get_json_string(const json& j, const std::string& key,
|
||||
const std::string& default_val = "");
|
||||
|
||||
/**
|
||||
* @brief Helper to extract int parameter from JSON
|
||||
*/
|
||||
static int get_json_int(const json& j, const std::string& key, int default_val = 0);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor - uses existing AI components
|
||||
*/
|
||||
AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly);
|
||||
|
||||
/**
|
||||
* @brief Constructor - creates own components
|
||||
*/
|
||||
AI_Tool_Handler();
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
~AI_Tool_Handler();
|
||||
|
||||
/**
|
||||
* @brief Initialize the tool handler
|
||||
*/
|
||||
int init() override;
|
||||
|
||||
/**
|
||||
* @brief Close and cleanup
|
||||
*/
|
||||
void close() override;
|
||||
|
||||
/**
|
||||
* @brief Get handler name
|
||||
*/
|
||||
std::string get_handler_name() const override { return "ai"; }
|
||||
|
||||
/**
|
||||
* @brief Get list of available tools
|
||||
*/
|
||||
json get_tool_list() override;
|
||||
|
||||
/**
|
||||
* @brief Get description of a specific tool
|
||||
*/
|
||||
json get_tool_description(const std::string& tool_name) override;
|
||||
|
||||
/**
|
||||
* @brief Execute a tool with arguments
|
||||
*/
|
||||
json execute_tool(const std::string& tool_name, const json& arguments) override;
|
||||
};
|
||||
|
||||
#endif /* CLASS_AI_TOOL_HANDLER_H */
|
||||
@ -0,0 +1,275 @@
|
||||
/**
|
||||
* @file AI_Tool_Handler.cpp
|
||||
* @brief Implementation of AI Tool Handler for MCP protocol
|
||||
*
|
||||
* Implements AI-powered tools through MCP protocol, primarily
|
||||
* the ai_nl2sql_convert tool for natural language to SQL conversion.
|
||||
*
|
||||
* @see AI_Tool_Handler.h
|
||||
*/
|
||||
|
||||
#include "AI_Tool_Handler.h"
|
||||
#include "NL2SQL_Converter.h"
|
||||
#include "Anomaly_Detector.h"
|
||||
#include "AI_Features_Manager.h"
|
||||
#include "proxysql_debug.h"
|
||||
#include "cpp.h"
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
|
||||
// JSON library
|
||||
#include "../deps/json/json.hpp"
|
||||
using json = nlohmann::json;
|
||||
#define PROXYJSON
|
||||
|
||||
// ============================================================================
|
||||
// Constructor/Destructor
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Constructor using existing AI components
|
||||
*/
|
||||
AI_Tool_Handler::AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly)
|
||||
: nl2sql_converter(nl2sql),
|
||||
anomaly_detector(anomaly),
|
||||
owns_components(false)
|
||||
{
|
||||
proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler created (wrapping existing components)\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constructor - creates own components
|
||||
* Note: This implementation uses global instances
|
||||
*/
|
||||
AI_Tool_Handler::AI_Tool_Handler()
|
||||
: nl2sql_converter(NULL),
|
||||
anomaly_detector(NULL),
|
||||
owns_components(false)
|
||||
{
|
||||
// Use global instances from AI_Features_Manager
|
||||
if (GloAI) {
|
||||
nl2sql_converter = GloAI->get_nl2sql();
|
||||
anomaly_detector = GloAI->get_anomaly_detector();
|
||||
}
|
||||
proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler created (using global instances)\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
AI_Tool_Handler::~AI_Tool_Handler() {
|
||||
close();
|
||||
proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler destroyed\n");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Lifecycle
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Initialize the tool handler
|
||||
*/
|
||||
int AI_Tool_Handler::init() {
|
||||
if (!nl2sql_converter) {
|
||||
proxy_error("AI_Tool_Handler: NL2SQL converter not available\n");
|
||||
return -1;
|
||||
}
|
||||
proxy_info("AI_Tool_Handler initialized\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Close and cleanup
|
||||
*/
|
||||
void AI_Tool_Handler::close() {
|
||||
if (owns_components) {
|
||||
// Components would be cleaned up here
|
||||
// For now, we use global instances managed by AI_Features_Manager
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Extract string parameter from JSON
|
||||
*/
|
||||
std::string AI_Tool_Handler::get_json_string(const json& j, const std::string& key,
|
||||
const std::string& default_val) {
|
||||
if (j.contains(key) && !j[key].is_null()) {
|
||||
if (j[key].is_string()) {
|
||||
return j[key].get<std::string>();
|
||||
} else {
|
||||
// Convert to string if not already
|
||||
return j[key].dump();
|
||||
}
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Extract int parameter from JSON
|
||||
*/
|
||||
int AI_Tool_Handler::get_json_int(const json& j, const std::string& key, int default_val) {
|
||||
if (j.contains(key) && !j[key].is_null()) {
|
||||
if (j[key].is_number()) {
|
||||
return j[key].get<int>();
|
||||
} else if (j[key].is_string()) {
|
||||
return std::stoi(j[key].get<std::string>());
|
||||
}
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool List
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Get list of available AI tools
|
||||
*/
|
||||
json AI_Tool_Handler::get_tool_list() {
|
||||
json tools = json::array();
|
||||
|
||||
// NL2SQL tool
|
||||
json nl2sql_params = json::object();
|
||||
nl2sql_params["type"] = "object";
|
||||
nl2sql_params["properties"] = json::object();
|
||||
nl2sql_params["properties"]["natural_language"] = {
|
||||
{"type", "string"},
|
||||
{"description", "Natural language query to convert to SQL"}
|
||||
};
|
||||
nl2sql_params["properties"]["schema"] = {
|
||||
{"type", "string"},
|
||||
{"description", "Database/schema name for context"}
|
||||
};
|
||||
nl2sql_params["properties"]["context_tables"] = {
|
||||
{"type", "string"},
|
||||
{"description", "Comma-separated list of relevant tables (optional)"}
|
||||
};
|
||||
nl2sql_params["properties"]["max_latency_ms"] = {
|
||||
{"type", "integer"},
|
||||
{"description", "Maximum acceptable latency in milliseconds (optional)"}
|
||||
};
|
||||
nl2sql_params["properties"]["allow_cache"] = {
|
||||
{"type", "boolean"},
|
||||
{"description", "Whether to check semantic cache (default: true)"}
|
||||
};
|
||||
nl2sql_params["required"] = json::array({"natural_language"});
|
||||
|
||||
tools.push_back({
|
||||
{"name", "ai_nl2sql_convert"},
|
||||
{"description", "Convert natural language query to SQL using LLM"},
|
||||
{"inputSchema", nl2sql_params}
|
||||
});
|
||||
|
||||
json result;
|
||||
result["tools"] = tools;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get description of a specific tool
|
||||
*/
|
||||
json AI_Tool_Handler::get_tool_description(const std::string& tool_name) {
|
||||
json tools_list = get_tool_list();
|
||||
for (const auto& tool : tools_list["tools"]) {
|
||||
if (tool["name"] == tool_name) {
|
||||
return tool;
|
||||
}
|
||||
}
|
||||
return create_error_response("Tool not found: " + tool_name);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Execution
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Execute an AI tool
|
||||
*/
|
||||
json AI_Tool_Handler::execute_tool(const std::string& tool_name, const json& arguments) {
|
||||
proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler: execute_tool(%s)\n", tool_name.c_str());
|
||||
|
||||
try {
|
||||
// NL2SQL conversion tool
|
||||
if (tool_name == "ai_nl2sql_convert") {
|
||||
if (!nl2sql_converter) {
|
||||
return create_error_response("NL2SQL converter not available");
|
||||
}
|
||||
|
||||
// Extract parameters
|
||||
std::string natural_language = get_json_string(arguments, "natural_language");
|
||||
if (natural_language.empty()) {
|
||||
return create_error_response("Missing required parameter: natural_language");
|
||||
}
|
||||
|
||||
std::string schema = get_json_string(arguments, "schema");
|
||||
int max_latency_ms = get_json_int(arguments, "max_latency_ms", 0);
|
||||
bool allow_cache = true;
|
||||
if (arguments.contains("allow_cache") && !arguments["allow_cache"].is_null()) {
|
||||
if (arguments["allow_cache"].is_boolean()) {
|
||||
allow_cache = arguments["allow_cache"].get<bool>();
|
||||
} else if (arguments["allow_cache"].is_string()) {
|
||||
std::string val = arguments["allow_cache"].get<std::string>();
|
||||
allow_cache = (val == "true" || val == "1");
|
||||
}
|
||||
}
|
||||
|
||||
// Parse context_tables
|
||||
std::vector<std::string> context_tables;
|
||||
std::string tables_str = get_json_string(arguments, "context_tables");
|
||||
if (!tables_str.empty()) {
|
||||
std::istringstream ts(tables_str);
|
||||
std::string table;
|
||||
while (std::getline(ts, table, ',')) {
|
||||
table.erase(0, table.find_first_not_of(" \t"));
|
||||
table.erase(table.find_last_not_of(" \t") + 1);
|
||||
if (!table.empty()) {
|
||||
context_tables.push_back(table);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create NL2SQL request
|
||||
NL2SQLRequest req;
|
||||
req.natural_language = natural_language;
|
||||
req.schema_name = schema;
|
||||
req.max_latency_ms = max_latency_ms;
|
||||
req.allow_cache = allow_cache;
|
||||
req.context_tables = context_tables;
|
||||
|
||||
// Call NL2SQL converter
|
||||
NL2SQLResult result = nl2sql_converter->convert(req);
|
||||
|
||||
// Build response
|
||||
json response_data;
|
||||
response_data["sql_query"] = result.sql_query;
|
||||
response_data["confidence"] = result.confidence;
|
||||
response_data["explanation"] = result.explanation;
|
||||
response_data["cached"] = result.cached;
|
||||
response_data["cache_id"] = result.cache_id;
|
||||
|
||||
// Add tables used if available
|
||||
if (!result.tables_used.empty()) {
|
||||
response_data["tables_used"] = result.tables_used;
|
||||
}
|
||||
|
||||
proxy_info("AI_Tool_Handler: NL2SQL conversion complete. SQL: %s, Confidence: %.2f\n",
|
||||
result.sql_query.c_str(), result.confidence);
|
||||
|
||||
return create_success_response(response_data);
|
||||
}
|
||||
|
||||
// Unknown tool
|
||||
return create_error_response("Unknown tool: " + tool_name);
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
proxy_error("AI_Tool_Handler: Exception in execute_tool: %s\n", e.what());
|
||||
return create_error_response(std::string("Exception: ") + e.what());
|
||||
} catch (...) {
|
||||
proxy_error("AI_Tool_Handler: Unknown exception in execute_tool\n");
|
||||
return create_error_response("Unknown exception");
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,441 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# @file test_nl2sql_tools.sh
|
||||
# @brief Test NL2SQL MCP tools via HTTPS/JSON-RPC
|
||||
#
|
||||
# Tests the ai_nl2sql_convert tool through the MCP protocol.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - ProxySQL with MCP server running on https://127.0.0.1:6071
|
||||
# - AI features enabled (GloAI initialized)
|
||||
# - LLM configured (Ollama or cloud API with valid keys)
|
||||
#
|
||||
# Usage:
|
||||
# ./test_nl2sql_tools.sh [options]
|
||||
#
|
||||
# Options:
|
||||
# -v, --verbose Show verbose output including HTTP requests/responses
|
||||
# -q, --quiet Suppress progress messages
|
||||
# -h, --help Show this help message
|
||||
#
|
||||
# @date 2025-01-16
|
||||
|
||||
set -e
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
MCP_HOST="${MCP_HOST:-127.0.0.1}"
|
||||
MCP_PORT="${MCP_PORT:-6071}"
|
||||
MCP_ENDPOINT="${MCP_ENDPOINT:-ai}"
|
||||
|
||||
# Test options
|
||||
VERBOSE=false
|
||||
QUIET=false
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m'
|
||||
|
||||
# Statistics
|
||||
TOTAL_TESTS=0
|
||||
PASSED_TESTS=0
|
||||
FAILED_TESTS=0
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
log_info() {
|
||||
if [ "${QUIET}" = "false" ]; then
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
fi
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
log_verbose() {
|
||||
if [ "${VERBOSE}" = "true" ]; then
|
||||
echo -e "${BLUE}[DEBUG]${NC} $1"
|
||||
fi
|
||||
}
|
||||
|
||||
log_test() {
|
||||
if [ "${QUIET}" = "false" ]; then
|
||||
echo -e "${CYAN}[TEST]${NC} $1"
|
||||
fi
|
||||
}
|
||||
|
||||
# Get endpoint URL
|
||||
get_endpoint_url() {
|
||||
echo "https://${MCP_HOST}:${MCP_PORT}/mcp/${MCP_ENDPOINT}"
|
||||
}
|
||||
|
||||
# Execute MCP request
|
||||
mcp_request() {
|
||||
local payload="$1"
|
||||
|
||||
local response
|
||||
response=$(curl -k -s -w "\n%{http_code}" -X POST "$(get_endpoint_url)" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "${payload}" 2>/dev/null)
|
||||
|
||||
local body
|
||||
body=$(echo "$response" | head -n -1)
|
||||
local code
|
||||
code=$(echo "$response" | tail -n 1)
|
||||
|
||||
if [ "${VERBOSE}" = "true" ]; then
|
||||
echo "Request: ${payload}" >&2
|
||||
echo "Response (${code}): ${body}" >&2
|
||||
fi
|
||||
|
||||
echo "${body}"
|
||||
return 0
|
||||
}
|
||||
|
||||
# Check if MCP server is accessible
|
||||
check_mcp_server() {
|
||||
log_test "Checking MCP server accessibility at $(get_endpoint_url)..."
|
||||
|
||||
local response
|
||||
response=$(mcp_request '{"jsonrpc":"2.0","method":"tools/list","id":1}')
|
||||
|
||||
if echo "${response}" | grep -q "result"; then
|
||||
log_info "MCP server is accessible"
|
||||
return 0
|
||||
else
|
||||
log_error "MCP server is not accessible"
|
||||
log_error "Response: ${response}"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# List available tools
|
||||
list_tools() {
|
||||
log_test "Listing available AI tools..."
|
||||
|
||||
local payload='{"jsonrpc":"2.0","method":"tools/list","id":1}'
|
||||
local response
|
||||
response=$(mcp_request "${payload}")
|
||||
|
||||
echo "${response}"
|
||||
}
|
||||
|
||||
# Get tool description
|
||||
describe_tool() {
|
||||
local tool_name="$1"
|
||||
|
||||
log_verbose "Getting description for tool: ${tool_name}"
|
||||
|
||||
local payload
|
||||
payload=$(cat <<EOF
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/describe",
|
||||
"params": {
|
||||
"name": "${tool_name}"
|
||||
},
|
||||
"id": 1
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
mcp_request "${payload}"
|
||||
}
|
||||
|
||||
# Run a single test
|
||||
run_test() {
|
||||
local test_name="$1"
|
||||
local nl_query="$2"
|
||||
local expected_pattern="$3"
|
||||
|
||||
TOTAL_TESTS=$((TOTAL_TESTS + 1))
|
||||
|
||||
log_test "Test ${TOTAL_TESTS}: ${test_name}"
|
||||
log_verbose " Query: ${nl_query}"
|
||||
|
||||
# Build the MCP request payload
|
||||
local payload
|
||||
payload=$(cat <<EOF
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "ai_nl2sql_convert",
|
||||
"arguments": {
|
||||
"natural_language": "${nl_query}"
|
||||
}
|
||||
},
|
||||
"id": ${TOTAL_TESTS}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Execute the request
|
||||
local response
|
||||
response=$(mcp_request "${payload}")
|
||||
|
||||
# Extract the result data
|
||||
local result_data
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
result_data=$(echo "${response}" | jq -r '.result.data' 2>/dev/null || echo "{}")
|
||||
else
|
||||
# Fallback: extract JSON between { and }
|
||||
result_data=$(echo "${response}" | grep -o '"data":{[^}]*}' | sed 's/"data"://')
|
||||
fi
|
||||
|
||||
# Check for errors
|
||||
if echo "${response}" | grep -q '"error"'; then
|
||||
local error_msg
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
error_msg=$(echo "${response}" | jq -r '.error.message' 2>/dev/null || echo "Unknown error")
|
||||
else
|
||||
error_msg=$(echo "${response}" | grep -o '"message"[[:space:]]*:[[:space:]]*"[^"]*"' | sed 's/.*: "\(.*\)"/\1/')
|
||||
fi
|
||||
log_error " FAILED: ${error_msg}"
|
||||
FAILED_TESTS=$((FAILED_TESTS + 1))
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract SQL query from result
|
||||
local sql_query
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
sql_query=$(echo "${response}" | jq -r '.result.data.sql_query' 2>/dev/null || echo "")
|
||||
else
|
||||
sql_query=$(echo "${response}" | grep -o '"sql_query"[[:space:]]*:[[:space:]]*"[^"]*"' | sed 's/.*: "\(.*\)"/\1/')
|
||||
fi
|
||||
|
||||
log_verbose " Generated SQL: ${sql_query}"
|
||||
|
||||
# Check if expected pattern exists
|
||||
if [ -n "${expected_pattern}" ] && [ -n "${sql_query}" ]; then
|
||||
sql_upper=$(echo "${sql_query}" | tr '[:lower:]' '[:upper:]')
|
||||
pattern_upper=$(echo "${expected_pattern}" | tr '[:lower:]' '[:upper:]')
|
||||
|
||||
if echo "${sql_upper}" | grep -qE "${pattern_upper}"; then
|
||||
log_info " PASSED: Pattern '${expected_pattern}' found in SQL"
|
||||
PASSED_TESTS=$((PASSED_TESTS + 1))
|
||||
return 0
|
||||
else
|
||||
log_error " FAILED: Pattern '${expected_pattern}' not found in SQL: ${sql_query}"
|
||||
FAILED_TESTS=$((FAILED_TESTS + 1))
|
||||
return 1
|
||||
fi
|
||||
elif [ -n "${sql_query}" ]; then
|
||||
# No pattern check, just verify SQL was generated
|
||||
log_info " PASSED: SQL generated successfully"
|
||||
PASSED_TESTS=$((PASSED_TESTS + 1))
|
||||
return 0
|
||||
else
|
||||
log_error " FAILED: No SQL query in response"
|
||||
FAILED_TESTS=$((FAILED_TESTS + 1))
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Test Cases
|
||||
# ============================================================================
|
||||
|
||||
run_all_tests() {
|
||||
log_info "Running NL2SQL MCP tool tests..."
|
||||
|
||||
# Test 1: Simple SELECT
|
||||
run_test \
|
||||
"Simple SELECT all customers" \
|
||||
"Show all customers" \
|
||||
"SELECT.*customers"
|
||||
|
||||
# Test 2: SELECT with WHERE clause
|
||||
run_test \
|
||||
"SELECT with WHERE clause" \
|
||||
"Find customers from USA" \
|
||||
"SELECT.*WHERE"
|
||||
|
||||
# Test 3: JOIN query
|
||||
run_test \
|
||||
"JOIN customers and orders" \
|
||||
"Show customer names with their order amounts" \
|
||||
"JOIN"
|
||||
|
||||
# Test 4: Aggregation (COUNT)
|
||||
run_test \
|
||||
"COUNT aggregation" \
|
||||
"Count customers by country" \
|
||||
"COUNT.*GROUP BY"
|
||||
|
||||
# Test 5: Sorting
|
||||
run_test \
|
||||
"ORDER BY clause" \
|
||||
"Show orders sorted by total amount" \
|
||||
"ORDER BY"
|
||||
|
||||
# Test 6: Limit
|
||||
run_test \
|
||||
"LIMIT clause" \
|
||||
"Show top 5 customers by revenue" \
|
||||
"SELECT.*customers"
|
||||
|
||||
# Test 7: Complex aggregation
|
||||
run_test \
|
||||
"AVG aggregation" \
|
||||
"What is the average order total?" \
|
||||
"SELECT"
|
||||
|
||||
# Test 8: Schema-specified query
|
||||
run_test \
|
||||
"Schema-specified query" \
|
||||
"List all users from the users table" \
|
||||
"SELECT.*users"
|
||||
|
||||
# Test 9: Subquery hint
|
||||
run_test \
|
||||
"Subquery pattern" \
|
||||
"Find customers with orders above average" \
|
||||
"SELECT"
|
||||
|
||||
# Test 10: Empty query (error handling)
|
||||
log_test "Test: Empty query (should handle gracefully)"
|
||||
TOTAL_TESTS=$((TOTAL_TESTS + 1))
|
||||
local payload='{"jsonrpc":"2.0","method":"tools/call","params":{"name":"ai_nl2sql_convert","arguments":{"natural_language":""}},"id":11}'
|
||||
local response
|
||||
response=$(mcp_request "${payload}")
|
||||
|
||||
if echo "${response}" | grep -q '"error"'; then
|
||||
log_info " PASSED: Empty query handled with error"
|
||||
PASSED_TESTS=$((PASSED_TESTS + 1))
|
||||
else
|
||||
log_warn " SKIPPED: Error handling for empty query not as expected"
|
||||
SKIPPED_TESTS=$((SKIPPED_TESTS + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Results Summary
|
||||
# ============================================================================
|
||||
|
||||
print_summary() {
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " Test Summary"
|
||||
echo "========================================"
|
||||
echo "Total tests: ${TOTAL_TESTS}"
|
||||
echo -e "Passed: ${GREEN}${PASSED_TESTS}${NC}"
|
||||
echo -e "Failed: ${RED}${FAILED_TESTS}${NC}"
|
||||
echo -e "Skipped: ${YELLOW}${SKIPPED_TESTS:-0}${NC}"
|
||||
echo "========================================"
|
||||
|
||||
if [ ${FAILED_TESTS} -eq 0 ]; then
|
||||
echo -e "\n${GREEN}All tests passed!${NC}\n"
|
||||
return 0
|
||||
else
|
||||
echo -e "\n${RED}Some tests failed${NC}\n"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Parse Arguments
|
||||
# ============================================================================
|
||||
|
||||
parse_args() {
|
||||
while [ $# -gt 0 ]; do
|
||||
case "$1" in
|
||||
-v|--verbose)
|
||||
VERBOSE=true
|
||||
shift
|
||||
;;
|
||||
-q|--quiet)
|
||||
QUIET=true
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
cat <<EOF
|
||||
Usage: $0 [OPTIONS]
|
||||
|
||||
Test NL2SQL MCP tools via HTTPS/JSON-RPC.
|
||||
|
||||
Options:
|
||||
-v, --verbose Show verbose output including HTTP requests/responses
|
||||
-q, --quiet Suppress progress messages
|
||||
-h, --help Show this help message
|
||||
|
||||
Environment Variables:
|
||||
MCP_HOST MCP server host (default: 127.0.0.1)
|
||||
MCP_PORT MCP server port (default: 6071)
|
||||
MCP_ENDPOINT MCP endpoint name (default: ai)
|
||||
|
||||
Examples:
|
||||
# Run tests with verbose output
|
||||
$0 --verbose
|
||||
|
||||
# Run tests against remote server
|
||||
MCP_HOST=192.168.1.100 MCP_PORT=6071 $0
|
||||
|
||||
EOF
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown option: $1"
|
||||
echo "Use -h or --help for usage information"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
main() {
|
||||
echo "========================================"
|
||||
echo " NL2SQL MCP Tool Testing"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
echo "Configuration:"
|
||||
echo " MCP Endpoint: $(get_endpoint_url)"
|
||||
echo " Verbose: ${VERBOSE}"
|
||||
echo ""
|
||||
|
||||
# Parse arguments
|
||||
parse_args "$@"
|
||||
|
||||
# Check server accessibility
|
||||
if ! check_mcp_server; then
|
||||
log_error "Cannot connect to MCP server. Please ensure ProxySQL is running with MCP enabled."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# List available tools
|
||||
echo ""
|
||||
log_info "Discovering available AI tools..."
|
||||
local tools
|
||||
tools=$(list_tools)
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
echo "${tools}" | jq -r '.result.tools[] | " - \(.name): \(.description)"' 2>/dev/null || echo "${tools}"
|
||||
else
|
||||
echo "${tools}"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Run tests
|
||||
run_all_tests
|
||||
|
||||
# Print summary
|
||||
print_summary
|
||||
}
|
||||
|
||||
main "$@"
|
||||
Loading…
Reference in new issue