From a3f0bade4ea8b565b56eef2f29419aff498f152e Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Sat, 17 Jan 2026 11:42:30 +0000 Subject: [PATCH] feat: Convert NL2SQL to generic LLM bridge - Rename NL2SQL_Converter to LLM_Bridge for generic prompt processing - Update MySQL protocol handler from /* NL2SQL: */ to /* LLM: */ - Remove SQL-specific fields (sql_query, confidence, tables_used) - Add GENAI_OP_LLM operation type to GenAI module - Rename all genai_nl2sql_* variables to genai_llm_* - Update AI_Features_Manager to use LLM_Bridge - Deprecate ai_nl2sql_convert MCP tool with error message - LLM bridge now handles any prompt type via MySQL protocol This enables generic LLM access (summarization, code generation, translation, analysis) while preserving infrastructure for future NL2SQL implementation via Web UI + external agents. --- include/AI_Features_Manager.h | 86 +- include/AI_Tool_Handler.h | 6 +- include/GenAI_Thread.h | 19 +- include/{NL2SQL_Converter.h => LLM_Bridge.h} | 204 +++-- include/MySQL_Session.h | 2 +- include/proxysql.h | 2 +- lib/AI_Features_Manager.cpp | 114 +-- lib/AI_Tool_Handler.cpp | 86 +- lib/Admin_FlushVariables.cpp | 10 +- lib/GenAI_Thread.cpp | 211 +++-- lib/LLM_Bridge.cpp | 375 +++++++++ lib/LLM_Clients.cpp | 42 +- lib/Makefile | 2 +- lib/MySQL_Session.cpp | 112 +-- lib/NL2SQL_Converter.cpp | 790 ------------------- lib/ProxySQL_MCP_Server.cpp | 6 +- 16 files changed, 779 insertions(+), 1288 deletions(-) rename include/{NL2SQL_Converter.h => LLM_Bridge.h} (54%) create mode 100644 lib/LLM_Bridge.cpp delete mode 100644 lib/NL2SQL_Converter.cpp diff --git a/include/AI_Features_Manager.h b/include/AI_Features_Manager.h index 01dc1c82e..1c90a6aa8 100644 --- a/include/AI_Features_Manager.h +++ b/include/AI_Features_Manager.h @@ -3,41 +3,41 @@ * @brief AI Features Manager for ProxySQL * * The AI_Features_Manager class coordinates all AI-related features in ProxySQL: - * - NL2SQL (Natural Language to SQL) conversion + * - LLM Bridge (generic LLM access via MySQL protocol) * - Anomaly detection for security monitoring * - Vector storage for semantic caching * - Hybrid model routing (local Ollama + cloud APIs) * * Architecture: - * - Central configuration management with 'ai-' variable prefix + * - Central configuration management with 'genai-' variable prefix * - Thread-safe operations using pthread rwlock * - Follows same pattern as MCP_Threads_Handler and GenAI_Threads_Handler * - Coordinates with MySQL_Session for query interception * - * @date 2025-01-16 - * @version 0.1.0 + * @date 2025-01-17 + * @version 1.0.0 * * Example Usage: * @code - * // Access NL2SQL converter - * NL2SQL_Converter* nl2sql = GloAI->get_nl2sql(); - * NL2SQLRequest req; - * req.natural_language = "Show top customers"; - * NL2SQLResult result = nl2sql->convert(req); + * // Access LLM bridge + * LLM_Bridge* llm = GloAI->get_llm_bridge(); + * LLMRequest req; + * req.prompt = "Summarize this data"; + * LLMResult result = llm->process(req); * @endcode */ #ifndef __CLASS_AI_FEATURES_MANAGER_H #define __CLASS_AI_FEATURES_MANAGER_H -#define AI_FEATURES_MANAGER_VERSION "0.1.0" +#define AI_FEATURES_MANAGER_VERSION "1.0.0" #include "proxysql.h" #include #include // Forward declarations -class NL2SQL_Converter; +class LLM_Bridge; class Anomaly_Detector; class SQLite3DB; @@ -45,7 +45,7 @@ class SQLite3DB; * @brief AI Features Manager * * Coordinates all AI features in ProxySQL: - * - NL2SQL (Natural Language to SQL) conversion + * - LLM Bridge (generic LLM access) * - Anomaly detection for security * - Vector storage for semantic caching * - Hybrid model routing (local Ollama + cloud APIs) @@ -57,7 +57,7 @@ class SQLite3DB; * - All public methods are thread-safe using pthread rwlock * - Use wrlock()/wrunlock() for manual locking if needed * - * @see NL2SQL_Converter, Anomaly_Detector + * @see LLM_Bridge, Anomaly_Detector */ class AI_Features_Manager { private: @@ -65,7 +65,7 @@ private: pthread_rwlock_t rwlock; // Sub-components - NL2SQL_Converter* nl2sql_converter; + LLM_Bridge* llm_bridge; Anomaly_Detector* anomaly_detector; SQLite3DB* vector_db; @@ -73,7 +73,7 @@ private: int init_vector_db(); int init_anomaly_detector(); void close_vector_db(); - void close_nl2sql(); + void close_llm_bridge(); void close_anomaly_detector(); public: @@ -84,16 +84,16 @@ public: * Configuration is managed by the GenAI module (GloGATH). */ struct { - unsigned long long nl2sql_total_requests; - unsigned long long nl2sql_cache_hits; - unsigned long long nl2sql_local_model_calls; - unsigned long long nl2sql_cloud_model_calls; - unsigned long long nl2sql_total_response_time_ms; // Total response time for all LLM calls - unsigned long long nl2sql_cache_total_lookup_time_ms; // Total time spent in cache lookups - unsigned long long nl2sql_cache_total_store_time_ms; // Total time spent in cache storage - unsigned long long nl2sql_cache_lookups; - unsigned long long nl2sql_cache_stores; - unsigned long long nl2sql_cache_misses; + unsigned long long llm_total_requests; + unsigned long long llm_cache_hits; + unsigned long long llm_local_model_calls; + unsigned long long llm_cloud_model_calls; + unsigned long long llm_total_response_time_ms; // Total response time for all LLM calls + unsigned long long llm_cache_total_lookup_time_ms; // Total time spent in cache lookups + unsigned long long llm_cache_total_store_time_ms; // Total time spent in cache storage + unsigned long long llm_cache_lookups; + unsigned long long llm_cache_stores; + unsigned long long llm_cache_misses; unsigned long long anomaly_total_checks; unsigned long long anomaly_blocked_queries; unsigned long long anomaly_flagged_queries; @@ -113,7 +113,7 @@ public: /** * @brief Initialize all AI features * - * Initializes vector database, NL2SQL converter, and anomaly detector. + * Initializes vector database, LLM bridge, and anomaly detector. * This must be called after ProxySQL configuration is loaded. * * @return 0 on success, non-zero on failure @@ -129,14 +129,14 @@ public: void shutdown(); /** - * @brief Initialize NL2SQL converter + * @brief Initialize LLM bridge * - * Initializes the NL2SQL converter if not already initialized. - * This can be called at runtime after enabling nl2sql. + * Initializes the LLM bridge if not already initialized. + * This can be called at runtime after enabling llm. * * @return 0 on success, non-zero on failure */ - int init_nl2sql(); + int init_llm_bridge(); /** * @brief Acquire write lock for thread-safe operations @@ -156,25 +156,25 @@ public: void wrunlock(); /** - * @brief Get NL2SQL converter instance + * @brief Get LLM bridge instance * - * @return Pointer to NL2SQL_Converter or NULL if not initialized + * @return Pointer to LLM_Bridge or NULL if not initialized * * @note Thread-safe when called within wrlock()/wrunlock() pair */ - NL2SQL_Converter* get_nl2sql() { return nl2sql_converter; } + LLM_Bridge* get_llm_bridge() { return llm_bridge; } // Status variable update methods - void increment_nl2sql_total_requests() { __sync_fetch_and_add(&status_variables.nl2sql_total_requests, 1); } - void increment_nl2sql_cache_hits() { __sync_fetch_and_add(&status_variables.nl2sql_cache_hits, 1); } - void increment_nl2sql_cache_misses() { __sync_fetch_and_add(&status_variables.nl2sql_cache_misses, 1); } - void increment_nl2sql_local_model_calls() { __sync_fetch_and_add(&status_variables.nl2sql_local_model_calls, 1); } - void increment_nl2sql_cloud_model_calls() { __sync_fetch_and_add(&status_variables.nl2sql_cloud_model_calls, 1); } - void add_nl2sql_response_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.nl2sql_total_response_time_ms, ms); } - void add_nl2sql_cache_lookup_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.nl2sql_cache_total_lookup_time_ms, ms); } - void add_nl2sql_cache_store_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.nl2sql_cache_total_store_time_ms, ms); } - void increment_nl2sql_cache_lookups() { __sync_fetch_and_add(&status_variables.nl2sql_cache_lookups, 1); } - void increment_nl2sql_cache_stores() { __sync_fetch_and_add(&status_variables.nl2sql_cache_stores, 1); } + void increment_llm_total_requests() { __sync_fetch_and_add(&status_variables.llm_total_requests, 1); } + void increment_llm_cache_hits() { __sync_fetch_and_add(&status_variables.llm_cache_hits, 1); } + void increment_llm_cache_misses() { __sync_fetch_and_add(&status_variables.llm_cache_misses, 1); } + void increment_llm_local_model_calls() { __sync_fetch_and_add(&status_variables.llm_local_model_calls, 1); } + void increment_llm_cloud_model_calls() { __sync_fetch_and_add(&status_variables.llm_cloud_model_calls, 1); } + void add_llm_response_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.llm_total_response_time_ms, ms); } + void add_llm_cache_lookup_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.llm_cache_total_lookup_time_ms, ms); } + void add_llm_cache_store_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.llm_cache_total_store_time_ms, ms); } + void increment_llm_cache_lookups() { __sync_fetch_and_add(&status_variables.llm_cache_lookups, 1); } + void increment_llm_cache_stores() { __sync_fetch_and_add(&status_variables.llm_cache_stores, 1); } /** * @brief Get anomaly detector instance diff --git a/include/AI_Tool_Handler.h b/include/AI_Tool_Handler.h index 85e102284..2eb81e1f0 100644 --- a/include/AI_Tool_Handler.h +++ b/include/AI_Tool_Handler.h @@ -19,7 +19,7 @@ #include // Forward declarations -class NL2SQL_Converter; +class LLM_Bridge; class Anomaly_Detector; /** @@ -31,7 +31,7 @@ class Anomaly_Detector; */ class AI_Tool_Handler : public MCP_Tool_Handler { private: - NL2SQL_Converter* nl2sql_converter; + LLM_Bridge* llm_bridge; Anomaly_Detector* anomaly_detector; bool owns_components; @@ -50,7 +50,7 @@ public: /** * @brief Constructor - uses existing AI components */ - AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly); + AI_Tool_Handler(LLM_Bridge* llm, Anomaly_Detector* anomaly); /** * @brief Constructor - creates own components diff --git a/include/GenAI_Thread.h b/include/GenAI_Thread.h index 641713de6..ce4183ed3 100644 --- a/include/GenAI_Thread.h +++ b/include/GenAI_Thread.h @@ -26,6 +26,7 @@ enum GenAI_Operation : uint32_t { GENAI_OP_EMBEDDING = 0, ///< Generate embeddings for documents GENAI_OP_RERANK = 1, ///< Rerank documents by relevance to query GENAI_OP_JSON = 2, ///< Autonomous JSON query processing (handles embed/rerank/document_from_sql) + GENAI_OP_LLM = 3, ///< Generic LLM bridge processing }; /** @@ -202,17 +203,17 @@ public: // AI Features master switches bool genai_enabled; ///< Master enable for all AI features (default: false) - bool genai_nl2sql_enabled; ///< Enable NL2SQL feature (default: false) + bool genai_llm_enabled; ///< Enable LLM bridge feature (default: false) bool genai_anomaly_enabled; ///< Enable anomaly detection (default: false) - // NL2SQL configuration - char* genai_nl2sql_query_prefix; ///< Prefix for NL2SQL queries (default: "NL2SQL:") - char* genai_nl2sql_provider; ///< Provider format: "openai" or "anthropic" (default: "openai") - char* genai_nl2sql_provider_url; ///< LLM endpoint URL (default: http://localhost:11434/v1/chat/completions) - char* genai_nl2sql_provider_model; ///< Model name (default: "llama3.2") - char* genai_nl2sql_provider_key; ///< API key (default: NULL) - int genai_nl2sql_cache_similarity_threshold; ///< Semantic cache threshold 0-100 (default: 85) - int genai_nl2sql_timeout_ms; ///< LLM request timeout in ms (default: 30000) + // LLM bridge configuration + char* genai_llm_provider; ///< Provider format: "openai" or "anthropic" (default: "openai") + char* genai_llm_provider_url; ///< LLM endpoint URL (default: http://localhost:11434/v1/chat/completions) + char* genai_llm_provider_model; ///< Model name (default: "llama3.2") + char* genai_llm_provider_key; ///< API key (default: NULL) + int genai_llm_cache_similarity_threshold; ///< Semantic cache threshold 0-100 (default: 85) + int genai_llm_cache_enabled; ///< Enable semantic cache (default: true) + int genai_llm_timeout_ms; ///< LLM request timeout in ms (default: 30000) // Anomaly detection configuration int genai_anomaly_risk_threshold; ///< Risk score threshold for blocking 0-100 (default: 70) diff --git a/include/NL2SQL_Converter.h b/include/LLM_Bridge.h similarity index 54% rename from include/NL2SQL_Converter.h rename to include/LLM_Bridge.h index 87460b843..4c7015581 100644 --- a/include/NL2SQL_Converter.h +++ b/include/LLM_Bridge.h @@ -1,35 +1,34 @@ /** - * @file nl2sql_converter.h - * @brief Natural Language to SQL Converter for ProxySQL + * @file llm_bridge.h + * @brief Generic LLM Bridge for ProxySQL * - * The NL2SQL_Converter class provides natural language to SQL conversion + * The LLM_Bridge class provides a generic interface to Large Language Models * using multiple LLM providers with hybrid deployment and vector-based * semantic caching. * * Key Features: * - Multi-provider LLM support (local + generic cloud) * - Semantic similarity caching using sqlite-vec - * - Schema-aware conversion + * - Generic prompt handling (not SQL-specific) * - Configurable model selection based on latency/budget * - Generic provider support (OpenAI-compatible, Anthropic-compatible) * - * @date 2025-01-16 - * @version 0.2.0 + * @date 2025-01-17 + * @version 1.0.0 * * Example Usage: * @code - * NL2SQLRequest req; - * req.natural_language = "Show top 10 customers"; - * req.schema_name = "sales"; - * NL2SQLResult result = converter->convert(req); - * std::cout << result.sql_query << std::endl; + * LLMRequest req; + * req.prompt = "Summarize this data..."; + * LLMResult result = bridge->process(req); + * std::cout << result.text_response << std::endl; * @endcode */ -#ifndef __CLASS_NL2SQL_CONVERTER_H -#define __CLASS_NL2SQL_CONVERTER_H +#ifndef __CLASS_LLM_BRIDGE_H +#define __CLASS_LLM_BRIDGE_H -#define NL2SQL_CONVERTER_VERSION "0.2.0" +#define LLM_BRIDGE_VERSION "1.0.0" #include "proxysql.h" #include @@ -39,73 +38,65 @@ class SQLite3DB; /** - * @brief Result structure for NL2SQL conversion + * @brief Result structure for LLM bridge processing * - * Contains the generated SQL query along with metadata including - * confidence score, explanation, cache status, and error details. - * - * @note The confidence score is a heuristic based on SQL validation - * and LLM response quality. Actual SQL correctness should be - * verified before execution. + * Contains the LLM text response along with metadata including + * cache status, error details, and performance timing. * * @note When errors occur, error_code, error_details, and http_status_code * provide diagnostic information for troubleshooting. */ -struct NL2SQLResult { - std::string sql_query; ///< Generated SQL query - float confidence; ///< Confidence score 0.0-1.0 - std::string explanation; ///< Which model generated this - std::vector tables_used; ///< Tables referenced in SQL - bool cached; ///< True if from semantic cache - int64_t cache_id; ///< Cache entry ID for tracking +struct LLMResult { + std::string text_response; ///< LLM-generated text response + std::string explanation; ///< Which model generated this + bool cached; ///< True if from semantic cache + int64_t cache_id; ///< Cache entry ID for tracking - // Error details - populated when conversion fails - std::string error_code; ///< Structured error code (e.g., "ERR_API_KEY_MISSING") - std::string error_details; ///< Detailed error context with query, provider, URL - int http_status_code; ///< HTTP status code if applicable (0 if N/A) - std::string provider_used; ///< Which provider was attempted + // Error details - populated when processing fails + std::string error_code; ///< Structured error code (e.g., "ERR_API_KEY_MISSING") + std::string error_details; ///< Detailed error context with query, provider, URL + int http_status_code; ///< HTTP status code if applicable (0 if N/A) + std::string provider_used; ///< Which provider was attempted // Performance timing information - int total_time_ms; ///< Total conversion time in milliseconds - int cache_lookup_time_ms; ///< Cache lookup time in milliseconds - int cache_store_time_ms; ///< Cache store time in milliseconds - int llm_call_time_ms; ///< LLM call time in milliseconds - bool cache_hit; ///< True if cache was hit + int total_time_ms; ///< Total processing time in milliseconds + int cache_lookup_time_ms; ///< Cache lookup time in milliseconds + int cache_store_time_ms; ///< Cache store time in milliseconds + int llm_call_time_ms; ///< LLM call time in milliseconds + bool cache_hit; ///< True if cache was hit - NL2SQLResult() : confidence(0.0f), cached(false), cache_id(0), http_status_code(0), - total_time_ms(0), cache_lookup_time_ms(0), cache_store_time_ms(0), - llm_call_time_ms(0), cache_hit(false) {} + LLMResult() : cached(false), cache_id(0), http_status_code(0), + total_time_ms(0), cache_lookup_time_ms(0), cache_store_time_ms(0), + llm_call_time_ms(0), cache_hit(false) {} }; /** - * @brief Request structure for NL2SQL conversion + * @brief Request structure for LLM bridge processing * - * Contains the natural language query and context for conversion. - * Context includes schema name and optional table list for better - * SQL generation. + * Contains the prompt text and context for LLM processing. * * @note If max_latency_ms is set and < 500ms, the system will prefer * local Ollama regardless of provider preference. */ -struct NL2SQLRequest { - std::string natural_language; ///< Natural language query text - std::string schema_name; ///< Current database/schema name - int max_latency_ms; ///< Max acceptable latency (ms) - bool allow_cache; ///< Enable semantic cache lookup - std::vector context_tables; ///< Optional table hints for schema +struct LLMRequest { + std::string prompt; ///< Prompt text for LLM + std::string system_message; ///< Optional system role message + std::string schema_name; ///< Optional schema/database context + int max_latency_ms; ///< Max acceptable latency (ms) + bool allow_cache; ///< Enable semantic cache lookup // Request tracking for correlation and debugging - std::string request_id; ///< Unique ID for this request (UUID-like) + std::string request_id; ///< Unique ID for this request (UUID-like) // Retry configuration for transient failures - int max_retries; ///< Maximum retry attempts (default: 3) - int retry_backoff_ms; ///< Initial backoff in ms (default: 1000) - double retry_multiplier; ///< Backoff multiplier (default: 2.0) - int retry_max_backoff_ms; ///< Maximum backoff in ms (default: 30000) + int max_retries; ///< Maximum retry attempts (default: 3) + int retry_backoff_ms; ///< Initial backoff in ms (default: 1000) + double retry_multiplier; ///< Backoff multiplier (default: 2.0) + int retry_max_backoff_ms; ///< Maximum backoff in ms (default: 30000) - NL2SQLRequest() : max_latency_ms(0), allow_cache(true), - max_retries(3), retry_backoff_ms(1000), - retry_multiplier(2.0), retry_max_backoff_ms(30000) { + LLMRequest() : max_latency_ms(0), allow_cache(true), + max_retries(3), retry_backoff_ms(1000), + retry_multiplier(2.0), retry_max_backoff_ms(30000) { // Generate UUID-like request ID for correlation char uuid[64]; snprintf(uuid, sizeof(uuid), "%08lx-%04x-%04x-%04x-%012lx", @@ -117,7 +108,7 @@ struct NL2SQLRequest { }; /** - * @brief Error codes for NL2SQL conversion + * @brief Error codes for LLM bridge processing * * Structured error codes that provide machine-readable error information * for programmatic handling and user-friendly error messages. @@ -127,9 +118,9 @@ struct NL2SQLRequest { * - Logging and monitoring * - User error messages * - * @see nl2sql_error_code_to_string() + * @see llm_error_code_to_string() */ -enum class NL2SQLErrorCode { +enum class LLMErrorCode { SUCCESS = 0, ///< No error ERR_API_KEY_MISSING, ///< API key not configured ERR_API_KEY_INVALID, ///< API key format is invalid @@ -139,7 +130,6 @@ enum class NL2SQLErrorCode { ERR_SERVER_ERROR, ///< Server error (HTTP 5xx) ERR_EMPTY_RESPONSE, ///< Empty response from LLM ERR_INVALID_RESPONSE, ///< Malformed response from LLM - ERR_SQL_INJECTION_DETECTED, ///< SQL injection pattern detected ERR_VALIDATION_FAILED, ///< Input validation failed ERR_UNKNOWN_PROVIDER, ///< Invalid provider name ERR_REQUEST_TOO_LARGE ///< Request exceeds size limit @@ -154,10 +144,10 @@ enum class NL2SQLErrorCode { * @param code The error code to convert * @return String representation of the error code */ -const char* nl2sql_error_code_to_string(NL2SQLErrorCode code); +const char* llm_error_code_to_string(LLMErrorCode code); /** - * @brief Model provider format types for NL2SQL conversion + * @brief Model provider format types for LLM bridge * * Defines the API format to use for generic providers: * - GENERIC_OPENAI: Any OpenAI-compatible endpoint (including Ollama) @@ -176,34 +166,33 @@ enum class ModelProvider { }; /** - * @brief NL2SQL Converter class + * @brief Generic LLM Bridge class * - * Converts natural language queries to SQL using LLMs with hybrid - * local/cloud model support and vector cache. + * Processes prompts using LLMs with hybrid local/cloud model support + * and vector cache. * * Architecture: * - Vector cache for semantic similarity (sqlite-vec) * - Model selection based on latency/budget * - Generic HTTP client (libcurl) supporting multiple API formats - * - Schema-aware prompt building + * - Generic prompt handling (not tied to SQL) * * Configuration Variables: - * - ai_nl2sql_provider: "ollama", "openai", or "anthropic" - * - ai_nl2sql_provider_url: Custom endpoint URL (for generic providers) - * - ai_nl2sql_provider_model: Model name - * - ai_nl2sql_provider_key: API key (optional for local) + * - genai_llm_provider: "ollama", "openai", or "anthropic" + * - genai_llm_provider_url: Custom endpoint URL (for generic providers) + * - genai_llm_provider_model: Model name + * - genai_llm_provider_key: API key (optional for local) * * Thread Safety: * - This class is NOT thread-safe by itself * - External locking must be provided by AI_Features_Manager * - * @see AI_Features_Manager, NL2SQLRequest, NL2SQLResult + * @see AI_Features_Manager, LLMRequest, LLMResult */ -class NL2SQL_Converter { +class LLM_Bridge { private: struct { bool enabled; - char* query_prefix; char* provider; ///< "openai" or "anthropic" char* provider_url; ///< Generic endpoint URL char* provider_model; ///< Model name @@ -215,7 +204,7 @@ private: SQLite3DB* vector_db; // Internal methods - std::string build_prompt(const NL2SQLRequest& req, const std::string& schema_context); + std::string build_prompt(const LLMRequest& req); std::string call_generic_openai(const std::string& prompt, const std::string& model, const std::string& url, const char* key, const std::string& req_id = ""); @@ -233,34 +222,31 @@ private: const std::string& req_id, int max_retries, int initial_backoff_ms, double backoff_multiplier, int max_backoff_ms); - NL2SQLResult check_vector_cache(const NL2SQLRequest& req); - void store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result); - std::string get_schema_context(const std::vector& tables); - ModelProvider select_model(const NL2SQLRequest& req); - std::vector get_query_embedding(const std::string& text); - float validate_and_score_sql(const std::string& sql); + LLMResult check_cache(const LLMRequest& req); + void store_in_cache(const LLMRequest& req, const LLMResult& result); + ModelProvider select_model(const LLMRequest& req); + std::vector get_text_embedding(const std::string& text); public: /** * @brief Constructor - initializes with default configuration * * Sets up default values: - * - query_prefix: "NL2SQL:" * - provider: "openai" * - provider_url: "http://localhost:11434/v1/chat/completions" (Ollama default) * - provider_model: "llama3.2" * - cache_similarity_threshold: 85 * - timeout_ms: 30000 */ - NL2SQL_Converter(); + LLM_Bridge(); /** * @brief Destructor - frees allocated resources */ - ~NL2SQL_Converter(); + ~LLM_Bridge(); /** - * @brief Initialize the NL2SQL converter + * @brief Initialize the LLM bridge * * Initializes vector DB connection and validates configuration. * The vector_db will be provided by AI_Features_Manager. @@ -270,7 +256,7 @@ public: int init(); /** - * @brief Shutdown the NL2SQL converter + * @brief Shutdown the LLM bridge * * Closes vector DB connection and cleans up resources. */ @@ -296,45 +282,38 @@ public: const char* provider_key, int cache_threshold, int timeout); /** - * @brief Convert natural language query to SQL + * @brief Process a prompt using the LLM * - * This is the main entry point for NL2SQL conversion. The flow is: - * 1. Check vector cache for semantically similar queries - * 2. Build prompt with schema context + * This is the main entry point for LLM bridge processing. The flow is: + * 1. Check vector cache for semantically similar prompts + * 2. Build prompt with optional system message * 3. Select appropriate model (Ollama or generic provider) * 4. Call LLM API - * 5. Parse and clean SQL response + * 5. Parse response * 6. Store in vector cache for future use * - * @param req NL2SQL request containing natural language query and context - * @return NL2SQLResult with generated SQL, confidence score, and metadata + * @param req LLM request containing prompt and context + * @return LLMResult with text response and metadata * * @note This is a synchronous blocking call. For non-blocking behavior, * use the async interface via MySQL_Session. * - * @note The confidence score is heuristic-based. Actual SQL correctness - * should be verified before execution. - * - * @see NL2SQLRequest, NL2SQLResult, ModelProvider - * * Example: * @code - * NL2SQLRequest req; - * req.natural_language = "Find customers with orders > $1000"; + * LLMRequest req; + * req.prompt = "Explain this query: SELECT * FROM users"; * req.allow_cache = true; - * NL2SQLResult result = converter.convert(req); - * if (result.confidence > 0.7f) { - * execute_sql(result.sql_query); - * } + * LLMResult result = bridge.process(req); + * std::cout << result.text_response << std::endl; * @endcode */ - NL2SQLResult convert(const NL2SQLRequest& req); + LLMResult process(const LLMRequest& req); /** * @brief Clear the vector cache * - * Removes all cached NL2SQL conversions from the vector database. - * This is useful for testing or when schema changes significantly. + * Removes all cached LLM responses from the vector database. + * This is useful for testing or when context changes significantly. */ void clear_cache(); @@ -342,7 +321,7 @@ public: * @brief Get cache statistics * * Returns JSON string with cache metrics: - * - entries: Total number of cached conversions + * - entries: Total number of cached responses * - hits: Number of cache hits * - misses: Number of cache misses * @@ -351,7 +330,4 @@ public: std::string get_cache_stats(); }; -// Global instance (defined by AI_Features_Manager) -// extern NL2SQL_Converter *GloNL2SQL; - -#endif // __CLASS_NL2SQL_CONVERTER_H +#endif // __CLASS_LLM_BRIDGE_H diff --git a/include/MySQL_Session.h b/include/MySQL_Session.h index a584d0c1c..f2b959a3d 100644 --- a/include/MySQL_Session.h +++ b/include/MySQL_Session.h @@ -284,7 +284,7 @@ class MySQL_Session: public Base_Sessionexecute(create_nl2sql_cache) != 0) { - proxy_error("AI: Failed to create nl2sql_cache table\n"); + if (vector_db->execute(create_llm_cache) != 0) { + proxy_error("AI: Failed to create llm_cache table\n"); return -1; } @@ -108,8 +108,8 @@ int AI_Features_Manager::init_vector_db() { const char* create_query_history = "CREATE TABLE IF NOT EXISTS query_history (" "id INTEGER PRIMARY KEY AUTOINCREMENT," - "query_text TEXT NOT NULL," - "generated_sql TEXT," + "prompt TEXT NOT NULL," + "response TEXT," "embedding BLOB," "execution_time_ms INTEGER," "success BOOLEAN," @@ -124,16 +124,16 @@ int AI_Features_Manager::init_vector_db() { // Create virtual vector tables for similarity search using sqlite-vec // Note: sqlite-vec extension is auto-loaded in Admin_Bootstrap.cpp:612 - // 1. NL2SQL cache virtual table - const char* create_nl2sql_vec = - "CREATE VIRTUAL TABLE IF NOT EXISTS nl2sql_cache_vec USING vec0(" + // 1. LLM cache virtual table + const char* create_llm_vec = + "CREATE VIRTUAL TABLE IF NOT EXISTS llm_cache_vec USING vec0(" "embedding float(1536)" ");"; - if (vector_db->execute(create_nl2sql_vec) != 0) { - proxy_error("AI: Failed to create nl2sql_cache_vec virtual table\n"); + if (vector_db->execute(create_llm_vec) != 0) { + proxy_error("AI: Failed to create llm_cache_vec virtual table\n"); // Virtual table creation failure is not critical - log and continue - proxy_debug(PROXY_DEBUG_GENAI, 3, "Continuing without nl2sql_cache_vec"); + proxy_debug(PROXY_DEBUG_GENAI, 3, "Continuing without llm_cache_vec"); } // 2. Anomaly patterns virtual table @@ -162,37 +162,37 @@ int AI_Features_Manager::init_vector_db() { return 0; } -int AI_Features_Manager::init_nl2sql() { - if (!GloGATH->variables.genai_nl2sql_enabled) { - proxy_info("AI: NL2SQL disabled, skipping initialization\n"); +int AI_Features_Manager::init_llm_bridge() { + if (!GloGATH->variables.genai_llm_enabled) { + proxy_info("AI: LLM bridge disabled, skipping initialization\n"); return 0; } - proxy_info("AI: Initializing NL2SQL Converter\n"); + proxy_info("AI: Initializing LLM Bridge\n"); - nl2sql_converter = new NL2SQL_Converter(); + llm_bridge = new LLM_Bridge(); // Set vector database - nl2sql_converter->set_vector_db(vector_db); + llm_bridge->set_vector_db(vector_db); // Update config with current variables from GenAI module - nl2sql_converter->update_config( - GloGATH->variables.genai_nl2sql_provider, - GloGATH->variables.genai_nl2sql_provider_url, - GloGATH->variables.genai_nl2sql_provider_model, - GloGATH->variables.genai_nl2sql_provider_key, - GloGATH->variables.genai_nl2sql_cache_similarity_threshold, - GloGATH->variables.genai_nl2sql_timeout_ms + llm_bridge->update_config( + GloGATH->variables.genai_llm_provider, + GloGATH->variables.genai_llm_provider_url, + GloGATH->variables.genai_llm_provider_model, + GloGATH->variables.genai_llm_provider_key, + GloGATH->variables.genai_llm_cache_similarity_threshold, + GloGATH->variables.genai_llm_timeout_ms ); - if (nl2sql_converter->init() != 0) { - proxy_error("AI: Failed to initialize NL2SQL Converter\n"); - delete nl2sql_converter; - nl2sql_converter = NULL; + if (llm_bridge->init() != 0) { + proxy_error("AI: Failed to initialize LLM Bridge\n"); + delete llm_bridge; + llm_bridge = NULL; return -1; } - proxy_info("AI: NL2SQL Converter initialized\n"); + proxy_info("AI: LLM Bridge initialized\n"); return 0; } @@ -223,11 +223,11 @@ void AI_Features_Manager::close_vector_db() { } } -void AI_Features_Manager::close_nl2sql() { - if (nl2sql_converter) { - nl2sql_converter->close(); - delete nl2sql_converter; - nl2sql_converter = NULL; +void AI_Features_Manager::close_llm_bridge() { + if (llm_bridge) { + llm_bridge->close(); + delete llm_bridge; + llm_bridge = NULL; } } @@ -247,15 +247,15 @@ int AI_Features_Manager::init() { return 0; } - // Initialize vector storage first (needed by both NL2SQL and Anomaly Detector) + // Initialize vector storage first (needed by both LLM bridge and Anomaly Detector) if (init_vector_db() != 0) { proxy_error("AI: Failed to initialize vector storage\n"); return -1; } - // Initialize NL2SQL - if (init_nl2sql() != 0) { - proxy_error("AI: Failed to initialize NL2SQL\n"); + // Initialize LLM bridge + if (init_llm_bridge() != 0) { + proxy_error("AI: Failed to initialize LLM bridge\n"); return -1; } @@ -275,7 +275,7 @@ void AI_Features_Manager::shutdown() { proxy_info("AI: Shutting down AI Features Manager\n"); - close_nl2sql(); + close_llm_bridge(); close_anomaly_detector(); close_vector_db(); @@ -299,7 +299,7 @@ std::string AI_Features_Manager::get_status_json() { snprintf(buf, sizeof(buf), "{" "\"version\": \"%s\"," - "\"nl2sql\": {" + "\"llm\": {" "\"total_requests\": %llu," "\"cache_hits\": %llu," "\"local_calls\": %llu," @@ -321,16 +321,16 @@ std::string AI_Features_Manager::get_status_json() { "}" "}", AI_FEATURES_MANAGER_VERSION, - status_variables.nl2sql_total_requests, - status_variables.nl2sql_cache_hits, - status_variables.nl2sql_local_model_calls, - status_variables.nl2sql_cloud_model_calls, - status_variables.nl2sql_total_response_time_ms, - status_variables.nl2sql_cache_total_lookup_time_ms, - status_variables.nl2sql_cache_total_store_time_ms, - status_variables.nl2sql_cache_lookups, - status_variables.nl2sql_cache_stores, - status_variables.nl2sql_cache_misses, + status_variables.llm_total_requests, + status_variables.llm_cache_hits, + status_variables.llm_local_model_calls, + status_variables.llm_cloud_model_calls, + status_variables.llm_total_response_time_ms, + status_variables.llm_cache_total_lookup_time_ms, + status_variables.llm_cache_total_store_time_ms, + status_variables.llm_cache_lookups, + status_variables.llm_cache_stores, + status_variables.llm_cache_misses, status_variables.anomaly_total_checks, status_variables.anomaly_blocked_queries, status_variables.anomaly_flagged_queries, diff --git a/lib/AI_Tool_Handler.cpp b/lib/AI_Tool_Handler.cpp index 314a3fbe5..afe9a9bb2 100644 --- a/lib/AI_Tool_Handler.cpp +++ b/lib/AI_Tool_Handler.cpp @@ -9,7 +9,7 @@ */ #include "AI_Tool_Handler.h" -#include "NL2SQL_Converter.h" +#include "LLM_Bridge.h" #include "Anomaly_Detector.h" #include "AI_Features_Manager.h" #include "proxysql_debug.h" @@ -29,8 +29,8 @@ using json = nlohmann::json; /** * @brief Constructor using existing AI components */ -AI_Tool_Handler::AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly) - : nl2sql_converter(nl2sql), +AI_Tool_Handler::AI_Tool_Handler(LLM_Bridge* llm, Anomaly_Detector* anomaly) + : llm_bridge(llm), anomaly_detector(anomaly), owns_components(false) { @@ -42,13 +42,13 @@ AI_Tool_Handler::AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* ano * Note: This implementation uses global instances */ AI_Tool_Handler::AI_Tool_Handler() - : nl2sql_converter(NULL), + : llm_bridge(NULL), anomaly_detector(NULL), owns_components(false) { // Use global instances from AI_Features_Manager if (GloAI) { - nl2sql_converter = GloAI->get_nl2sql(); + llm_bridge = GloAI->get_llm_bridge(); anomaly_detector = GloAI->get_anomaly_detector(); } proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler created (using global instances)\n"); @@ -70,8 +70,8 @@ AI_Tool_Handler::~AI_Tool_Handler() { * @brief Initialize the tool handler */ int AI_Tool_Handler::init() { - if (!nl2sql_converter) { - proxy_error("AI_Tool_Handler: NL2SQL converter not available\n"); + if (!llm_bridge) { + proxy_error("AI_Tool_Handler: LLM bridge not available\n"); return -1; } proxy_info("AI_Tool_Handler initialized\n"); @@ -199,73 +199,13 @@ json AI_Tool_Handler::execute_tool(const std::string& tool_name, const json& arg proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler: execute_tool(%s)\n", tool_name.c_str()); try { - // NL2SQL conversion tool + // LLM processing tool (generic, replaces NL2SQL) if (tool_name == "ai_nl2sql_convert") { - if (!nl2sql_converter) { - return create_error_response("NL2SQL converter not available"); - } - - // Extract parameters - std::string natural_language = get_json_string(arguments, "natural_language"); - if (natural_language.empty()) { - return create_error_response("Missing required parameter: natural_language"); - } - - std::string schema = get_json_string(arguments, "schema"); - int max_latency_ms = get_json_int(arguments, "max_latency_ms", 0); - bool allow_cache = true; - if (arguments.contains("allow_cache") && !arguments["allow_cache"].is_null()) { - if (arguments["allow_cache"].is_boolean()) { - allow_cache = arguments["allow_cache"].get(); - } else if (arguments["allow_cache"].is_string()) { - std::string val = arguments["allow_cache"].get(); - allow_cache = (val == "true" || val == "1"); - } - } - - // Parse context_tables - std::vector context_tables; - std::string tables_str = get_json_string(arguments, "context_tables"); - if (!tables_str.empty()) { - std::istringstream ts(tables_str); - std::string table; - while (std::getline(ts, table, ',')) { - table.erase(0, table.find_first_not_of(" \t")); - table.erase(table.find_last_not_of(" \t") + 1); - if (!table.empty()) { - context_tables.push_back(table); - } - } - } - - // Create NL2SQL request - NL2SQLRequest req; - req.natural_language = natural_language; - req.schema_name = schema; - req.max_latency_ms = max_latency_ms; - req.allow_cache = allow_cache; - req.context_tables = context_tables; - - // Call NL2SQL converter - NL2SQLResult result = nl2sql_converter->convert(req); - - // Build response - json response_data; - response_data["sql_query"] = result.sql_query; - response_data["confidence"] = result.confidence; - response_data["explanation"] = result.explanation; - response_data["cached"] = result.cached; - response_data["cache_id"] = result.cache_id; - - // Add tables used if available - if (!result.tables_used.empty()) { - response_data["tables_used"] = result.tables_used; - } - - proxy_info("AI_Tool_Handler: NL2SQL conversion complete. SQL: %s, Confidence: %.2f\n", - result.sql_query.c_str(), result.confidence); - - return create_success_response(response_data); + // NOTE: The ai_nl2sql_convert tool is deprecated. + // NL2SQL functionality has been replaced with a generic LLM bridge. + // Future NL2SQL will be implemented as a Web UI using external agents (Claude Code + MCP server). + return create_error_response("The ai_nl2sql_convert tool is deprecated. " + "Use the generic LLM: queries via MySQL protocol instead."); } // Unknown tool diff --git a/lib/Admin_FlushVariables.cpp b/lib/Admin_FlushVariables.cpp index 576b53e70..c9bf71484 100644 --- a/lib/Admin_FlushVariables.cpp +++ b/lib/Admin_FlushVariables.cpp @@ -1079,11 +1079,11 @@ void ProxySQL_Admin::flush_genai_variables___database_to_runtime(SQLite3DB* db, pthread_mutex_unlock(&GloVars.checksum_mutex); } - // Check if NL2SQL needs to be initialized - if (GloAI && GloGATH->variables.genai_nl2sql_enabled && !GloAI->get_nl2sql()) { - proxy_info("NL2SQL enabled but not initialized, initializing now\n"); - if (GloAI->init_nl2sql() != 0) { - proxy_error("Failed to initialize NL2SQL converter\n"); + // Check if LLM bridge needs to be initialized + if (GloAI && GloGATH->variables.genai_llm_enabled && !GloAI->get_llm_bridge()) { + proxy_info("LLM bridge enabled but not initialized, initializing now\n"); + if (GloAI->init_llm_bridge() != 0) { + proxy_error("Failed to initialize LLM bridge\n"); } } diff --git a/lib/GenAI_Thread.cpp b/lib/GenAI_Thread.cpp index 56d73bc77..e3a51736a 100644 --- a/lib/GenAI_Thread.cpp +++ b/lib/GenAI_Thread.cpp @@ -45,17 +45,17 @@ static const char* genai_thread_variables_names[] = { // AI Features master switches "enabled", - "nl2sql_enabled", + "llm_enabled", "anomaly_enabled", - // NL2SQL configuration - "nl2sql_query_prefix", - "nl2sql_provider", - "nl2sql_provider_url", - "nl2sql_provider_model", - "nl2sql_provider_key", - "nl2sql_cache_similarity_threshold", - "nl2sql_timeout_ms", + // LLM bridge configuration + "llm_provider", + "llm_provider_url", + "llm_provider_model", + "llm_provider_key", + "llm_cache_similarity_threshold", + "llm_cache_enabled", + "llm_timeout_ms", // Anomaly detection configuration "anomaly_risk_threshold", @@ -153,17 +153,17 @@ GenAI_Threads_Handler::GenAI_Threads_Handler() { // AI Features master switches variables.genai_enabled = false; - variables.genai_nl2sql_enabled = false; + variables.genai_llm_enabled = false; variables.genai_anomaly_enabled = false; - // NL2SQL configuration - variables.genai_nl2sql_query_prefix = strdup("NL2SQL:"); - variables.genai_nl2sql_provider = strdup("openai"); - variables.genai_nl2sql_provider_url = strdup("http://localhost:11434/v1/chat/completions"); - variables.genai_nl2sql_provider_model = strdup("llama3.2"); - variables.genai_nl2sql_provider_key = NULL; - variables.genai_nl2sql_cache_similarity_threshold = 85; - variables.genai_nl2sql_timeout_ms = 30000; + // LLM bridge configuration + variables.genai_llm_provider = strdup("openai"); + variables.genai_llm_provider_url = strdup("http://localhost:11434/v1/chat/completions"); + variables.genai_llm_provider_model = strdup("llama3.2"); + variables.genai_llm_provider_key = NULL; + variables.genai_llm_cache_similarity_threshold = 85; + variables.genai_llm_cache_enabled = true; + variables.genai_llm_timeout_ms = 30000; // Anomaly detection configuration variables.genai_anomaly_risk_threshold = 70; @@ -197,17 +197,15 @@ GenAI_Threads_Handler::~GenAI_Threads_Handler() { if (variables.genai_rerank_uri) free(variables.genai_rerank_uri); - // Free NL2SQL string variables - if (variables.genai_nl2sql_query_prefix) - free(variables.genai_nl2sql_query_prefix); - if (variables.genai_nl2sql_provider) - free(variables.genai_nl2sql_provider); - if (variables.genai_nl2sql_provider_url) - free(variables.genai_nl2sql_provider_url); - if (variables.genai_nl2sql_provider_model) - free(variables.genai_nl2sql_provider_model); - if (variables.genai_nl2sql_provider_key) - free(variables.genai_nl2sql_provider_key); + // Free LLM bridge string variables + if (variables.genai_llm_provider) + free(variables.genai_llm_provider); + if (variables.genai_llm_provider_url) + free(variables.genai_llm_provider_url); + if (variables.genai_llm_provider_model) + free(variables.genai_llm_provider_model); + if (variables.genai_llm_provider_key) + free(variables.genai_llm_provider_key); // Free vector storage string variables if (variables.genai_vector_db_path) @@ -377,37 +375,34 @@ char* GenAI_Threads_Handler::get_variable(char* name) { if (!strcmp(name, "enabled")) { return strdup(variables.genai_enabled ? "true" : "false"); } - if (!strcmp(name, "nl2sql_enabled")) { - return strdup(variables.genai_nl2sql_enabled ? "true" : "false"); + if (!strcmp(name, "llm_enabled")) { + return strdup(variables.genai_llm_enabled ? "true" : "false"); } if (!strcmp(name, "anomaly_enabled")) { return strdup(variables.genai_anomaly_enabled ? "true" : "false"); } - // NL2SQL configuration - if (!strcmp(name, "nl2sql_query_prefix")) { - return strdup(variables.genai_nl2sql_query_prefix ? variables.genai_nl2sql_query_prefix : ""); + // LLM configuration + if (!strcmp(name, "llm_provider")) { + return strdup(variables.genai_llm_provider ? variables.genai_llm_provider : ""); } - if (!strcmp(name, "nl2sql_provider")) { - return strdup(variables.genai_nl2sql_provider ? variables.genai_nl2sql_provider : ""); + if (!strcmp(name, "llm_provider_url")) { + return strdup(variables.genai_llm_provider_url ? variables.genai_llm_provider_url : ""); } - if (!strcmp(name, "nl2sql_provider_url")) { - return strdup(variables.genai_nl2sql_provider_url ? variables.genai_nl2sql_provider_url : ""); + if (!strcmp(name, "llm_provider_model")) { + return strdup(variables.genai_llm_provider_model ? variables.genai_llm_provider_model : ""); } - if (!strcmp(name, "nl2sql_provider_model")) { - return strdup(variables.genai_nl2sql_provider_model ? variables.genai_nl2sql_provider_model : ""); + if (!strcmp(name, "llm_provider_key")) { + return strdup(variables.genai_llm_provider_key ? variables.genai_llm_provider_key : ""); } - if (!strcmp(name, "nl2sql_provider_key")) { - return strdup(variables.genai_nl2sql_provider_key ? variables.genai_nl2sql_provider_key : ""); - } - if (!strcmp(name, "nl2sql_cache_similarity_threshold")) { + if (!strcmp(name, "llm_cache_similarity_threshold")) { char buf[64]; - sprintf(buf, "%d", variables.genai_nl2sql_cache_similarity_threshold); + sprintf(buf, "%d", variables.genai_llm_cache_similarity_threshold); return strdup(buf); } - if (!strcmp(name, "nl2sql_timeout_ms")) { + if (!strcmp(name, "llm_timeout_ms")) { char buf[64]; - sprintf(buf, "%d", variables.genai_nl2sql_timeout_ms); + sprintf(buf, "%d", variables.genai_llm_timeout_ms); return strdup(buf); } @@ -512,8 +507,8 @@ bool GenAI_Threads_Handler::set_variable(char* name, const char* value) { variables.genai_enabled = (strcmp(value, "true") == 0); return true; } - if (!strcmp(name, "nl2sql_enabled")) { - variables.genai_nl2sql_enabled = (strcmp(value, "true") == 0); + if (!strcmp(name, "llm_enabled")) { + variables.genai_llm_enabled = (strcmp(value, "true") == 0); return true; } if (!strcmp(name, "anomaly_enabled")) { @@ -521,53 +516,47 @@ bool GenAI_Threads_Handler::set_variable(char* name, const char* value) { return true; } - // NL2SQL configuration - if (!strcmp(name, "nl2sql_query_prefix")) { - if (variables.genai_nl2sql_query_prefix) - free(variables.genai_nl2sql_query_prefix); - variables.genai_nl2sql_query_prefix = strdup(value); - return true; - } - if (!strcmp(name, "nl2sql_provider")) { - if (variables.genai_nl2sql_provider) - free(variables.genai_nl2sql_provider); - variables.genai_nl2sql_provider = strdup(value); + // LLM configuration + if (!strcmp(name, "llm_provider")) { + if (variables.genai_llm_provider) + free(variables.genai_llm_provider); + variables.genai_llm_provider = strdup(value); return true; } - if (!strcmp(name, "nl2sql_provider_url")) { - if (variables.genai_nl2sql_provider_url) - free(variables.genai_nl2sql_provider_url); - variables.genai_nl2sql_provider_url = strdup(value); + if (!strcmp(name, "llm_provider_url")) { + if (variables.genai_llm_provider_url) + free(variables.genai_llm_provider_url); + variables.genai_llm_provider_url = strdup(value); return true; } - if (!strcmp(name, "nl2sql_provider_model")) { - if (variables.genai_nl2sql_provider_model) - free(variables.genai_nl2sql_provider_model); - variables.genai_nl2sql_provider_model = strdup(value); + if (!strcmp(name, "llm_provider_model")) { + if (variables.genai_llm_provider_model) + free(variables.genai_llm_provider_model); + variables.genai_llm_provider_model = strdup(value); return true; } - if (!strcmp(name, "nl2sql_provider_key")) { - if (variables.genai_nl2sql_provider_key) - free(variables.genai_nl2sql_provider_key); - variables.genai_nl2sql_provider_key = strdup(value); + if (!strcmp(name, "llm_provider_key")) { + if (variables.genai_llm_provider_key) + free(variables.genai_llm_provider_key); + variables.genai_llm_provider_key = strdup(value); return true; } - if (!strcmp(name, "nl2sql_cache_similarity_threshold")) { + if (!strcmp(name, "llm_cache_similarity_threshold")) { int val = atoi(value); if (val < 0 || val > 100) { - proxy_error("Invalid value for genai_nl2sql_cache_similarity_threshold: %d (must be 0-100)\n", val); + proxy_error("Invalid value for genai_llm_cache_similarity_threshold: %d (must be 0-100)\n", val); return false; } - variables.genai_nl2sql_cache_similarity_threshold = val; + variables.genai_llm_cache_similarity_threshold = val; return true; } - if (!strcmp(name, "nl2sql_timeout_ms")) { + if (!strcmp(name, "llm_timeout_ms")) { int val = atoi(value); if (val < 1000 || val > 600000) { - proxy_error("Invalid value for genai_nl2sql_timeout_ms: %d (must be 1000-600000)\n", val); + proxy_error("Invalid value for genai_llm_timeout_ms: %d (must be 1000-600000)\n", val); return false; } - variables.genai_nl2sql_timeout_ms = val; + variables.genai_llm_timeout_ms = val; return true; } @@ -1709,30 +1698,30 @@ std::string GenAI_Threads_Handler::process_json_query(const std::string& json_qu return result.dump(); } - // Handle nl2sql operation - if (op_type == "nl2sql") { + // Handle llm operation + if (op_type == "llm") { // Check if AI manager is available if (!GloAI) { result["error"] = "AI features manager is not initialized"; return result.dump(); } - // Extract natural language query - if (!query_json.contains("query") || !query_json["query"].is_string()) { - result["error"] = "NL2SQL operation requires a 'query' string"; + // Extract prompt + if (!query_json.contains("prompt") || !query_json["prompt"].is_string()) { + result["error"] = "LLM operation requires a 'prompt' string"; return result.dump(); } - std::string nl_query = query_json["query"].get(); + std::string prompt = query_json["prompt"].get(); - if (nl_query.empty()) { - result["error"] = "NL2SQL query cannot be empty"; + if (prompt.empty()) { + result["error"] = "LLM prompt cannot be empty"; return result.dump(); } - // Extract optional schema name - std::string schema_name; - if (query_json.contains("schema") && query_json["schema"].is_string()) { - schema_name = query_json["schema"].get(); + // Extract optional system message + std::string system_message; + if (query_json.contains("system_message") && query_json["system_message"].is_string()) { + system_message = query_json["system_message"].get(); } // Extract optional cache flag @@ -1741,41 +1730,37 @@ std::string GenAI_Threads_Handler::process_json_query(const std::string& json_qu allow_cache = query_json["allow_cache"].get(); } - // Get NL2SQL converter - NL2SQL_Converter* nl2sql = GloAI->get_nl2sql(); - if (!nl2sql) { - result["error"] = "NL2SQL converter is not initialized"; + // Get LLM bridge + LLM_Bridge* llm_bridge = GloAI->get_llm_bridge(); + if (!llm_bridge) { + result["error"] = "LLM bridge is not initialized"; return result.dump(); } - // Build NL2SQL request - NL2SQLRequest req; - req.natural_language = nl_query; - req.schema_name = schema_name; + // Build LLM request + LLMRequest req; + req.prompt = prompt; + req.system_message = system_message; req.allow_cache = allow_cache; req.max_latency_ms = 0; // No specific latency requirement - // Convert (this will use cache if available) - NL2SQLResult sql_result = nl2sql->convert(req); + // Process (this will use cache if available) + LLMResult llm_result = llm_bridge->process(req); - if (sql_result.sql_query.empty() || sql_result.sql_query.find("NL2SQL conversion failed") == 0) { - result["error"] = "Failed to convert natural language to SQL: " + sql_result.explanation; + if (!llm_result.error_code.empty()) { + result["error"] = "LLM processing failed: " + llm_result.error_details; return result.dump(); } - // Build result - result["columns"] = json::array({"sql_query", "confidence", "explanation", "cached"}); + // Build result - return as single row with text_response + result["columns"] = json::array({"text_response", "explanation", "cached", "provider"}); json rows = json::array(); json row = json::array(); - row.push_back(sql_result.sql_query); - - char conf_buf[32]; - snprintf(conf_buf, sizeof(conf_buf), "%.2f", sql_result.confidence); - row.push_back(std::string(conf_buf)); - - row.push_back(sql_result.explanation); - row.push_back(sql_result.cached ? "true" : "false"); + row.push_back(llm_result.text_response); + row.push_back(llm_result.explanation); + row.push_back(llm_result.cached ? "true" : "false"); + row.push_back(llm_result.provider_used); rows.push_back(row); result["rows"] = rows; @@ -1784,7 +1769,7 @@ std::string GenAI_Threads_Handler::process_json_query(const std::string& json_qu } // Unknown operation type - result["error"] = "Unknown operation type: " + op_type + ". Use 'embed', 'rerank', or 'nl2sql'"; + result["error"] = "Unknown operation type: " + op_type + ". Use 'embed', 'rerank', or 'llm'"; return result.dump(); } catch (const json::parse_error& e) { diff --git a/lib/LLM_Bridge.cpp b/lib/LLM_Bridge.cpp new file mode 100644 index 000000000..05f19d4cb --- /dev/null +++ b/lib/LLM_Bridge.cpp @@ -0,0 +1,375 @@ +/** + * @file LLM_Bridge.cpp + * @brief Implementation of Generic LLM Bridge + * + * This file implements the generic LLM bridge pipeline including: + * - Vector cache operations for semantic similarity + * - Model selection based on latency/budget + * - Generic LLM API calls (Ollama, OpenAI-compatible, Anthropic-compatible) + * + * @see LLM_Bridge.h + */ + +#include "LLM_Bridge.h" +#include "sqlite3db.h" +#include "proxysql_utils.h" +#include "GenAI_Thread.h" +#include "cpp.h" +#include +#include +#include +#include +#include +#include + +using json = nlohmann::json; + +// Global GenAI handler for embedding generation +extern GenAI_Threads_Handler *GloGATH; + +// Global AI Features Manager for status updates +extern AI_Features_Manager *GloAI; + +// ============================================================================ +// Error Handling Helper Functions +// ============================================================================ + +/** + * @brief Convert error code enum to string representation + */ +const char* llm_error_code_to_string(LLMErrorCode code) { + switch (code) { + case LLMErrorCode::SUCCESS: return "SUCCESS"; + case LLMErrorCode::ERR_API_KEY_MISSING: return "ERR_API_KEY_MISSING"; + case LLMErrorCode::ERR_API_KEY_INVALID: return "ERR_API_KEY_INVALID"; + case LLMErrorCode::ERR_TIMEOUT: return "ERR_TIMEOUT"; + case LLMErrorCode::ERR_CONNECTION_FAILED: return "ERR_CONNECTION_FAILED"; + case LLMErrorCode::ERR_RATE_LIMITED: return "ERR_RATE_LIMITED"; + case LLMErrorCode::ERR_SERVER_ERROR: return "ERR_SERVER_ERROR"; + case LLMErrorCode::ERR_EMPTY_RESPONSE: return "ERR_EMPTY_RESPONSE"; + case LLMErrorCode::ERR_INVALID_RESPONSE: return "ERR_INVALID_RESPONSE"; + case LLMErrorCode::ERR_VALIDATION_FAILED: return "ERR_VALIDATION_FAILED"; + case LLMErrorCode::ERR_UNKNOWN_PROVIDER: return "ERR_UNKNOWN_PROVIDER"; + case LLMErrorCode::ERR_REQUEST_TOO_LARGE: return "ERR_REQUEST_TOO_LARGE"; + default: return "UNKNOWN"; + } +} + +// Forward declarations of external functions from LLM_Clients.cpp +extern std::string call_generic_openai_with_retry(const std::string& prompt, const std::string& model, + const std::string& url, const char* key, + const std::string& req_id); +extern std::string call_generic_anthropic_with_retry(const std::string& prompt, const std::string& model, + const std::string& url, const char* key, + const std::string& req_id); + +// ============================================================================ +// LLM_Bridge Implementation +// ============================================================================ + +/** + * @brief Constructor - initializes with default configuration + */ +LLM_Bridge::LLM_Bridge() + : vector_db(nullptr) +{ + // Set default configuration + config.enabled = false; + config.provider = strdup("openai"); + config.provider_url = strdup("http://localhost:11434/v1/chat/completions"); + config.provider_model = strdup("llama3.2"); + config.provider_key = nullptr; + config.cache_similarity_threshold = 85; + config.timeout_ms = 30000; + + proxy_debug(PROXY_DEBUG_GENAI, 3, "LLM_Bridge: Initialized with defaults\n"); +} + +/** + * @brief Destructor - frees allocated resources + */ +LLM_Bridge::~LLM_Bridge() { + if (config.provider) free(config.provider); + if (config.provider_url) free(config.provider_url); + if (config.provider_model) free(config.provider_model); + if (config.provider_key) free(config.provider_key); + + proxy_debug(PROXY_DEBUG_GENAI, 3, "LLM_Bridge: Destroyed\n"); +} + +/** + * @brief Initialize the LLM bridge + */ +int LLM_Bridge::init() { + proxy_info("LLM_Bridge: Initialized successfully\n"); + return 0; +} + +/** + * @brief Shutdown the LLM bridge + */ +void LLM_Bridge::close() { + proxy_info("LLM_Bridge: Shutdown complete\n"); +} + +/** + * @brief Update configuration from AI_Features_Manager + */ +void LLM_Bridge::update_config(const char* provider, const char* provider_url, const char* provider_model, + const char* provider_key, int cache_threshold, int timeout) { + if (provider) { + if (config.provider) free(config.provider); + config.provider = strdup(provider); + } + if (provider_url) { + if (config.provider_url) free(config.provider_url); + config.provider_url = strdup(provider_url); + } + if (provider_model) { + if (config.provider_model) free(config.provider_model); + config.provider_model = strdup(provider_model); + } + if (provider_key) { + if (config.provider_key) free(config.provider_key); + config.provider_key = provider_key ? strdup(provider_key) : nullptr; + } + config.cache_similarity_threshold = cache_threshold; + config.timeout_ms = timeout; + + proxy_debug(PROXY_DEBUG_GENAI, 3, "LLM_Bridge: Configuration updated\n"); +} + +/** + * @brief Build prompt from request + */ +std::string LLM_Bridge::build_prompt(const LLMRequest& req) { + std::string prompt = req.prompt; + + // Add system message if provided + if (!req.system_message.empty()) { + // For most LLM APIs, the system message is handled separately + // This is a simplified implementation + } + + return prompt; +} + +/** + * @brief Check vector cache for similar prompts + */ +LLMResult LLM_Bridge::check_cache(const LLMRequest& req) { + LLMResult result; + result.cached = false; + result.cache_hit = false; + + if (!vector_db || !req.allow_cache) { + return result; + } + + auto start_time = std::chrono::high_resolution_clock::now(); + + // TODO: Implement vector similarity search + // This would involve: + // 1. Generate embedding for the prompt + // 2. Search vector database for similar prompts + // 3. If similarity >= threshold, return cached response + + auto end_time = std::chrono::high_resolution_clock::now(); + result.cache_lookup_time_ms = std::chrono::duration_cast(end_time - start_time).count(); + + return result; +} + +/** + * @brief Store result in vector cache + */ +void LLM_Bridge::store_in_cache(const LLMRequest& req, const LLMResult& result) { + if (!vector_db || !req.allow_cache) { + return; + } + + auto start_time = std::chrono::high_resolution_clock::now(); + + // TODO: Implement cache storage + // This would involve: + // 1. Generate embedding for the prompt + // 2. Store prompt embedding, response, and metadata in cache table + + auto end_time = std::chrono::high_resolution_clock::now(); + const_cast(result).cache_store_time_ms = std::chrono::duration_cast(end_time - start_time).count(); +} + +/** + * @brief Select appropriate model based on request + */ +ModelProvider LLM_Bridge::select_model(const LLMRequest& req) { + if (!config.provider) { + return ModelProvider::FALLBACK_ERROR; + } + + if (strcmp(config.provider, "openai") == 0) { + return ModelProvider::GENERIC_OPENAI; + } else if (strcmp(config.provider, "anthropic") == 0) { + return ModelProvider::GENERIC_ANTHROPIC; + } + + return ModelProvider::FALLBACK_ERROR; +} + +/** + * @brief Get text embedding for vector cache + */ +std::vector LLM_Bridge::get_text_embedding(const std::string& text) { + std::vector embedding; + + // Use GenAI module for embedding generation + if (GloGATH) { + std::vector texts = {text}; + GenAI_EmbeddingResult result = GloGATH->embed_documents(texts); + + if (result.data && result.count > 0) { + // Copy embedding data + size_t dim = result.embedding_size; + embedding.assign(result.data, result.data + dim); + } + } + + return embedding; +} + +/** + * @brief Process a prompt using the LLM + */ +LLMResult LLM_Bridge::process(const LLMRequest& req) { + LLMResult result; + + auto total_start = std::chrono::high_resolution_clock::now(); + + // Check cache first + result = check_cache(req); + if (result.cached) { + result.cache_hit = true; + result.total_time_ms = result.cache_lookup_time_ms; + if (GloAI) { + GloAI->increment_llm_cache_hits(); + GloAI->add_llm_cache_lookup_time_ms(result.cache_lookup_time_ms); + GloAI->add_llm_response_time_ms(result.total_time_ms); + } + return result; + } + + if (GloAI) { + GloAI->increment_llm_cache_misses(); + GloAI->increment_llm_cache_lookups(); + GloAI->add_llm_cache_lookup_time_ms(result.cache_lookup_time_ms); + } + + // Build prompt + std::string prompt = build_prompt(req); + + // Select model + ModelProvider provider = select_model(req); + if (provider == ModelProvider::FALLBACK_ERROR) { + result.error_code = llm_error_code_to_string(LLMErrorCode::ERR_UNKNOWN_PROVIDER); + result.error_details = "Unknown provider: " + std::string(config.provider ? config.provider : "null"); + return result; + } + + // Call LLM API + auto llm_start = std::chrono::high_resolution_clock::now(); + + std::string raw_response; + try { + if (provider == ModelProvider::GENERIC_OPENAI) { + raw_response = call_generic_openai_with_retry( + prompt, + config.provider_model ? config.provider_model : "", + config.provider_url ? config.provider_url : "", + config.provider_key, + req.request_id, + req.max_retries, + req.retry_backoff_ms, + req.retry_multiplier, + req.retry_max_backoff_ms + ); + result.provider_used = "openai"; + } else if (provider == ModelProvider::GENERIC_ANTHROPIC) { + raw_response = call_generic_anthropic_with_retry( + prompt, + config.provider_model ? config.provider_model : "", + config.provider_url ? config.provider_url : "", + config.provider_key, + req.request_id, + req.max_retries, + req.retry_backoff_ms, + req.retry_multiplier, + req.retry_max_backoff_ms + ); + result.provider_used = "anthropic"; + } + } catch (const std::exception& e) { + result.error_code = "ERR_EXCEPTION"; + result.error_details = e.what(); + result.http_status_code = 0; + } + + auto llm_end = std::chrono::high_resolution_clock::now(); + result.llm_call_time_ms = std::chrono::duration_cast(llm_end - llm_start).count(); + + // Parse response + if (raw_response.empty() && result.error_code.empty()) { + result.error_code = llm_error_code_to_string(LLMErrorCode::ERR_EMPTY_RESPONSE); + result.error_details = "LLM returned empty response"; + } else if (!result.error_code.empty()) { + // Error already set by exception handler + } else { + result.text_response = raw_response; + } + + // Store in cache + store_in_cache(req, result); + + auto total_end = std::chrono::high_resolution_clock::now(); + result.total_time_ms = std::chrono::duration_cast(total_end - total_start).count(); + + // Update status counters + if (GloAI) { + GloAI->add_llm_response_time_ms(result.total_time_ms); + if (result.cache_store_time_ms > 0) { + GloAI->add_llm_cache_store_time_ms(result.cache_store_time_ms); + GloAI->increment_llm_cache_stores(); + } + GloAI->increment_llm_cloud_model_calls(); + } + + return result; +} + +/** + * @brief Clear the vector cache + */ +void LLM_Bridge::clear_cache() { + if (!vector_db) { + return; + } + + // TODO: Implement cache clearing + // This would involve deleting all rows from llm_cache table + + proxy_info("LLM_Bridge: Cache cleared\n"); +} + +/** + * @brief Get cache statistics + */ +std::string LLM_Bridge::get_cache_stats() { + // TODO: Implement cache statistics + // This would involve querying the llm_cache table for metrics + + json stats; + stats["entries"] = 0; + stats["hits"] = 0; + stats["misses"] = 0; + + return stats.dump(); +} diff --git a/lib/LLM_Clients.cpp b/lib/LLM_Clients.cpp index 8981dee5c..daec689c3 100644 --- a/lib/LLM_Clients.cpp +++ b/lib/LLM_Clients.cpp @@ -19,7 +19,7 @@ * @see NL2SQL_Converter.h */ -#include "NL2SQL_Converter.h" +#include "LLM_Bridge.h" #include "sqlite3db.h" #include "proxysql_utils.h" #include @@ -50,11 +50,11 @@ using json = nlohmann::json; do { \ if (req_id && strlen(req_id) > 0) { \ proxy_debug(PROXY_DEBUG_NL2SQL, 2, \ - "NL2SQL [%s]: REQUEST url=%s model=%s prompt_len=%zu\n", \ + "LLM [%s]: REQUEST url=%s model=%s prompt_len=%zu\n", \ req_id, url, model, prompt.length()); \ } else { \ proxy_debug(PROXY_DEBUG_NL2SQL, 2, \ - "NL2SQL: REQUEST url=%s model=%s prompt_len=%zu\n", \ + "LLM: REQUEST url=%s model=%s prompt_len=%zu\n", \ url, model, prompt.length()); \ } \ } while(0) @@ -63,11 +63,11 @@ using json = nlohmann::json; do { \ if (req_id && strlen(req_id) > 0) { \ proxy_debug(PROXY_DEBUG_NL2SQL, 3, \ - "NL2SQL [%s]: RESPONSE status=%d duration_ms=%ld response=%s\n", \ + "LLM [%s]: RESPONSE status=%d duration_ms=%ld response=%s\n", \ req_id, status, duration_ms, response_preview.c_str()); \ } else { \ proxy_debug(PROXY_DEBUG_NL2SQL, 3, \ - "NL2SQL: RESPONSE status=%d duration_ms=%ld response=%s\n", \ + "LLM: RESPONSE status=%d duration_ms=%ld response=%s\n", \ status, duration_ms, response_preview.c_str()); \ } \ } while(0) @@ -75,10 +75,10 @@ using json = nlohmann::json; #define LOG_LLM_ERROR(req_id, phase, error, status) \ do { \ if (req_id && strlen(req_id) > 0) { \ - proxy_error("NL2SQL [%s]: ERROR phase=%s error=%s status=%d\n", \ + proxy_error("LLM [%s]: ERROR phase=%s error=%s status=%d\n", \ req_id, phase, error, status); \ } else { \ - proxy_error("NL2SQL: ERROR phase=%s error=%s status=%d\n", \ + proxy_error("LLM: ERROR phase=%s error=%s status=%d\n", \ phase, error, status); \ } \ } while(0) @@ -214,7 +214,7 @@ static void sleep_with_jitter(int base_delay_ms, double jitter_factor = 0.1) { * @param req_id Request ID for correlation (optional) * @return Generated SQL or empty string on error */ -std::string NL2SQL_Converter::call_generic_openai(const std::string& prompt, const std::string& model, +std::string LLM_Bridge::call_generic_openai(const std::string& prompt, const std::string& model, const std::string& url, const char* key, const std::string& req_id) { // Start timing @@ -381,7 +381,7 @@ std::string NL2SQL_Converter::call_generic_openai(const std::string& prompt, con * @param req_id Request ID for correlation (optional) * @return Generated SQL or empty string on error */ -std::string NL2SQL_Converter::call_generic_anthropic(const std::string& prompt, const std::string& model, +std::string LLM_Bridge::call_generic_anthropic(const std::string& prompt, const std::string& model, const std::string& url, const char* key, const std::string& req_id) { // Start timing @@ -544,7 +544,7 @@ std::string NL2SQL_Converter::call_generic_anthropic(const std::string& prompt, * @param max_backoff_ms Maximum backoff delay in milliseconds * @return Generated SQL or empty string if all retries fail */ -std::string NL2SQL_Converter::call_generic_openai_with_retry( +std::string LLM_Bridge::call_generic_openai_with_retry( const std::string& prompt, const std::string& model, const std::string& url, @@ -568,7 +568,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry( // If we got a successful response, return it if (!result.empty()) { if (attempt > 0) { - proxy_info("NL2SQL [%s]: Request succeeded after %d retries\n", + proxy_info("LLM [%s]: Request succeeded after %d retries\n", req_id.c_str(), attempt); } return result; @@ -580,7 +580,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry( // If this was our last attempt, give up if (attempt == max_retries) { - proxy_error("NL2SQL [%s]: Request failed after %d attempts. Max retries reached.\n", + proxy_error("LLM [%s]: Request failed after %d attempts. Max retries reached.\n", req_id.c_str(), attempt + 1); return ""; } @@ -590,10 +590,10 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry( if (is_retryable_error(last_http_code, last_curl_code) || result.empty()) { // Log retry attempt if (result.empty()) { - proxy_warning("NL2SQL [%s]: Empty response, retrying in %dms (attempt %d/%d)\n", + proxy_warning("LLM [%s]: Empty response, retrying in %dms (attempt %d/%d)\n", req_id.c_str(), current_backoff_ms, attempt + 1, max_retries + 1); } else { - proxy_warning("NL2SQL [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n", + proxy_warning("LLM [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n", req_id.c_str(), last_http_code, current_backoff_ms, attempt + 1, max_retries + 1); } @@ -609,7 +609,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry( attempt++; } else { // Non-retryable error, give up - proxy_error("NL2SQL [%s]: Non-retryable error (HTTP %d), giving up.\n", + proxy_error("LLM [%s]: Non-retryable error (HTTP %d), giving up.\n", req_id.c_str(), last_http_code); return ""; } @@ -638,7 +638,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry( * @param max_backoff_ms Maximum backoff delay in milliseconds * @return Generated SQL or empty string if all retries fail */ -std::string NL2SQL_Converter::call_generic_anthropic_with_retry( +std::string LLM_Bridge::call_generic_anthropic_with_retry( const std::string& prompt, const std::string& model, const std::string& url, @@ -661,7 +661,7 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry( // If we got a successful response, return it if (!result.empty()) { if (attempt > 0) { - proxy_info("NL2SQL [%s]: Request succeeded after %d retries\n", + proxy_info("LLM [%s]: Request succeeded after %d retries\n", req_id.c_str(), attempt); } return result; @@ -669,7 +669,7 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry( // If this was our last attempt, give up if (attempt == max_retries) { - proxy_error("NL2SQL [%s]: Request failed after %d attempts. Max retries reached.\n", + proxy_error("LLM [%s]: Request failed after %d attempts. Max retries reached.\n", req_id.c_str(), attempt + 1); return ""; } @@ -679,10 +679,10 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry( if (is_retryable_error(last_http_code, last_curl_code) || result.empty()) { // Log retry attempt if (result.empty()) { - proxy_warning("NL2SQL [%s]: Empty response, retrying in %dms (attempt %d/%d)\n", + proxy_warning("LLM [%s]: Empty response, retrying in %dms (attempt %d/%d)\n", req_id.c_str(), current_backoff_ms, attempt + 1, max_retries + 1); } else { - proxy_warning("NL2SQL [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n", + proxy_warning("LLM [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n", req_id.c_str(), last_http_code, current_backoff_ms, attempt + 1, max_retries + 1); } @@ -698,7 +698,7 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry( attempt++; } else { // Non-retryable error, give up - proxy_error("NL2SQL [%s]: Non-retryable error (HTTP %d), giving up.\n", + proxy_error("LLM [%s]: Non-retryable error (HTTP %d), giving up.\n", req_id.c_str(), last_http_code); return ""; } diff --git a/lib/Makefile b/lib/Makefile index fc1e2960d..3e3283d0a 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -85,7 +85,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo MySQL_Catalog.oo MySQL_Tool_Handler.oo \ Config_Tool_Handler.oo Query_Tool_Handler.oo \ Admin_Tool_Handler.oo Cache_Tool_Handler.oo Observe_Tool_Handler.oo \ - AI_Features_Manager.oo NL2SQL_Converter.oo LLM_Clients.oo Anomaly_Detector.oo AI_Vector_Storage.oo AI_Tool_Handler.oo + AI_Features_Manager.oo LLM_Bridge.oo LLM_Clients.oo Anomaly_Detector.oo AI_Vector_Storage.oo AI_Tool_Handler.oo OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX)) HEADERS := ../include/*.h ../include/*.hpp diff --git a/lib/MySQL_Session.cpp b/lib/MySQL_Session.cpp index 3042515d1..05be0a5bc 100644 --- a/lib/MySQL_Session.cpp +++ b/lib/MySQL_Session.cpp @@ -16,6 +16,7 @@ using json = nlohmann::json; #include "MySQL_PreparedStatement.h" #include "GenAI_Thread.h" #include "AI_Features_Manager.h" +#include "LLM_Bridge.h" #include "Anomaly_Detector.h" #include "MySQL_Logger.hpp" #include "StatCounters.h" @@ -3871,31 +3872,33 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C #endif // epoll_create1 - fallback blocking path } -// Handler for NL2SQL: queries - Natural Language to SQL conversion +// Handler for LLM: queries - Generic LLM bridge processing // Query format: -// NL2SQL: Show me top 10 customers by revenue -// Returns: Resultset with the generated SQL query +// LLM: Summarize the customer feedback +// LLM: Generate a Python function to validate emails +// LLM: Explain this SQL query: SELECT * FROM users +// Returns: Resultset with the text response from LLM // // Note: This now uses the async GENAI path to avoid blocking MySQL threads. -// The NL2SQL query is converted to a JSON GENAI request and sent asynchronously. -void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(const char* query, size_t query_len, PtrSize_t* pkt) { - // Skip leading space after "NL2SQL:" +// The LLM query is converted to a JSON GENAI request and sent asynchronously. +void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___llm(const char* query, size_t query_len, PtrSize_t* pkt) { + // Skip leading space after "LLM:" while (query_len > 0 && (*query == ' ' || *query == '\t')) { query++; query_len--; } if (query_len == 0) { - // Empty query after NL2SQL: + // Empty query after LLM: client_myds->DSS = STATE_QUERY_SENT_NET; - client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1240, (char*)"HY000", "Empty NL2SQL: query", true); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1240, (char*)"HY000", "Empty LLM: query", true); l_free(pkt->size, pkt->ptr); client_myds->DSS = STATE_SLEEP; status = WAITING_CLIENT_DATA; return; } - // Check GenAI module is initialized (NL2SQL now uses GenAI module) + // Check GenAI module is initialized (LLM now uses GenAI module) if (!GloGATH) { client_myds->DSS = STATE_QUERY_SENT_NET; client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1241, (char*)"HY000", "GenAI module is not initialized", true); @@ -3905,7 +3908,7 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return; } - // Check AI manager is available for NL2SQL converter + // Check AI manager is available for LLM bridge if (!GloAI) { client_myds->DSS = STATE_QUERY_SENT_NET; client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1242, (char*)"HY000", "AI features module is not initialized", true); @@ -3915,11 +3918,11 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return; } - // Get NL2SQL converter from AI manager - NL2SQL_Converter* nl2sql = GloAI->get_nl2sql(); - if (!nl2sql) { + // Get LLM bridge from AI manager + LLM_Bridge* llm_bridge = GloAI->get_llm_bridge(); + if (!llm_bridge) { client_myds->DSS = STATE_QUERY_SENT_NET; - client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1243, (char*)"HY000", "NL2SQL converter is not initialized", true); + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1243, (char*)"HY000", "LLM bridge is not initialized", true); l_free(pkt->size, pkt->ptr); client_myds->DSS = STATE_SLEEP; status = WAITING_CLIENT_DATA; @@ -3927,16 +3930,16 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C } // Increment total requests counter - GloAI->increment_nl2sql_total_requests(); + GloAI->increment_llm_total_requests(); #ifdef epoll_create1 - // Build JSON query for NL2SQL operation + // Build JSON query for LLM operation json json_query; - json_query["type"] = "nl2sql"; - json_query["query"] = std::string(query, query_len); + json_query["type"] = "llm"; + json_query["prompt"] = std::string(query, query_len); json_query["allow_cache"] = true; - // Add schema if available + // Add schema if available (for context) if (client_myds->myconn->userinfo->schemaname) { json_query["schema"] = std::string(client_myds->myconn->userinfo->schemaname); } @@ -3954,38 +3957,38 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C // Request sent asynchronously - don't free pkt, will be freed in response handler // Return immediately, session is now free to handle other queries - proxy_debug(PROXY_DEBUG_NL2SQL, 2, "NL2SQL: Query sent asynchronously via GenAI: %s\n", std::string(query, query_len).c_str()); + proxy_debug(PROXY_DEBUG_GENAI, 2, "LLM: Query sent asynchronously via GenAI: %s\n", std::string(query, query_len).c_str()); #else // Fallback to synchronous blocking path for systems without epoll - // Build NL2SQL request - NL2SQLRequest req; - req.natural_language = std::string(query, query_len); + // Build LLM request + LLMRequest req; + req.prompt = std::string(query, query_len); req.schema_name = client_myds->myconn->userinfo->schemaname ? client_myds->myconn->userinfo->schemaname : ""; req.allow_cache = true; req.max_latency_ms = 0; // No specific latency requirement - // Call NL2SQL converter (blocking fallback) - NL2SQLResult result = nl2sql->convert(req); + // Call LLM bridge (blocking fallback) + LLMResult result = llm_bridge->process(req); // Update performance counters based on result if (result.cache_hit) { - GloAI->increment_nl2sql_cache_hits(); + GloAI->increment_llm_cache_hits(); } else { - GloAI->increment_nl2sql_cache_misses(); + GloAI->increment_llm_cache_misses(); } // Update timing counters - GloAI->add_nl2sql_response_time_ms(result.total_time_ms); - GloAI->add_nl2sql_cache_lookup_time_ms(result.cache_lookup_time_ms); - GloAI->increment_nl2sql_cache_lookups(); + GloAI->add_llm_response_time_ms(result.total_time_ms); + GloAI->add_llm_cache_lookup_time_ms(result.cache_lookup_time_ms); + GloAI->increment_llm_cache_lookups(); if (result.cache_hit) { // For cache hits, we're done } else { // For cache misses, also count LLM call time and cache store time - GloAI->add_nl2sql_cache_store_time_ms(result.cache_store_time_ms); + GloAI->add_llm_cache_store_time_ms(result.cache_store_time_ms); if (result.cache_store_time_ms > 0) { - GloAI->increment_nl2sql_cache_stores(); + GloAI->increment_llm_cache_stores(); } // Update model call counters @@ -3998,19 +4001,23 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C if (prefer_local_models && (result.explanation.find("localhost") != std::string::npos || result.explanation.find("127.0.0.1") != std::string::npos)) { - GloAI->increment_nl2sql_local_model_calls(); + GloAI->increment_llm_local_model_calls(); } else { - GloAI->increment_nl2sql_cloud_model_calls(); + GloAI->increment_llm_cloud_model_calls(); } } else if (result.provider_used == "anthropic") { - GloAI->increment_nl2sql_cloud_model_calls(); + GloAI->increment_llm_cloud_model_calls(); } } - if (result.sql_query.empty() || result.sql_query.find("NL2SQL conversion failed") == 0) { - // Conversion failed - std::string err_msg = "Failed to convert natural language to SQL: "; - err_msg += result.explanation; + if (result.text_response.empty() && !result.error_code.empty()) { + // LLM processing failed + std::string err_msg = "LLM processing failed: "; + err_msg += result.error_code; + if (!result.error_details.empty()) { + err_msg += " - "; + err_msg += result.error_details; + } client_myds->DSS = STATE_QUERY_SENT_NET; client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1244, (char*)"HY000", (char*)err_msg.c_str(), true); l_free(pkt->size, pkt->ptr); @@ -4019,8 +4026,8 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return; } - // Build resultset with the generated SQL - std::vector columns = {"sql_query", "confidence", "explanation", "cached"}; + // Build resultset with the generated text response + std::vector columns = {"text_response", "explanation", "cached", "provider"}; std::unique_ptr resultset(new SQLite3_result(columns.size())); // Add column definitions @@ -4030,13 +4037,10 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C // Add single row with the result char** row_data = (char**)malloc(columns.size() * sizeof(char*)); - row_data[0] = strdup(result.sql_query.c_str()); - - char conf_buf[32]; - snprintf(conf_buf, sizeof(conf_buf), "%.2f", result.confidence); - row_data[1] = strdup(conf_buf); - row_data[2] = strdup(result.explanation.c_str()); - row_data[3] = strdup(result.cached ? "true" : "false"); + row_data[0] = strdup(result.text_response.c_str()); + row_data[1] = strdup(result.explanation.c_str()); + row_data[2] = strdup(result.cached ? "true" : "false"); + row_data[3] = strdup(result.provider_used.c_str()); resultset->add_row(row_data); @@ -4054,8 +4058,8 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C client_myds->DSS = STATE_SLEEP; status = WAITING_CLIENT_DATA; - proxy_debug(PROXY_DEBUG_NL2SQL, 2, "NL2SQL: Converted '%s' to SQL (confidence: %.2f) [blocking fallback]\n", - req.natural_language.c_str(), result.confidence); + proxy_debug(PROXY_DEBUG_GENAI, 2, "LLM: Processed prompt '%s' [blocking fallback]\n", + req.prompt.c_str()); #endif } @@ -7037,10 +7041,10 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return true; } - // Check for NL2SQL: queries - Natural Language to SQL conversion - if (query_len >= 8 && strncasecmp(query_ptr, "NL2SQL:", 7) == 0) { - // This is a NL2SQL: query - handle with NL2SQL converter - handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(query_ptr + 7, query_len - 7, pkt); + // Check for LLM: queries - Generic LLM bridge processing + if (query_len >= 5 && strncasecmp(query_ptr, "LLM:", 4) == 0) { + // This is a LLM: query - handle with LLM bridge + handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___llm(query_ptr + 4, query_len - 4, pkt); return true; } } diff --git a/lib/NL2SQL_Converter.cpp b/lib/NL2SQL_Converter.cpp deleted file mode 100644 index e6bbcbb8d..000000000 --- a/lib/NL2SQL_Converter.cpp +++ /dev/null @@ -1,790 +0,0 @@ -/** - * @file NL2SQL_Converter.cpp - * @brief Implementation of Natural Language to SQL Converter - * - * This file implements the NL2SQL conversion pipeline including: - * - Vector cache operations for semantic similarity - * - Model selection based on latency/budget - * - Generic LLM API calls (Ollama, OpenAI-compatible, Anthropic-compatible) - * - SQL validation and cleaning - * - * @see NL2SQL_Converter.h - */ - -#include "NL2SQL_Converter.h" -#include "sqlite3db.h" -#include "proxysql_utils.h" -#include "GenAI_Thread.h" -#include -#include -#include -#include -#include -#include - -using json = nlohmann::json; - -// Global GenAI handler for embedding generation -extern GenAI_Threads_Handler *GloGATH; - -// Global instance is defined elsewhere if needed -// NL2SQL_Converter *GloNL2SQL = NULL; - -// ============================================================================ -// Error Handling Helper Functions -// ============================================================================ - -/** - * @brief Convert error code enum to string representation - * - * Returns the string representation of an error code for logging - * and display purposes. - * - * @param code The error code to convert - * @return String representation of the error code - */ -const char* nl2sql_error_code_to_string(NL2SQLErrorCode code) { - switch (code) { - case NL2SQLErrorCode::SUCCESS: return "SUCCESS"; - case NL2SQLErrorCode::ERR_API_KEY_MISSING: return "ERR_API_KEY_MISSING"; - case NL2SQLErrorCode::ERR_API_KEY_INVALID: return "ERR_API_KEY_INVALID"; - case NL2SQLErrorCode::ERR_TIMEOUT: return "ERR_TIMEOUT"; - case NL2SQLErrorCode::ERR_CONNECTION_FAILED: return "ERR_CONNECTION_FAILED"; - case NL2SQLErrorCode::ERR_RATE_LIMITED: return "ERR_RATE_LIMITED"; - case NL2SQLErrorCode::ERR_SERVER_ERROR: return "ERR_SERVER_ERROR"; - case NL2SQLErrorCode::ERR_EMPTY_RESPONSE: return "ERR_EMPTY_RESPONSE"; - case NL2SQLErrorCode::ERR_INVALID_RESPONSE: return "ERR_INVALID_RESPONSE"; - case NL2SQLErrorCode::ERR_SQL_INJECTION_DETECTED: return "ERR_SQL_INJECTION_DETECTED"; - case NL2SQLErrorCode::ERR_VALIDATION_FAILED: return "ERR_VALIDATION_FAILED"; - case NL2SQLErrorCode::ERR_UNKNOWN_PROVIDER: return "ERR_UNKNOWN_PROVIDER"; - case NL2SQLErrorCode::ERR_REQUEST_TOO_LARGE: return "ERR_REQUEST_TOO_LARGE"; - default: return "UNKNOWN_ERROR"; - } -} - -/** - * @brief Format detailed error context for logging and user display - * - * Creates a structured error message including: - * - Query (truncated if too long) - * - Schema name - * - Provider attempted - * - Endpoint URL - * - Specific error message - * - * @param req The NL2SQL request that failed - * @param provider The provider that was attempted - * @param url The endpoint URL that was used - * @param error The specific error message - * @return Formatted error context string - */ -static std::string format_error_context(const NL2SQLRequest& req, - const std::string& provider, - const std::string& url, - const std::string& error) -{ - std::ostringstream oss; - oss << "NL2SQL conversion failed:\n" - << " Query: " << req.natural_language.substr(0, 100) - << (req.natural_language.length() > 100 ? "..." : "") << "\n" - << " Schema: " << (req.schema_name.empty() ? "(none)" : req.schema_name) << "\n" - << " Provider: " << provider << "\n" - << " URL: " << url << "\n" - << " Error: " << error; - return oss.str(); -} - -/** - * @brief Set error details in NL2SQLResult - * - * Helper function to populate error fields in result struct. - * - * @param result The result to update - * @param error_code The error code string - * @param error_details Detailed error context - * @param http_status HTTP status code (0 if N/A) - * @param provider Provider that was attempted - */ -static void set_error_details(NL2SQLResult& result, - const std::string& error_code, - const std::string& error_details, - int http_status, - const std::string& provider) -{ - result.error_code = error_code; - result.error_details = error_details; - result.http_status_code = http_status; - result.provider_used = provider; -} - -// ============================================================================ -// Constructor/Destructor -// ============================================================================ - -/** - * Constructor initializes with default configuration values. - * The vector_db will be set by AI_Features_Manager during init(). - */ -NL2SQL_Converter::NL2SQL_Converter() : vector_db(NULL) { - config.enabled = true; - config.query_prefix = strdup("NL2SQL:"); - config.provider = strdup("openai"); - config.provider_url = strdup("http://localhost:11434/v1/chat/completions"); // Ollama default - config.provider_model = strdup("llama3.2"); - config.provider_key = NULL; - config.cache_similarity_threshold = 85; - config.timeout_ms = 30000; -} - -NL2SQL_Converter::~NL2SQL_Converter() { - free(config.query_prefix); - free(config.provider); - free(config.provider_url); - free(config.provider_model); - free(config.provider_key); -} - -// ============================================================================ -// Lifecycle -// ============================================================================ - -/** - * Initialize the NL2SQL converter. - * The vector DB will be provided by AI_Features_Manager during initialization. - */ -int NL2SQL_Converter::init() { - proxy_info("NL2SQL: Initializing NL2SQL Converter v%s\n", NL2SQL_CONVERTER_VERSION); - - // Vector DB will be provided by AI_Features_Manager - // This is a stub implementation for Phase 1 - - proxy_info("NL2SQL: NL2SQL Converter initialized (stub)\n"); - return 0; -} - -void NL2SQL_Converter::close() { - proxy_info("NL2SQL: NL2SQL Converter closed\n"); -} - -void NL2SQL_Converter::update_config(const char* provider, const char* provider_url, - const char* provider_model, const char* provider_key, - int cache_threshold, int timeout) { - // Free old values - free(config.provider); - free(config.provider_url); - free(config.provider_model); - free(config.provider_key); - - // Set new values - config.provider = strdup(provider ? provider : "openai"); - config.provider_url = strdup(provider_url ? provider_url : "http://localhost:11434/v1/chat/completions"); - config.provider_model = strdup(provider_model ? provider_model : "llama3.2"); - config.provider_key = provider_key ? strdup(provider_key) : NULL; - config.cache_similarity_threshold = cache_threshold; - config.timeout_ms = timeout; -} - -// ============================================================================ -// Vector Cache Operations (semantic similarity cache) -// ============================================================================ - -/** - * @brief Generate vector embedding for text - * - * Generates a 1536-dimensional embedding using the GenAI module. - * This embedding represents the semantic meaning of the text. - * - * @param text Input text to embed - * @return Vector embedding (empty if not available) - */ -std::vector NL2SQL_Converter::get_query_embedding(const std::string& text) { - if (!GloGATH) { - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: GenAI handler not available for embedding"); - return {}; - } - - // Generate embedding using GenAI - GenAI_EmbeddingResult emb_result = GloGATH->embed_documents({text}); - - if (!emb_result.data || emb_result.count == 0) { - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Failed to generate embedding"); - return {}; - } - - // Convert to std::vector - std::vector embedding(emb_result.data, emb_result.data + emb_result.embedding_size); - - // Free the result data (GenAI allocates with malloc) - if (emb_result.data) { - free(emb_result.data); - } - - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Generated embedding with %zu dimensions", embedding.size()); - return embedding; -} - -/** - * @brief Check vector cache for semantically similar previous conversions - * - * Uses sqlite-vec to find previous NL2SQL conversions with similar - * natural language queries. This allows caching based on semantic meaning - * rather than exact string matching. - */ -NL2SQLResult NL2SQL_Converter::check_vector_cache(const NL2SQLRequest& req) { - NL2SQLResult result; - result.cached = false; - - if (!vector_db || !req.allow_cache) { - return result; - } - - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Checking vector cache for: %s\n", - req.natural_language.c_str()); - - // Generate embedding for the query - std::vector query_embedding = get_query_embedding(req.natural_language); - if (query_embedding.empty()) { - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Failed to generate embedding for cache lookup"); - return result; - } - - // Convert embedding to JSON for sqlite-vec MATCH - std::string embedding_json = "["; - for (size_t i = 0; i < query_embedding.size(); i++) { - if (i > 0) embedding_json += ","; - embedding_json += std::to_string(query_embedding[i]); - } - embedding_json += "]"; - - // Calculate distance threshold from similarity - // Similarity 0-100 -> Distance 0-2 (cosine distance: 0=similar, 2=dissimilar) - float distance_threshold = 2.0f - (config.cache_similarity_threshold / 50.0f); - - // Build KNN search query - char search[1024]; - snprintf(search, sizeof(search), - "SELECT c.natural_language, c.generated_sql, c.schema_context, " - " vec_distance_cosine(v.embedding, '%s') as distance " - "FROM nl2sql_cache c " - "JOIN nl2sql_cache_vec v ON c.id = v.rowid " - "WHERE v.embedding MATCH '%s' " - "AND distance < %f " - "ORDER BY distance " - "LIMIT 1", - embedding_json.c_str(), embedding_json.c_str(), distance_threshold); - - // Execute search - sqlite3* db = vector_db->get_db(); - sqlite3_stmt* stmt = NULL; - int rc = sqlite3_prepare_v2(db, search, -1, &stmt, NULL); - - if (rc != SQLITE_OK) { - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Cache search prepare failed: %s", sqlite3_errmsg(db)); - return result; - } - - // Check if any cached queries matched - rc = sqlite3_step(stmt); - if (rc == SQLITE_ROW) { - // Found similar cached query - result.cached = true; - - // Extract cached result (natural_lang and schema_ctx available but not currently used) - // const char* natural_lang = reinterpret_cast(sqlite3_column_text(stmt, 0)); - const char* generated_sql = reinterpret_cast(sqlite3_column_text(stmt, 1)); - // const char* schema_ctx = reinterpret_cast(sqlite3_column_text(stmt, 2)); - double distance = sqlite3_column_double(stmt, 3); - - // Calculate similarity score from distance - float similarity = 1.0f - (distance / 2.0f); - result.confidence = similarity; - result.sql_query = generated_sql ? generated_sql : ""; - result.explanation = "Retrieved from semantic cache (similarity: " + - std::to_string((int)(similarity * 100)) + "%)"; - - proxy_info("NL2SQL: Cache hit! (distance: %.3f, similarity: %.0f%%)\n", - distance, similarity * 100); - } - - sqlite3_finalize(stmt); - - return result; -} - -/** - * @brief Store a new NL2SQL conversion in the vector cache - * - * Stores both the original query and generated SQL, along with - * the query embedding for semantic similarity search. - */ -void NL2SQL_Converter::store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result) { - if (!vector_db || !req.allow_cache) { - return; - } - - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Storing in vector cache: %s -> %s\n", - req.natural_language.c_str(), result.sql_query.c_str()); - - // Generate embedding for the natural language query - std::vector embedding = get_query_embedding(req.natural_language); - if (embedding.empty()) { - proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Failed to generate embedding for cache storage"); - return; - } - - // Insert into main table with embedding BLOB - sqlite3* db = vector_db->get_db(); - sqlite3_stmt* stmt = NULL; - const char* insert = "INSERT INTO nl2sql_cache " - "(natural_language, generated_sql, schema_context, embedding) " - "VALUES (?, ?, ?, ?)"; - - int rc = sqlite3_prepare_v2(db, insert, -1, &stmt, NULL); - if (rc != SQLITE_OK) { - proxy_error("NL2SQL: Failed to prepare cache insert: %s\n", sqlite3_errmsg(db)); - return; - } - - // Bind values - sqlite3_bind_text(stmt, 1, req.natural_language.c_str(), -1, SQLITE_TRANSIENT); - sqlite3_bind_text(stmt, 2, result.sql_query.c_str(), -1, SQLITE_TRANSIENT); - - // Schema context (may be empty) - std::string schema_context; - if (!req.context_tables.empty()) { - schema_context = "{"; // Simple format: table names - for (size_t i = 0; i < req.context_tables.size(); i++) { - if (i > 0) schema_context += ","; - schema_context += req.context_tables[i]; - } - schema_context += "}"; - } - sqlite3_bind_text(stmt, 3, schema_context.c_str(), -1, SQLITE_TRANSIENT); - - // Bind embedding as BLOB - sqlite3_bind_blob(stmt, 4, embedding.data(), embedding.size() * sizeof(float), SQLITE_TRANSIENT); - - // Execute insert - rc = sqlite3_step(stmt); - if (rc != SQLITE_DONE) { - proxy_error("NL2SQL: Failed to insert into cache: %s\n", sqlite3_errmsg(db)); - sqlite3_finalize(stmt); - return; - } - - sqlite3_finalize(stmt); - - // Get the inserted rowid - sqlite3_int64 rowid = sqlite3_last_insert_rowid(db); - - // Update virtual table (sqlite-vec needs explicit rowid insertion) - char update_vec[256]; - snprintf(update_vec, sizeof(update_vec), - "INSERT INTO nl2sql_cache_vec(rowid) VALUES (%lld)", rowid); - - char* err = NULL; - rc = sqlite3_exec(db, update_vec, NULL, NULL, &err); - if (rc != SQLITE_OK) { - proxy_error("NL2SQL: Failed to update vec table: %s\n", err ? err : "unknown"); - if (err) sqlite3_free(err); - return; - } - - proxy_info("NL2SQL: Stored in cache (id: %lld)\n", rowid); -} - -// ============================================================================ -// Model Selection Logic -// ============================================================================ - -/** - * @brief Select the best model provider for the given request - * - * Selection criteria: - * 1. Explicit provider preference -> use that - * 2. For generic providers: check API key availability (only for cloud) - * - * @note For local endpoints (like Ollama), API key is optional - */ -ModelProvider NL2SQL_Converter::select_model(const NL2SQLRequest& req) { - // Check provider preference - std::string provider(config.provider ? config.provider : "openai"); - - if (provider == "openai") { - // For local endpoints, API key is optional - // Check if this is a local endpoint - std::string url(config.provider_url ? config.provider_url : ""); - bool is_local = (url.find("localhost") != std::string::npos || - url.find("127.0.0.1") != std::string::npos || - url.find("http://localhost:11434") != std::string::npos); - - if (!is_local && !config.provider_key) { - proxy_error("NL2SQL: OpenAI-compatible provider requested but API key not configured\n"); - return ModelProvider::FALLBACK_ERROR; - } - return ModelProvider::GENERIC_OPENAI; - } else if (provider == "anthropic") { - // Anthropic always requires API key - if (!config.provider_key) { - proxy_error("NL2SQL: Anthropic-compatible provider requested but API key not configured\n"); - return ModelProvider::FALLBACK_ERROR; - } - return ModelProvider::GENERIC_ANTHROPIC; - } - - // Unknown provider, default to OpenAI format - proxy_warning("NL2SQL: Unknown provider '%s', defaulting to OpenAI format\n", provider.c_str()); - return ModelProvider::GENERIC_OPENAI; -} - -// ============================================================================ -// Prompt Building -// ============================================================================ - -/** - * @brief Build the prompt for LLM with schema context - * - * Constructs a comprehensive prompt including: - * - System instructions - * - Schema information (tables, columns) - * - User's natural language query - */ -std::string NL2SQL_Converter::build_prompt(const NL2SQLRequest& req, const std::string& schema_context) { - std::ostringstream prompt; - - // System instructions - prompt << "You are a SQL expert. Convert the following natural language question to a SQL query.\n\n"; - - // Add schema context if available - if (!schema_context.empty()) { - prompt << "Database Schema:\n"; - prompt << schema_context; - prompt << "\n"; - } - - // User's question - prompt << "Question: " << req.natural_language << "\n\n"; - prompt << "Return ONLY the SQL query. No explanations, no markdown formatting.\n"; - - return prompt.str(); -} - -/** - * @brief Get schema context for the specified tables - * - * Retrieves table and column information from the MySQL_Tool_Handler - * or from cached schema information. - */ -std::string NL2SQL_Converter::get_schema_context(const std::vector& tables) { - // TODO: Implement schema context retrieval via MySQL_Tool_Handler - // For Phase 2, return empty string - return ""; -} - -// ============================================================================ -// SQL Validation -// ============================================================================ - -/** - * @brief Validate SQL and generate confidence score - * - * Performs multi-factor validation: - * 1. SQL keyword detection - * 2. Structural validation (parentheses, quotes) - * 3. Common SQL injection pattern detection - * 4. Length and complexity checks - * - * @param sql The SQL to validate - * @return Confidence score 0.0-1.0 - */ -float NL2SQL_Converter::validate_and_score_sql(const std::string& sql) { - if (sql.empty()) { - return 0.0f; - } - - float confidence = 0.0f; - int checks_passed = 0; - int total_checks = 0; - - // Trim leading whitespace for validation - size_t start = sql.find_first_not_of(" \t\n\r"); - if (start == std::string::npos) { - return 0.0f; // Empty or whitespace only - } - std::string trimmed_sql = sql.substr(start); - std::string upper_sql = trimmed_sql; - std::transform(upper_sql.begin(), upper_sql.end(), upper_sql.begin(), ::toupper); - - // Check 1: SQL keyword detection - total_checks++; - static const std::vector sql_keywords = { - "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", - "TRUNCATE", "REPLACE", "GRANT", "REVOKE", "SHOW", "DESCRIBE", - "EXPLAIN", "WITH", "CALL", "BEGIN", "COMMIT", "ROLLBACK" - }; - for (const auto& keyword : sql_keywords) { - if (upper_sql.find(keyword) == 0 || upper_sql.find("-- " + keyword) == 0) { - confidence += 0.4f; - checks_passed++; - break; - } - } - - // Check 2: Structural validation - balanced parentheses - total_checks++; - int paren_count = 0; - bool balanced_parens = true; - for (char c : sql) { - if (c == '(') paren_count++; - else if (c == ')') paren_count--; - if (paren_count < 0) { - balanced_parens = false; - break; - } - } - if (balanced_parens && paren_count == 0) { - confidence += 0.15f; - checks_passed++; - } else if (paren_count != 0) { - // Unbalanced parentheses reduce confidence - confidence -= 0.1f; - } - - // Check 3: Balanced quotes - total_checks++; - int single_quotes = 0; - int double_quotes = 0; - for (size_t i = 0; i < sql.length(); i++) { - if (sql[i] == '\'' && (i == 0 || sql[i-1] != '\\')) { - single_quotes++; - } - if (sql[i] == '"' && (i == 0 || sql[i-1] != '\\')) { - double_quotes++; - } - } - if (single_quotes % 2 == 0 && double_quotes % 2 == 0) { - confidence += 0.15f; - checks_passed++; - } else { - confidence -= 0.1f; - } - - // Check 4: Minimum length check - total_checks++; - if (sql.length() >= 10) { - confidence += 0.1f; - checks_passed++; - } - - // Check 5: Contains FROM clause for SELECT statements (quality indicator) - total_checks++; - if (upper_sql.find("SELECT") == 0 && upper_sql.find("FROM") != std::string::npos) { - confidence += 0.1f; - checks_passed++; - } - - // Check 6: SQL injection pattern detection (negative impact) - total_checks++; - static const std::vector injection_patterns = { - "; DROP", "; DELETE", "; INSERT", "; UPDATE", - "1=1", "1 = 1", "OR TRUE", "AND TRUE", - "UNION SELECT", "'; --", "\"; --" - }; - bool has_injection = false; - std::string check_upper = upper_sql; - for (const auto& pattern : injection_patterns) { - std::string pattern_upper = pattern; - std::transform(pattern_upper.begin(), pattern_upper.end(), pattern_upper.begin(), ::toupper); - if (check_upper.find(pattern_upper) != std::string::npos) { - has_injection = true; - break; - } - } - if (!has_injection) { - confidence += 0.1f; - checks_passed++; - } else { - confidence -= 0.3f; // Significant penalty for injection patterns - proxy_warning("NL2SQL: Potential SQL injection pattern detected in generated SQL\n"); - } - - // Normalize confidence to 0.0-1.0 range - if (confidence < 0.0f) confidence = 0.0f; - if (confidence > 1.0f) confidence = 1.0f; - - // Additional logging for low confidence - if (confidence < 0.5f) { - proxy_debug(PROXY_DEBUG_NL2SQL, 2, - "NL2SQL: Low confidence score %.2f (passed %d/%d checks). SQL: %s\n", - confidence, checks_passed, total_checks, sql.c_str()); - } - - return confidence; -} - -// ============================================================================ -// Main Conversion Method -// ============================================================================ - -/** - * @brief Convert natural language to SQL (main entry point) - * - * Conversion Pipeline: - * 1. Check vector cache for semantically similar queries - * 2. Build prompt with schema context - * 3. Select appropriate model (Ollama or generic provider) - * 4. Call LLM API via HTTP - * 5. Parse and clean SQL response - * 6. Store in vector cache for future use - * - * The confidence score is calculated based on: - * - SQL keyword validation (does it look like SQL?) - * - Response quality (non-empty, well-formed) - * - Default score of 0.85 for valid-looking SQL - * - * @note This is a synchronous blocking call. - */ -NL2SQLResult NL2SQL_Converter::convert(const NL2SQLRequest& req) { - NL2SQLResult result; - - // Start timing the entire conversion - auto start_time = std::chrono::steady_clock::now(); - - proxy_info("NL2SQL: Converting query: %s\n", req.natural_language.c_str()); - - // Check vector cache first - auto cache_start = std::chrono::steady_clock::now(); - if (req.allow_cache) { - result = check_vector_cache(req); - if (result.cached && !result.sql_query.empty()) { - proxy_info("NL2SQL: Cache hit! Returning cached SQL\n"); - // Set timing information for cache hit - auto cache_end = std::chrono::steady_clock::now(); - int cache_lookup_time_ms = std::chrono::duration_cast(cache_end - cache_start).count(); - result.total_time_ms = cache_lookup_time_ms; - result.cache_lookup_time_ms = cache_lookup_time_ms; - result.cache_hit = true; - return result; - } - } - auto cache_end = std::chrono::steady_clock::now(); - int cache_lookup_time_ms = std::chrono::duration_cast(cache_end - cache_start).count(); - - // Build prompt with schema context - std::string schema_context = get_schema_context(req.context_tables); - std::string prompt = build_prompt(req, schema_context); - - // Select model provider - ModelProvider provider = select_model(req); - - // Call appropriate LLM - std::string raw_sql; - std::string url; - const char* model = NULL; - const char* key = config.provider_key; - - // Time the LLM call - auto llm_start = std::chrono::steady_clock::now(); - switch (provider) { - case ModelProvider::GENERIC_OPENAI: - // Use configured URL or default Ollama endpoint - url = (config.provider_url && strlen(config.provider_url) > 0) - ? config.provider_url - : "http://localhost:11434/v1/chat/completions"; - model = config.provider_model ? config.provider_model : "llama3.2"; - raw_sql = call_generic_openai_with_retry(prompt, model, url, key, req.request_id, - req.max_retries, req.retry_backoff_ms, - req.retry_multiplier, req.retry_max_backoff_ms); - result.explanation = "Generated by OpenAI-compatible provider (" + std::string(model) + ")"; - result.provider_used = "openai"; - break; - case ModelProvider::GENERIC_ANTHROPIC: - // Use configured URL or default Anthropic endpoint - url = (config.provider_url && strlen(config.provider_url) > 0) - ? config.provider_url - : "https://api.anthropic.com/v1/messages"; - model = config.provider_model ? config.provider_model : "claude-3-haiku"; - raw_sql = call_generic_anthropic_with_retry(prompt, model, url, key, req.request_id, - req.max_retries, req.retry_backoff_ms, - req.retry_multiplier, req.retry_max_backoff_ms); - result.explanation = "Generated by Anthropic-compatible provider (" + std::string(model) + ")"; - result.provider_used = "anthropic"; - break; - case ModelProvider::FALLBACK_ERROR: - default: { - // Format error context - std::string provider_str(config.provider ? config.provider : "unknown"); - std::string url_str(config.provider_url ? config.provider_url : "not configured"); - std::string error_msg = "API key not configured or provider error"; - std::string context = format_error_context(req, provider_str, url_str, error_msg); - - proxy_error("NL2SQL: %s\n", context.c_str()); - - set_error_details(result, "ERR_API_KEY_MISSING", context, 0, provider_str); - result.sql_query = "-- NL2SQL conversion failed: " + error_msg + "\n"; - result.confidence = 0.0f; - result.explanation = "Error: " + error_msg; - return result; - } - } - auto llm_end = std::chrono::steady_clock::now(); - int llm_call_time_ms = std::chrono::duration_cast(llm_end - llm_start).count(); - - // Validate and clean SQL - if (raw_sql.empty()) { - std::string provider_str(config.provider ? config.provider : "unknown"); - std::string url_str(config.provider_url ? config.provider_url : "not configured"); - std::string error_msg = "empty response from LLM"; - std::string context = format_error_context(req, provider_str, url_str, error_msg); - - proxy_error("NL2SQL: %s\n", context.c_str()); - - set_error_details(result, "ERR_EMPTY_RESPONSE", context, 0, provider_str); - result.sql_query = "-- NL2SQL conversion failed: " + error_msg + "\n"; - result.confidence = 0.0f; - result.explanation += " (empty response)"; - return result; - } - - // Improved SQL validation - float confidence = validate_and_score_sql(raw_sql); - result.sql_query = raw_sql; - result.confidence = confidence; - - // Store in vector cache for future use if confidence is good enough - auto cache_store_start = std::chrono::steady_clock::now(); - if (req.allow_cache && confidence >= 0.5f) { - store_in_vector_cache(req, result); - } - auto cache_store_end = std::chrono::steady_clock::now(); - int cache_store_time_ms = std::chrono::duration_cast(cache_store_end - cache_store_start).count(); - - proxy_info("NL2SQL: Conversion complete. Confidence: %.2f\n", result.confidence); - - // Calculate total time - auto end_time = std::chrono::steady_clock::now(); - int total_time_ms = std::chrono::duration_cast(end_time - start_time).count(); - - // Populate timing information in result - result.total_time_ms = total_time_ms; - result.cache_lookup_time_ms = cache_lookup_time_ms; - result.cache_store_time_ms = cache_store_time_ms; - result.llm_call_time_ms = llm_call_time_ms; - result.cache_hit = false; // This will be set to true if we return from cache hit - - return result; -} - -// ============================================================================ -// Cache Management -// ============================================================================ - -void NL2SQL_Converter::clear_cache() { - proxy_info("NL2SQL: Cache cleared\n"); - // TODO: Implement cache clearing -} - -std::string NL2SQL_Converter::get_cache_stats() { - return "{\"entries\": 0, \"hits\": 0, \"misses\": 0}"; - // TODO: Implement real cache statistics -} diff --git a/lib/ProxySQL_MCP_Server.cpp b/lib/ProxySQL_MCP_Server.cpp index 434627a34..6c3ea9347 100644 --- a/lib/ProxySQL_MCP_Server.cpp +++ b/lib/ProxySQL_MCP_Server.cpp @@ -121,10 +121,10 @@ ProxySQL_MCP_Server::ProxySQL_MCP_Server(int p, MCP_Threads_Handler* h) proxy_info("Observe Tool Handler initialized\n"); } - // 6. AI Tool Handler (for NL2SQL and other AI features) + // 6. AI Tool Handler (for LLM and other AI features) extern AI_Features_Manager *GloAI; if (GloAI) { - handler->ai_tool_handler = new AI_Tool_Handler(GloAI->get_nl2sql(), GloAI->get_anomaly_detector()); + handler->ai_tool_handler = new AI_Tool_Handler(GloAI->get_llm_bridge(), GloAI->get_anomaly_detector()); if (handler->ai_tool_handler->init() == 0) { proxy_info("AI Tool Handler initialized\n"); } else { @@ -164,7 +164,7 @@ ProxySQL_MCP_Server::ProxySQL_MCP_Server(int p, MCP_Threads_Handler* h) ws->register_resource("/mcp/cache", cache_resource.get(), true); _endpoints.push_back({"/mcp/cache", std::move(cache_resource)}); - // 6. AI endpoint (for NL2SQL and other AI features) + // 6. AI endpoint (for LLM and other AI features) if (handler->ai_tool_handler) { std::unique_ptr ai_resource = std::unique_ptr(new MCP_JSONRPC_Resource(handler, handler->ai_tool_handler, "ai"));