feat: Convert NL2SQL to generic LLM bridge

- Rename NL2SQL_Converter to LLM_Bridge for generic prompt processing
- Update MySQL protocol handler from /* NL2SQL: */ to /* LLM: */
- Remove SQL-specific fields (sql_query, confidence, tables_used)
- Add GENAI_OP_LLM operation type to GenAI module
- Rename all genai_nl2sql_* variables to genai_llm_*
- Update AI_Features_Manager to use LLM_Bridge
- Deprecate ai_nl2sql_convert MCP tool with error message
- LLM bridge now handles any prompt type via MySQL protocol

This enables generic LLM access (summarization, code generation,
translation, analysis) while preserving infrastructure for future
NL2SQL implementation via Web UI + external agents.
pull/5310/head
Rene Cannao 5 months ago
parent 3fe8a48f70
commit a3f0bade4e

@ -3,41 +3,41 @@
* @brief AI Features Manager for ProxySQL
*
* The AI_Features_Manager class coordinates all AI-related features in ProxySQL:
* - NL2SQL (Natural Language to SQL) conversion
* - LLM Bridge (generic LLM access via MySQL protocol)
* - Anomaly detection for security monitoring
* - Vector storage for semantic caching
* - Hybrid model routing (local Ollama + cloud APIs)
*
* Architecture:
* - Central configuration management with 'ai-' variable prefix
* - Central configuration management with 'genai-' variable prefix
* - Thread-safe operations using pthread rwlock
* - Follows same pattern as MCP_Threads_Handler and GenAI_Threads_Handler
* - Coordinates with MySQL_Session for query interception
*
* @date 2025-01-16
* @version 0.1.0
* @date 2025-01-17
* @version 1.0.0
*
* Example Usage:
* @code
* // Access NL2SQL converter
* NL2SQL_Converter* nl2sql = GloAI->get_nl2sql();
* NL2SQLRequest req;
* req.natural_language = "Show top customers";
* NL2SQLResult result = nl2sql->convert(req);
* // Access LLM bridge
* LLM_Bridge* llm = GloAI->get_llm_bridge();
* LLMRequest req;
* req.prompt = "Summarize this data";
* LLMResult result = llm->process(req);
* @endcode
*/
#ifndef __CLASS_AI_FEATURES_MANAGER_H
#define __CLASS_AI_FEATURES_MANAGER_H
#define AI_FEATURES_MANAGER_VERSION "0.1.0"
#define AI_FEATURES_MANAGER_VERSION "1.0.0"
#include "proxysql.h"
#include <pthread.h>
#include <string>
// Forward declarations
class NL2SQL_Converter;
class LLM_Bridge;
class Anomaly_Detector;
class SQLite3DB;
@ -45,7 +45,7 @@ class SQLite3DB;
* @brief AI Features Manager
*
* Coordinates all AI features in ProxySQL:
* - NL2SQL (Natural Language to SQL) conversion
* - LLM Bridge (generic LLM access)
* - Anomaly detection for security
* - Vector storage for semantic caching
* - Hybrid model routing (local Ollama + cloud APIs)
@ -57,7 +57,7 @@ class SQLite3DB;
* - All public methods are thread-safe using pthread rwlock
* - Use wrlock()/wrunlock() for manual locking if needed
*
* @see NL2SQL_Converter, Anomaly_Detector
* @see LLM_Bridge, Anomaly_Detector
*/
class AI_Features_Manager {
private:
@ -65,7 +65,7 @@ private:
pthread_rwlock_t rwlock;
// Sub-components
NL2SQL_Converter* nl2sql_converter;
LLM_Bridge* llm_bridge;
Anomaly_Detector* anomaly_detector;
SQLite3DB* vector_db;
@ -73,7 +73,7 @@ private:
int init_vector_db();
int init_anomaly_detector();
void close_vector_db();
void close_nl2sql();
void close_llm_bridge();
void close_anomaly_detector();
public:
@ -84,16 +84,16 @@ public:
* Configuration is managed by the GenAI module (GloGATH).
*/
struct {
unsigned long long nl2sql_total_requests;
unsigned long long nl2sql_cache_hits;
unsigned long long nl2sql_local_model_calls;
unsigned long long nl2sql_cloud_model_calls;
unsigned long long nl2sql_total_response_time_ms; // Total response time for all LLM calls
unsigned long long nl2sql_cache_total_lookup_time_ms; // Total time spent in cache lookups
unsigned long long nl2sql_cache_total_store_time_ms; // Total time spent in cache storage
unsigned long long nl2sql_cache_lookups;
unsigned long long nl2sql_cache_stores;
unsigned long long nl2sql_cache_misses;
unsigned long long llm_total_requests;
unsigned long long llm_cache_hits;
unsigned long long llm_local_model_calls;
unsigned long long llm_cloud_model_calls;
unsigned long long llm_total_response_time_ms; // Total response time for all LLM calls
unsigned long long llm_cache_total_lookup_time_ms; // Total time spent in cache lookups
unsigned long long llm_cache_total_store_time_ms; // Total time spent in cache storage
unsigned long long llm_cache_lookups;
unsigned long long llm_cache_stores;
unsigned long long llm_cache_misses;
unsigned long long anomaly_total_checks;
unsigned long long anomaly_blocked_queries;
unsigned long long anomaly_flagged_queries;
@ -113,7 +113,7 @@ public:
/**
* @brief Initialize all AI features
*
* Initializes vector database, NL2SQL converter, and anomaly detector.
* Initializes vector database, LLM bridge, and anomaly detector.
* This must be called after ProxySQL configuration is loaded.
*
* @return 0 on success, non-zero on failure
@ -129,14 +129,14 @@ public:
void shutdown();
/**
* @brief Initialize NL2SQL converter
* @brief Initialize LLM bridge
*
* Initializes the NL2SQL converter if not already initialized.
* This can be called at runtime after enabling nl2sql.
* Initializes the LLM bridge if not already initialized.
* This can be called at runtime after enabling llm.
*
* @return 0 on success, non-zero on failure
*/
int init_nl2sql();
int init_llm_bridge();
/**
* @brief Acquire write lock for thread-safe operations
@ -156,25 +156,25 @@ public:
void wrunlock();
/**
* @brief Get NL2SQL converter instance
* @brief Get LLM bridge instance
*
* @return Pointer to NL2SQL_Converter or NULL if not initialized
* @return Pointer to LLM_Bridge or NULL if not initialized
*
* @note Thread-safe when called within wrlock()/wrunlock() pair
*/
NL2SQL_Converter* get_nl2sql() { return nl2sql_converter; }
LLM_Bridge* get_llm_bridge() { return llm_bridge; }
// Status variable update methods
void increment_nl2sql_total_requests() { __sync_fetch_and_add(&status_variables.nl2sql_total_requests, 1); }
void increment_nl2sql_cache_hits() { __sync_fetch_and_add(&status_variables.nl2sql_cache_hits, 1); }
void increment_nl2sql_cache_misses() { __sync_fetch_and_add(&status_variables.nl2sql_cache_misses, 1); }
void increment_nl2sql_local_model_calls() { __sync_fetch_and_add(&status_variables.nl2sql_local_model_calls, 1); }
void increment_nl2sql_cloud_model_calls() { __sync_fetch_and_add(&status_variables.nl2sql_cloud_model_calls, 1); }
void add_nl2sql_response_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.nl2sql_total_response_time_ms, ms); }
void add_nl2sql_cache_lookup_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.nl2sql_cache_total_lookup_time_ms, ms); }
void add_nl2sql_cache_store_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.nl2sql_cache_total_store_time_ms, ms); }
void increment_nl2sql_cache_lookups() { __sync_fetch_and_add(&status_variables.nl2sql_cache_lookups, 1); }
void increment_nl2sql_cache_stores() { __sync_fetch_and_add(&status_variables.nl2sql_cache_stores, 1); }
void increment_llm_total_requests() { __sync_fetch_and_add(&status_variables.llm_total_requests, 1); }
void increment_llm_cache_hits() { __sync_fetch_and_add(&status_variables.llm_cache_hits, 1); }
void increment_llm_cache_misses() { __sync_fetch_and_add(&status_variables.llm_cache_misses, 1); }
void increment_llm_local_model_calls() { __sync_fetch_and_add(&status_variables.llm_local_model_calls, 1); }
void increment_llm_cloud_model_calls() { __sync_fetch_and_add(&status_variables.llm_cloud_model_calls, 1); }
void add_llm_response_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.llm_total_response_time_ms, ms); }
void add_llm_cache_lookup_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.llm_cache_total_lookup_time_ms, ms); }
void add_llm_cache_store_time_ms(unsigned long long ms) { __sync_fetch_and_add(&status_variables.llm_cache_total_store_time_ms, ms); }
void increment_llm_cache_lookups() { __sync_fetch_and_add(&status_variables.llm_cache_lookups, 1); }
void increment_llm_cache_stores() { __sync_fetch_and_add(&status_variables.llm_cache_stores, 1); }
/**
* @brief Get anomaly detector instance

@ -19,7 +19,7 @@
#include <map>
// Forward declarations
class NL2SQL_Converter;
class LLM_Bridge;
class Anomaly_Detector;
/**
@ -31,7 +31,7 @@ class Anomaly_Detector;
*/
class AI_Tool_Handler : public MCP_Tool_Handler {
private:
NL2SQL_Converter* nl2sql_converter;
LLM_Bridge* llm_bridge;
Anomaly_Detector* anomaly_detector;
bool owns_components;
@ -50,7 +50,7 @@ public:
/**
* @brief Constructor - uses existing AI components
*/
AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly);
AI_Tool_Handler(LLM_Bridge* llm, Anomaly_Detector* anomaly);
/**
* @brief Constructor - creates own components

@ -26,6 +26,7 @@ enum GenAI_Operation : uint32_t {
GENAI_OP_EMBEDDING = 0, ///< Generate embeddings for documents
GENAI_OP_RERANK = 1, ///< Rerank documents by relevance to query
GENAI_OP_JSON = 2, ///< Autonomous JSON query processing (handles embed/rerank/document_from_sql)
GENAI_OP_LLM = 3, ///< Generic LLM bridge processing
};
/**
@ -202,17 +203,17 @@ public:
// AI Features master switches
bool genai_enabled; ///< Master enable for all AI features (default: false)
bool genai_nl2sql_enabled; ///< Enable NL2SQL feature (default: false)
bool genai_llm_enabled; ///< Enable LLM bridge feature (default: false)
bool genai_anomaly_enabled; ///< Enable anomaly detection (default: false)
// NL2SQL configuration
char* genai_nl2sql_query_prefix; ///< Prefix for NL2SQL queries (default: "NL2SQL:")
char* genai_nl2sql_provider; ///< Provider format: "openai" or "anthropic" (default: "openai")
char* genai_nl2sql_provider_url; ///< LLM endpoint URL (default: http://localhost:11434/v1/chat/completions)
char* genai_nl2sql_provider_model; ///< Model name (default: "llama3.2")
char* genai_nl2sql_provider_key; ///< API key (default: NULL)
int genai_nl2sql_cache_similarity_threshold; ///< Semantic cache threshold 0-100 (default: 85)
int genai_nl2sql_timeout_ms; ///< LLM request timeout in ms (default: 30000)
// LLM bridge configuration
char* genai_llm_provider; ///< Provider format: "openai" or "anthropic" (default: "openai")
char* genai_llm_provider_url; ///< LLM endpoint URL (default: http://localhost:11434/v1/chat/completions)
char* genai_llm_provider_model; ///< Model name (default: "llama3.2")
char* genai_llm_provider_key; ///< API key (default: NULL)
int genai_llm_cache_similarity_threshold; ///< Semantic cache threshold 0-100 (default: 85)
int genai_llm_cache_enabled; ///< Enable semantic cache (default: true)
int genai_llm_timeout_ms; ///< LLM request timeout in ms (default: 30000)
// Anomaly detection configuration
int genai_anomaly_risk_threshold; ///< Risk score threshold for blocking 0-100 (default: 70)

@ -1,35 +1,34 @@
/**
* @file nl2sql_converter.h
* @brief Natural Language to SQL Converter for ProxySQL
* @file llm_bridge.h
* @brief Generic LLM Bridge for ProxySQL
*
* The NL2SQL_Converter class provides natural language to SQL conversion
* The LLM_Bridge class provides a generic interface to Large Language Models
* using multiple LLM providers with hybrid deployment and vector-based
* semantic caching.
*
* Key Features:
* - Multi-provider LLM support (local + generic cloud)
* - Semantic similarity caching using sqlite-vec
* - Schema-aware conversion
* - Generic prompt handling (not SQL-specific)
* - Configurable model selection based on latency/budget
* - Generic provider support (OpenAI-compatible, Anthropic-compatible)
*
* @date 2025-01-16
* @version 0.2.0
* @date 2025-01-17
* @version 1.0.0
*
* Example Usage:
* @code
* NL2SQLRequest req;
* req.natural_language = "Show top 10 customers";
* req.schema_name = "sales";
* NL2SQLResult result = converter->convert(req);
* std::cout << result.sql_query << std::endl;
* LLMRequest req;
* req.prompt = "Summarize this data...";
* LLMResult result = bridge->process(req);
* std::cout << result.text_response << std::endl;
* @endcode
*/
#ifndef __CLASS_NL2SQL_CONVERTER_H
#define __CLASS_NL2SQL_CONVERTER_H
#ifndef __CLASS_LLM_BRIDGE_H
#define __CLASS_LLM_BRIDGE_H
#define NL2SQL_CONVERTER_VERSION "0.2.0"
#define LLM_BRIDGE_VERSION "1.0.0"
#include "proxysql.h"
#include <string>
@ -39,73 +38,65 @@
class SQLite3DB;
/**
* @brief Result structure for NL2SQL conversion
* @brief Result structure for LLM bridge processing
*
* Contains the generated SQL query along with metadata including
* confidence score, explanation, cache status, and error details.
*
* @note The confidence score is a heuristic based on SQL validation
* and LLM response quality. Actual SQL correctness should be
* verified before execution.
* Contains the LLM text response along with metadata including
* cache status, error details, and performance timing.
*
* @note When errors occur, error_code, error_details, and http_status_code
* provide diagnostic information for troubleshooting.
*/
struct NL2SQLResult {
std::string sql_query; ///< Generated SQL query
float confidence; ///< Confidence score 0.0-1.0
std::string explanation; ///< Which model generated this
std::vector<std::string> tables_used; ///< Tables referenced in SQL
bool cached; ///< True if from semantic cache
int64_t cache_id; ///< Cache entry ID for tracking
struct LLMResult {
std::string text_response; ///< LLM-generated text response
std::string explanation; ///< Which model generated this
bool cached; ///< True if from semantic cache
int64_t cache_id; ///< Cache entry ID for tracking
// Error details - populated when conversion fails
std::string error_code; ///< Structured error code (e.g., "ERR_API_KEY_MISSING")
std::string error_details; ///< Detailed error context with query, provider, URL
int http_status_code; ///< HTTP status code if applicable (0 if N/A)
std::string provider_used; ///< Which provider was attempted
// Error details - populated when processing fails
std::string error_code; ///< Structured error code (e.g., "ERR_API_KEY_MISSING")
std::string error_details; ///< Detailed error context with query, provider, URL
int http_status_code; ///< HTTP status code if applicable (0 if N/A)
std::string provider_used; ///< Which provider was attempted
// Performance timing information
int total_time_ms; ///< Total conversion time in milliseconds
int cache_lookup_time_ms; ///< Cache lookup time in milliseconds
int cache_store_time_ms; ///< Cache store time in milliseconds
int llm_call_time_ms; ///< LLM call time in milliseconds
bool cache_hit; ///< True if cache was hit
int total_time_ms; ///< Total processing time in milliseconds
int cache_lookup_time_ms; ///< Cache lookup time in milliseconds
int cache_store_time_ms; ///< Cache store time in milliseconds
int llm_call_time_ms; ///< LLM call time in milliseconds
bool cache_hit; ///< True if cache was hit
NL2SQLResult() : confidence(0.0f), cached(false), cache_id(0), http_status_code(0),
total_time_ms(0), cache_lookup_time_ms(0), cache_store_time_ms(0),
llm_call_time_ms(0), cache_hit(false) {}
LLMResult() : cached(false), cache_id(0), http_status_code(0),
total_time_ms(0), cache_lookup_time_ms(0), cache_store_time_ms(0),
llm_call_time_ms(0), cache_hit(false) {}
};
/**
* @brief Request structure for NL2SQL conversion
* @brief Request structure for LLM bridge processing
*
* Contains the natural language query and context for conversion.
* Context includes schema name and optional table list for better
* SQL generation.
* Contains the prompt text and context for LLM processing.
*
* @note If max_latency_ms is set and < 500ms, the system will prefer
* local Ollama regardless of provider preference.
*/
struct NL2SQLRequest {
std::string natural_language; ///< Natural language query text
std::string schema_name; ///< Current database/schema name
int max_latency_ms; ///< Max acceptable latency (ms)
bool allow_cache; ///< Enable semantic cache lookup
std::vector<std::string> context_tables; ///< Optional table hints for schema
struct LLMRequest {
std::string prompt; ///< Prompt text for LLM
std::string system_message; ///< Optional system role message
std::string schema_name; ///< Optional schema/database context
int max_latency_ms; ///< Max acceptable latency (ms)
bool allow_cache; ///< Enable semantic cache lookup
// Request tracking for correlation and debugging
std::string request_id; ///< Unique ID for this request (UUID-like)
std::string request_id; ///< Unique ID for this request (UUID-like)
// Retry configuration for transient failures
int max_retries; ///< Maximum retry attempts (default: 3)
int retry_backoff_ms; ///< Initial backoff in ms (default: 1000)
double retry_multiplier; ///< Backoff multiplier (default: 2.0)
int retry_max_backoff_ms; ///< Maximum backoff in ms (default: 30000)
int max_retries; ///< Maximum retry attempts (default: 3)
int retry_backoff_ms; ///< Initial backoff in ms (default: 1000)
double retry_multiplier; ///< Backoff multiplier (default: 2.0)
int retry_max_backoff_ms; ///< Maximum backoff in ms (default: 30000)
NL2SQLRequest() : max_latency_ms(0), allow_cache(true),
max_retries(3), retry_backoff_ms(1000),
retry_multiplier(2.0), retry_max_backoff_ms(30000) {
LLMRequest() : max_latency_ms(0), allow_cache(true),
max_retries(3), retry_backoff_ms(1000),
retry_multiplier(2.0), retry_max_backoff_ms(30000) {
// Generate UUID-like request ID for correlation
char uuid[64];
snprintf(uuid, sizeof(uuid), "%08lx-%04x-%04x-%04x-%012lx",
@ -117,7 +108,7 @@ struct NL2SQLRequest {
};
/**
* @brief Error codes for NL2SQL conversion
* @brief Error codes for LLM bridge processing
*
* Structured error codes that provide machine-readable error information
* for programmatic handling and user-friendly error messages.
@ -127,9 +118,9 @@ struct NL2SQLRequest {
* - Logging and monitoring
* - User error messages
*
* @see nl2sql_error_code_to_string()
* @see llm_error_code_to_string()
*/
enum class NL2SQLErrorCode {
enum class LLMErrorCode {
SUCCESS = 0, ///< No error
ERR_API_KEY_MISSING, ///< API key not configured
ERR_API_KEY_INVALID, ///< API key format is invalid
@ -139,7 +130,6 @@ enum class NL2SQLErrorCode {
ERR_SERVER_ERROR, ///< Server error (HTTP 5xx)
ERR_EMPTY_RESPONSE, ///< Empty response from LLM
ERR_INVALID_RESPONSE, ///< Malformed response from LLM
ERR_SQL_INJECTION_DETECTED, ///< SQL injection pattern detected
ERR_VALIDATION_FAILED, ///< Input validation failed
ERR_UNKNOWN_PROVIDER, ///< Invalid provider name
ERR_REQUEST_TOO_LARGE ///< Request exceeds size limit
@ -154,10 +144,10 @@ enum class NL2SQLErrorCode {
* @param code The error code to convert
* @return String representation of the error code
*/
const char* nl2sql_error_code_to_string(NL2SQLErrorCode code);
const char* llm_error_code_to_string(LLMErrorCode code);
/**
* @brief Model provider format types for NL2SQL conversion
* @brief Model provider format types for LLM bridge
*
* Defines the API format to use for generic providers:
* - GENERIC_OPENAI: Any OpenAI-compatible endpoint (including Ollama)
@ -176,34 +166,33 @@ enum class ModelProvider {
};
/**
* @brief NL2SQL Converter class
* @brief Generic LLM Bridge class
*
* Converts natural language queries to SQL using LLMs with hybrid
* local/cloud model support and vector cache.
* Processes prompts using LLMs with hybrid local/cloud model support
* and vector cache.
*
* Architecture:
* - Vector cache for semantic similarity (sqlite-vec)
* - Model selection based on latency/budget
* - Generic HTTP client (libcurl) supporting multiple API formats
* - Schema-aware prompt building
* - Generic prompt handling (not tied to SQL)
*
* Configuration Variables:
* - ai_nl2sql_provider: "ollama", "openai", or "anthropic"
* - ai_nl2sql_provider_url: Custom endpoint URL (for generic providers)
* - ai_nl2sql_provider_model: Model name
* - ai_nl2sql_provider_key: API key (optional for local)
* - genai_llm_provider: "ollama", "openai", or "anthropic"
* - genai_llm_provider_url: Custom endpoint URL (for generic providers)
* - genai_llm_provider_model: Model name
* - genai_llm_provider_key: API key (optional for local)
*
* Thread Safety:
* - This class is NOT thread-safe by itself
* - External locking must be provided by AI_Features_Manager
*
* @see AI_Features_Manager, NL2SQLRequest, NL2SQLResult
* @see AI_Features_Manager, LLMRequest, LLMResult
*/
class NL2SQL_Converter {
class LLM_Bridge {
private:
struct {
bool enabled;
char* query_prefix;
char* provider; ///< "openai" or "anthropic"
char* provider_url; ///< Generic endpoint URL
char* provider_model; ///< Model name
@ -215,7 +204,7 @@ private:
SQLite3DB* vector_db;
// Internal methods
std::string build_prompt(const NL2SQLRequest& req, const std::string& schema_context);
std::string build_prompt(const LLMRequest& req);
std::string call_generic_openai(const std::string& prompt, const std::string& model,
const std::string& url, const char* key,
const std::string& req_id = "");
@ -233,34 +222,31 @@ private:
const std::string& req_id,
int max_retries, int initial_backoff_ms,
double backoff_multiplier, int max_backoff_ms);
NL2SQLResult check_vector_cache(const NL2SQLRequest& req);
void store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result);
std::string get_schema_context(const std::vector<std::string>& tables);
ModelProvider select_model(const NL2SQLRequest& req);
std::vector<float> get_query_embedding(const std::string& text);
float validate_and_score_sql(const std::string& sql);
LLMResult check_cache(const LLMRequest& req);
void store_in_cache(const LLMRequest& req, const LLMResult& result);
ModelProvider select_model(const LLMRequest& req);
std::vector<float> get_text_embedding(const std::string& text);
public:
/**
* @brief Constructor - initializes with default configuration
*
* Sets up default values:
* - query_prefix: "NL2SQL:"
* - provider: "openai"
* - provider_url: "http://localhost:11434/v1/chat/completions" (Ollama default)
* - provider_model: "llama3.2"
* - cache_similarity_threshold: 85
* - timeout_ms: 30000
*/
NL2SQL_Converter();
LLM_Bridge();
/**
* @brief Destructor - frees allocated resources
*/
~NL2SQL_Converter();
~LLM_Bridge();
/**
* @brief Initialize the NL2SQL converter
* @brief Initialize the LLM bridge
*
* Initializes vector DB connection and validates configuration.
* The vector_db will be provided by AI_Features_Manager.
@ -270,7 +256,7 @@ public:
int init();
/**
* @brief Shutdown the NL2SQL converter
* @brief Shutdown the LLM bridge
*
* Closes vector DB connection and cleans up resources.
*/
@ -296,45 +282,38 @@ public:
const char* provider_key, int cache_threshold, int timeout);
/**
* @brief Convert natural language query to SQL
* @brief Process a prompt using the LLM
*
* This is the main entry point for NL2SQL conversion. The flow is:
* 1. Check vector cache for semantically similar queries
* 2. Build prompt with schema context
* This is the main entry point for LLM bridge processing. The flow is:
* 1. Check vector cache for semantically similar prompts
* 2. Build prompt with optional system message
* 3. Select appropriate model (Ollama or generic provider)
* 4. Call LLM API
* 5. Parse and clean SQL response
* 5. Parse response
* 6. Store in vector cache for future use
*
* @param req NL2SQL request containing natural language query and context
* @return NL2SQLResult with generated SQL, confidence score, and metadata
* @param req LLM request containing prompt and context
* @return LLMResult with text response and metadata
*
* @note This is a synchronous blocking call. For non-blocking behavior,
* use the async interface via MySQL_Session.
*
* @note The confidence score is heuristic-based. Actual SQL correctness
* should be verified before execution.
*
* @see NL2SQLRequest, NL2SQLResult, ModelProvider
*
* Example:
* @code
* NL2SQLRequest req;
* req.natural_language = "Find customers with orders > $1000";
* LLMRequest req;
* req.prompt = "Explain this query: SELECT * FROM users";
* req.allow_cache = true;
* NL2SQLResult result = converter.convert(req);
* if (result.confidence > 0.7f) {
* execute_sql(result.sql_query);
* }
* LLMResult result = bridge.process(req);
* std::cout << result.text_response << std::endl;
* @endcode
*/
NL2SQLResult convert(const NL2SQLRequest& req);
LLMResult process(const LLMRequest& req);
/**
* @brief Clear the vector cache
*
* Removes all cached NL2SQL conversions from the vector database.
* This is useful for testing or when schema changes significantly.
* Removes all cached LLM responses from the vector database.
* This is useful for testing or when context changes significantly.
*/
void clear_cache();
@ -342,7 +321,7 @@ public:
* @brief Get cache statistics
*
* Returns JSON string with cache metrics:
* - entries: Total number of cached conversions
* - entries: Total number of cached responses
* - hits: Number of cache hits
* - misses: Number of cache misses
*
@ -351,7 +330,4 @@ public:
std::string get_cache_stats();
};
// Global instance (defined by AI_Features_Manager)
// extern NL2SQL_Converter *GloNL2SQL;
#endif // __CLASS_NL2SQL_CONVERTER_H
#endif // __CLASS_LLM_BRIDGE_H

@ -284,7 +284,7 @@ class MySQL_Session: public Base_Session<MySQL_Session, MySQL_Data_Stream, MySQL
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_INIT_DB_replace_CLICKHOUSE(PtrSize_t& pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___not_mysql(PtrSize_t& pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai(const char* query, size_t query_len, PtrSize_t* pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(const char* query, size_t query_len, PtrSize_t* pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___llm(const char* query, size_t query_len, PtrSize_t* pkt);
#ifdef epoll_create1
/**
* @brief Handle GenAI response from socketpair

@ -63,7 +63,7 @@
// AI Features includes
#include "AI_Features_Manager.h"
#include "NL2SQL_Converter.h"
#include "LLM_Bridge.h"
#include "Anomaly_Detector.h"
#include "AI_Vector_Storage.h"

@ -1,6 +1,6 @@
#include "AI_Features_Manager.h"
#include "GenAI_Thread.h"
#include "NL2SQL_Converter.h"
#include "LLM_Bridge.h"
#include "Anomaly_Detector.h"
#include "sqlite3db.h"
#include "proxysql_utils.h"
@ -20,7 +20,7 @@ class ProxySQL_Admin;
extern ProxySQL_Admin *GloAdmin;
AI_Features_Manager::AI_Features_Manager()
: shutdown_(0), nl2sql_converter(NULL), anomaly_detector(NULL), vector_db(NULL)
: shutdown_(0), llm_bridge(NULL), anomaly_detector(NULL), vector_db(NULL)
{
pthread_rwlock_init(&rwlock, NULL);
@ -69,21 +69,21 @@ int AI_Features_Manager::init_vector_db() {
return -1;
}
// Create tables for NL2SQL cache
const char* create_nl2sql_cache =
"CREATE TABLE IF NOT EXISTS nl2sql_cache ("
// Create tables for LLM cache
const char* create_llm_cache =
"CREATE TABLE IF NOT EXISTS llm_cache ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"natural_language TEXT NOT NULL,"
"generated_sql TEXT NOT NULL,"
"schema_context TEXT,"
"prompt TEXT NOT NULL,"
"response TEXT NOT NULL,"
"system_message TEXT,"
"embedding BLOB,"
"hit_count INTEGER DEFAULT 0,"
"last_hit INTEGER,"
"created_at INTEGER DEFAULT (strftime('%s', 'now'))"
");";
if (vector_db->execute(create_nl2sql_cache) != 0) {
proxy_error("AI: Failed to create nl2sql_cache table\n");
if (vector_db->execute(create_llm_cache) != 0) {
proxy_error("AI: Failed to create llm_cache table\n");
return -1;
}
@ -108,8 +108,8 @@ int AI_Features_Manager::init_vector_db() {
const char* create_query_history =
"CREATE TABLE IF NOT EXISTS query_history ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"query_text TEXT NOT NULL,"
"generated_sql TEXT,"
"prompt TEXT NOT NULL,"
"response TEXT,"
"embedding BLOB,"
"execution_time_ms INTEGER,"
"success BOOLEAN,"
@ -124,16 +124,16 @@ int AI_Features_Manager::init_vector_db() {
// Create virtual vector tables for similarity search using sqlite-vec
// Note: sqlite-vec extension is auto-loaded in Admin_Bootstrap.cpp:612
// 1. NL2SQL cache virtual table
const char* create_nl2sql_vec =
"CREATE VIRTUAL TABLE IF NOT EXISTS nl2sql_cache_vec USING vec0("
// 1. LLM cache virtual table
const char* create_llm_vec =
"CREATE VIRTUAL TABLE IF NOT EXISTS llm_cache_vec USING vec0("
"embedding float(1536)"
");";
if (vector_db->execute(create_nl2sql_vec) != 0) {
proxy_error("AI: Failed to create nl2sql_cache_vec virtual table\n");
if (vector_db->execute(create_llm_vec) != 0) {
proxy_error("AI: Failed to create llm_cache_vec virtual table\n");
// Virtual table creation failure is not critical - log and continue
proxy_debug(PROXY_DEBUG_GENAI, 3, "Continuing without nl2sql_cache_vec");
proxy_debug(PROXY_DEBUG_GENAI, 3, "Continuing without llm_cache_vec");
}
// 2. Anomaly patterns virtual table
@ -162,37 +162,37 @@ int AI_Features_Manager::init_vector_db() {
return 0;
}
int AI_Features_Manager::init_nl2sql() {
if (!GloGATH->variables.genai_nl2sql_enabled) {
proxy_info("AI: NL2SQL disabled, skipping initialization\n");
int AI_Features_Manager::init_llm_bridge() {
if (!GloGATH->variables.genai_llm_enabled) {
proxy_info("AI: LLM bridge disabled, skipping initialization\n");
return 0;
}
proxy_info("AI: Initializing NL2SQL Converter\n");
proxy_info("AI: Initializing LLM Bridge\n");
nl2sql_converter = new NL2SQL_Converter();
llm_bridge = new LLM_Bridge();
// Set vector database
nl2sql_converter->set_vector_db(vector_db);
llm_bridge->set_vector_db(vector_db);
// Update config with current variables from GenAI module
nl2sql_converter->update_config(
GloGATH->variables.genai_nl2sql_provider,
GloGATH->variables.genai_nl2sql_provider_url,
GloGATH->variables.genai_nl2sql_provider_model,
GloGATH->variables.genai_nl2sql_provider_key,
GloGATH->variables.genai_nl2sql_cache_similarity_threshold,
GloGATH->variables.genai_nl2sql_timeout_ms
llm_bridge->update_config(
GloGATH->variables.genai_llm_provider,
GloGATH->variables.genai_llm_provider_url,
GloGATH->variables.genai_llm_provider_model,
GloGATH->variables.genai_llm_provider_key,
GloGATH->variables.genai_llm_cache_similarity_threshold,
GloGATH->variables.genai_llm_timeout_ms
);
if (nl2sql_converter->init() != 0) {
proxy_error("AI: Failed to initialize NL2SQL Converter\n");
delete nl2sql_converter;
nl2sql_converter = NULL;
if (llm_bridge->init() != 0) {
proxy_error("AI: Failed to initialize LLM Bridge\n");
delete llm_bridge;
llm_bridge = NULL;
return -1;
}
proxy_info("AI: NL2SQL Converter initialized\n");
proxy_info("AI: LLM Bridge initialized\n");
return 0;
}
@ -223,11 +223,11 @@ void AI_Features_Manager::close_vector_db() {
}
}
void AI_Features_Manager::close_nl2sql() {
if (nl2sql_converter) {
nl2sql_converter->close();
delete nl2sql_converter;
nl2sql_converter = NULL;
void AI_Features_Manager::close_llm_bridge() {
if (llm_bridge) {
llm_bridge->close();
delete llm_bridge;
llm_bridge = NULL;
}
}
@ -247,15 +247,15 @@ int AI_Features_Manager::init() {
return 0;
}
// Initialize vector storage first (needed by both NL2SQL and Anomaly Detector)
// Initialize vector storage first (needed by both LLM bridge and Anomaly Detector)
if (init_vector_db() != 0) {
proxy_error("AI: Failed to initialize vector storage\n");
return -1;
}
// Initialize NL2SQL
if (init_nl2sql() != 0) {
proxy_error("AI: Failed to initialize NL2SQL\n");
// Initialize LLM bridge
if (init_llm_bridge() != 0) {
proxy_error("AI: Failed to initialize LLM bridge\n");
return -1;
}
@ -275,7 +275,7 @@ void AI_Features_Manager::shutdown() {
proxy_info("AI: Shutting down AI Features Manager\n");
close_nl2sql();
close_llm_bridge();
close_anomaly_detector();
close_vector_db();
@ -299,7 +299,7 @@ std::string AI_Features_Manager::get_status_json() {
snprintf(buf, sizeof(buf),
"{"
"\"version\": \"%s\","
"\"nl2sql\": {"
"\"llm\": {"
"\"total_requests\": %llu,"
"\"cache_hits\": %llu,"
"\"local_calls\": %llu,"
@ -321,16 +321,16 @@ std::string AI_Features_Manager::get_status_json() {
"}"
"}",
AI_FEATURES_MANAGER_VERSION,
status_variables.nl2sql_total_requests,
status_variables.nl2sql_cache_hits,
status_variables.nl2sql_local_model_calls,
status_variables.nl2sql_cloud_model_calls,
status_variables.nl2sql_total_response_time_ms,
status_variables.nl2sql_cache_total_lookup_time_ms,
status_variables.nl2sql_cache_total_store_time_ms,
status_variables.nl2sql_cache_lookups,
status_variables.nl2sql_cache_stores,
status_variables.nl2sql_cache_misses,
status_variables.llm_total_requests,
status_variables.llm_cache_hits,
status_variables.llm_local_model_calls,
status_variables.llm_cloud_model_calls,
status_variables.llm_total_response_time_ms,
status_variables.llm_cache_total_lookup_time_ms,
status_variables.llm_cache_total_store_time_ms,
status_variables.llm_cache_lookups,
status_variables.llm_cache_stores,
status_variables.llm_cache_misses,
status_variables.anomaly_total_checks,
status_variables.anomaly_blocked_queries,
status_variables.anomaly_flagged_queries,

@ -9,7 +9,7 @@
*/
#include "AI_Tool_Handler.h"
#include "NL2SQL_Converter.h"
#include "LLM_Bridge.h"
#include "Anomaly_Detector.h"
#include "AI_Features_Manager.h"
#include "proxysql_debug.h"
@ -29,8 +29,8 @@ using json = nlohmann::json;
/**
* @brief Constructor using existing AI components
*/
AI_Tool_Handler::AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* anomaly)
: nl2sql_converter(nl2sql),
AI_Tool_Handler::AI_Tool_Handler(LLM_Bridge* llm, Anomaly_Detector* anomaly)
: llm_bridge(llm),
anomaly_detector(anomaly),
owns_components(false)
{
@ -42,13 +42,13 @@ AI_Tool_Handler::AI_Tool_Handler(NL2SQL_Converter* nl2sql, Anomaly_Detector* ano
* Note: This implementation uses global instances
*/
AI_Tool_Handler::AI_Tool_Handler()
: nl2sql_converter(NULL),
: llm_bridge(NULL),
anomaly_detector(NULL),
owns_components(false)
{
// Use global instances from AI_Features_Manager
if (GloAI) {
nl2sql_converter = GloAI->get_nl2sql();
llm_bridge = GloAI->get_llm_bridge();
anomaly_detector = GloAI->get_anomaly_detector();
}
proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler created (using global instances)\n");
@ -70,8 +70,8 @@ AI_Tool_Handler::~AI_Tool_Handler() {
* @brief Initialize the tool handler
*/
int AI_Tool_Handler::init() {
if (!nl2sql_converter) {
proxy_error("AI_Tool_Handler: NL2SQL converter not available\n");
if (!llm_bridge) {
proxy_error("AI_Tool_Handler: LLM bridge not available\n");
return -1;
}
proxy_info("AI_Tool_Handler initialized\n");
@ -199,73 +199,13 @@ json AI_Tool_Handler::execute_tool(const std::string& tool_name, const json& arg
proxy_debug(PROXY_DEBUG_GENAI, 3, "AI_Tool_Handler: execute_tool(%s)\n", tool_name.c_str());
try {
// NL2SQL conversion tool
// LLM processing tool (generic, replaces NL2SQL)
if (tool_name == "ai_nl2sql_convert") {
if (!nl2sql_converter) {
return create_error_response("NL2SQL converter not available");
}
// Extract parameters
std::string natural_language = get_json_string(arguments, "natural_language");
if (natural_language.empty()) {
return create_error_response("Missing required parameter: natural_language");
}
std::string schema = get_json_string(arguments, "schema");
int max_latency_ms = get_json_int(arguments, "max_latency_ms", 0);
bool allow_cache = true;
if (arguments.contains("allow_cache") && !arguments["allow_cache"].is_null()) {
if (arguments["allow_cache"].is_boolean()) {
allow_cache = arguments["allow_cache"].get<bool>();
} else if (arguments["allow_cache"].is_string()) {
std::string val = arguments["allow_cache"].get<std::string>();
allow_cache = (val == "true" || val == "1");
}
}
// Parse context_tables
std::vector<std::string> context_tables;
std::string tables_str = get_json_string(arguments, "context_tables");
if (!tables_str.empty()) {
std::istringstream ts(tables_str);
std::string table;
while (std::getline(ts, table, ',')) {
table.erase(0, table.find_first_not_of(" \t"));
table.erase(table.find_last_not_of(" \t") + 1);
if (!table.empty()) {
context_tables.push_back(table);
}
}
}
// Create NL2SQL request
NL2SQLRequest req;
req.natural_language = natural_language;
req.schema_name = schema;
req.max_latency_ms = max_latency_ms;
req.allow_cache = allow_cache;
req.context_tables = context_tables;
// Call NL2SQL converter
NL2SQLResult result = nl2sql_converter->convert(req);
// Build response
json response_data;
response_data["sql_query"] = result.sql_query;
response_data["confidence"] = result.confidence;
response_data["explanation"] = result.explanation;
response_data["cached"] = result.cached;
response_data["cache_id"] = result.cache_id;
// Add tables used if available
if (!result.tables_used.empty()) {
response_data["tables_used"] = result.tables_used;
}
proxy_info("AI_Tool_Handler: NL2SQL conversion complete. SQL: %s, Confidence: %.2f\n",
result.sql_query.c_str(), result.confidence);
return create_success_response(response_data);
// NOTE: The ai_nl2sql_convert tool is deprecated.
// NL2SQL functionality has been replaced with a generic LLM bridge.
// Future NL2SQL will be implemented as a Web UI using external agents (Claude Code + MCP server).
return create_error_response("The ai_nl2sql_convert tool is deprecated. "
"Use the generic LLM: queries via MySQL protocol instead.");
}
// Unknown tool

@ -1079,11 +1079,11 @@ void ProxySQL_Admin::flush_genai_variables___database_to_runtime(SQLite3DB* db,
pthread_mutex_unlock(&GloVars.checksum_mutex);
}
// Check if NL2SQL needs to be initialized
if (GloAI && GloGATH->variables.genai_nl2sql_enabled && !GloAI->get_nl2sql()) {
proxy_info("NL2SQL enabled but not initialized, initializing now\n");
if (GloAI->init_nl2sql() != 0) {
proxy_error("Failed to initialize NL2SQL converter\n");
// Check if LLM bridge needs to be initialized
if (GloAI && GloGATH->variables.genai_llm_enabled && !GloAI->get_llm_bridge()) {
proxy_info("LLM bridge enabled but not initialized, initializing now\n");
if (GloAI->init_llm_bridge() != 0) {
proxy_error("Failed to initialize LLM bridge\n");
}
}

@ -45,17 +45,17 @@ static const char* genai_thread_variables_names[] = {
// AI Features master switches
"enabled",
"nl2sql_enabled",
"llm_enabled",
"anomaly_enabled",
// NL2SQL configuration
"nl2sql_query_prefix",
"nl2sql_provider",
"nl2sql_provider_url",
"nl2sql_provider_model",
"nl2sql_provider_key",
"nl2sql_cache_similarity_threshold",
"nl2sql_timeout_ms",
// LLM bridge configuration
"llm_provider",
"llm_provider_url",
"llm_provider_model",
"llm_provider_key",
"llm_cache_similarity_threshold",
"llm_cache_enabled",
"llm_timeout_ms",
// Anomaly detection configuration
"anomaly_risk_threshold",
@ -153,17 +153,17 @@ GenAI_Threads_Handler::GenAI_Threads_Handler() {
// AI Features master switches
variables.genai_enabled = false;
variables.genai_nl2sql_enabled = false;
variables.genai_llm_enabled = false;
variables.genai_anomaly_enabled = false;
// NL2SQL configuration
variables.genai_nl2sql_query_prefix = strdup("NL2SQL:");
variables.genai_nl2sql_provider = strdup("openai");
variables.genai_nl2sql_provider_url = strdup("http://localhost:11434/v1/chat/completions");
variables.genai_nl2sql_provider_model = strdup("llama3.2");
variables.genai_nl2sql_provider_key = NULL;
variables.genai_nl2sql_cache_similarity_threshold = 85;
variables.genai_nl2sql_timeout_ms = 30000;
// LLM bridge configuration
variables.genai_llm_provider = strdup("openai");
variables.genai_llm_provider_url = strdup("http://localhost:11434/v1/chat/completions");
variables.genai_llm_provider_model = strdup("llama3.2");
variables.genai_llm_provider_key = NULL;
variables.genai_llm_cache_similarity_threshold = 85;
variables.genai_llm_cache_enabled = true;
variables.genai_llm_timeout_ms = 30000;
// Anomaly detection configuration
variables.genai_anomaly_risk_threshold = 70;
@ -197,17 +197,15 @@ GenAI_Threads_Handler::~GenAI_Threads_Handler() {
if (variables.genai_rerank_uri)
free(variables.genai_rerank_uri);
// Free NL2SQL string variables
if (variables.genai_nl2sql_query_prefix)
free(variables.genai_nl2sql_query_prefix);
if (variables.genai_nl2sql_provider)
free(variables.genai_nl2sql_provider);
if (variables.genai_nl2sql_provider_url)
free(variables.genai_nl2sql_provider_url);
if (variables.genai_nl2sql_provider_model)
free(variables.genai_nl2sql_provider_model);
if (variables.genai_nl2sql_provider_key)
free(variables.genai_nl2sql_provider_key);
// Free LLM bridge string variables
if (variables.genai_llm_provider)
free(variables.genai_llm_provider);
if (variables.genai_llm_provider_url)
free(variables.genai_llm_provider_url);
if (variables.genai_llm_provider_model)
free(variables.genai_llm_provider_model);
if (variables.genai_llm_provider_key)
free(variables.genai_llm_provider_key);
// Free vector storage string variables
if (variables.genai_vector_db_path)
@ -377,37 +375,34 @@ char* GenAI_Threads_Handler::get_variable(char* name) {
if (!strcmp(name, "enabled")) {
return strdup(variables.genai_enabled ? "true" : "false");
}
if (!strcmp(name, "nl2sql_enabled")) {
return strdup(variables.genai_nl2sql_enabled ? "true" : "false");
if (!strcmp(name, "llm_enabled")) {
return strdup(variables.genai_llm_enabled ? "true" : "false");
}
if (!strcmp(name, "anomaly_enabled")) {
return strdup(variables.genai_anomaly_enabled ? "true" : "false");
}
// NL2SQL configuration
if (!strcmp(name, "nl2sql_query_prefix")) {
return strdup(variables.genai_nl2sql_query_prefix ? variables.genai_nl2sql_query_prefix : "");
// LLM configuration
if (!strcmp(name, "llm_provider")) {
return strdup(variables.genai_llm_provider ? variables.genai_llm_provider : "");
}
if (!strcmp(name, "nl2sql_provider")) {
return strdup(variables.genai_nl2sql_provider ? variables.genai_nl2sql_provider : "");
if (!strcmp(name, "llm_provider_url")) {
return strdup(variables.genai_llm_provider_url ? variables.genai_llm_provider_url : "");
}
if (!strcmp(name, "nl2sql_provider_url")) {
return strdup(variables.genai_nl2sql_provider_url ? variables.genai_nl2sql_provider_url : "");
if (!strcmp(name, "llm_provider_model")) {
return strdup(variables.genai_llm_provider_model ? variables.genai_llm_provider_model : "");
}
if (!strcmp(name, "nl2sql_provider_model")) {
return strdup(variables.genai_nl2sql_provider_model ? variables.genai_nl2sql_provider_model : "");
if (!strcmp(name, "llm_provider_key")) {
return strdup(variables.genai_llm_provider_key ? variables.genai_llm_provider_key : "");
}
if (!strcmp(name, "nl2sql_provider_key")) {
return strdup(variables.genai_nl2sql_provider_key ? variables.genai_nl2sql_provider_key : "");
}
if (!strcmp(name, "nl2sql_cache_similarity_threshold")) {
if (!strcmp(name, "llm_cache_similarity_threshold")) {
char buf[64];
sprintf(buf, "%d", variables.genai_nl2sql_cache_similarity_threshold);
sprintf(buf, "%d", variables.genai_llm_cache_similarity_threshold);
return strdup(buf);
}
if (!strcmp(name, "nl2sql_timeout_ms")) {
if (!strcmp(name, "llm_timeout_ms")) {
char buf[64];
sprintf(buf, "%d", variables.genai_nl2sql_timeout_ms);
sprintf(buf, "%d", variables.genai_llm_timeout_ms);
return strdup(buf);
}
@ -512,8 +507,8 @@ bool GenAI_Threads_Handler::set_variable(char* name, const char* value) {
variables.genai_enabled = (strcmp(value, "true") == 0);
return true;
}
if (!strcmp(name, "nl2sql_enabled")) {
variables.genai_nl2sql_enabled = (strcmp(value, "true") == 0);
if (!strcmp(name, "llm_enabled")) {
variables.genai_llm_enabled = (strcmp(value, "true") == 0);
return true;
}
if (!strcmp(name, "anomaly_enabled")) {
@ -521,53 +516,47 @@ bool GenAI_Threads_Handler::set_variable(char* name, const char* value) {
return true;
}
// NL2SQL configuration
if (!strcmp(name, "nl2sql_query_prefix")) {
if (variables.genai_nl2sql_query_prefix)
free(variables.genai_nl2sql_query_prefix);
variables.genai_nl2sql_query_prefix = strdup(value);
return true;
}
if (!strcmp(name, "nl2sql_provider")) {
if (variables.genai_nl2sql_provider)
free(variables.genai_nl2sql_provider);
variables.genai_nl2sql_provider = strdup(value);
// LLM configuration
if (!strcmp(name, "llm_provider")) {
if (variables.genai_llm_provider)
free(variables.genai_llm_provider);
variables.genai_llm_provider = strdup(value);
return true;
}
if (!strcmp(name, "nl2sql_provider_url")) {
if (variables.genai_nl2sql_provider_url)
free(variables.genai_nl2sql_provider_url);
variables.genai_nl2sql_provider_url = strdup(value);
if (!strcmp(name, "llm_provider_url")) {
if (variables.genai_llm_provider_url)
free(variables.genai_llm_provider_url);
variables.genai_llm_provider_url = strdup(value);
return true;
}
if (!strcmp(name, "nl2sql_provider_model")) {
if (variables.genai_nl2sql_provider_model)
free(variables.genai_nl2sql_provider_model);
variables.genai_nl2sql_provider_model = strdup(value);
if (!strcmp(name, "llm_provider_model")) {
if (variables.genai_llm_provider_model)
free(variables.genai_llm_provider_model);
variables.genai_llm_provider_model = strdup(value);
return true;
}
if (!strcmp(name, "nl2sql_provider_key")) {
if (variables.genai_nl2sql_provider_key)
free(variables.genai_nl2sql_provider_key);
variables.genai_nl2sql_provider_key = strdup(value);
if (!strcmp(name, "llm_provider_key")) {
if (variables.genai_llm_provider_key)
free(variables.genai_llm_provider_key);
variables.genai_llm_provider_key = strdup(value);
return true;
}
if (!strcmp(name, "nl2sql_cache_similarity_threshold")) {
if (!strcmp(name, "llm_cache_similarity_threshold")) {
int val = atoi(value);
if (val < 0 || val > 100) {
proxy_error("Invalid value for genai_nl2sql_cache_similarity_threshold: %d (must be 0-100)\n", val);
proxy_error("Invalid value for genai_llm_cache_similarity_threshold: %d (must be 0-100)\n", val);
return false;
}
variables.genai_nl2sql_cache_similarity_threshold = val;
variables.genai_llm_cache_similarity_threshold = val;
return true;
}
if (!strcmp(name, "nl2sql_timeout_ms")) {
if (!strcmp(name, "llm_timeout_ms")) {
int val = atoi(value);
if (val < 1000 || val > 600000) {
proxy_error("Invalid value for genai_nl2sql_timeout_ms: %d (must be 1000-600000)\n", val);
proxy_error("Invalid value for genai_llm_timeout_ms: %d (must be 1000-600000)\n", val);
return false;
}
variables.genai_nl2sql_timeout_ms = val;
variables.genai_llm_timeout_ms = val;
return true;
}
@ -1709,30 +1698,30 @@ std::string GenAI_Threads_Handler::process_json_query(const std::string& json_qu
return result.dump();
}
// Handle nl2sql operation
if (op_type == "nl2sql") {
// Handle llm operation
if (op_type == "llm") {
// Check if AI manager is available
if (!GloAI) {
result["error"] = "AI features manager is not initialized";
return result.dump();
}
// Extract natural language query
if (!query_json.contains("query") || !query_json["query"].is_string()) {
result["error"] = "NL2SQL operation requires a 'query' string";
// Extract prompt
if (!query_json.contains("prompt") || !query_json["prompt"].is_string()) {
result["error"] = "LLM operation requires a 'prompt' string";
return result.dump();
}
std::string nl_query = query_json["query"].get<std::string>();
std::string prompt = query_json["prompt"].get<std::string>();
if (nl_query.empty()) {
result["error"] = "NL2SQL query cannot be empty";
if (prompt.empty()) {
result["error"] = "LLM prompt cannot be empty";
return result.dump();
}
// Extract optional schema name
std::string schema_name;
if (query_json.contains("schema") && query_json["schema"].is_string()) {
schema_name = query_json["schema"].get<std::string>();
// Extract optional system message
std::string system_message;
if (query_json.contains("system_message") && query_json["system_message"].is_string()) {
system_message = query_json["system_message"].get<std::string>();
}
// Extract optional cache flag
@ -1741,41 +1730,37 @@ std::string GenAI_Threads_Handler::process_json_query(const std::string& json_qu
allow_cache = query_json["allow_cache"].get<bool>();
}
// Get NL2SQL converter
NL2SQL_Converter* nl2sql = GloAI->get_nl2sql();
if (!nl2sql) {
result["error"] = "NL2SQL converter is not initialized";
// Get LLM bridge
LLM_Bridge* llm_bridge = GloAI->get_llm_bridge();
if (!llm_bridge) {
result["error"] = "LLM bridge is not initialized";
return result.dump();
}
// Build NL2SQL request
NL2SQLRequest req;
req.natural_language = nl_query;
req.schema_name = schema_name;
// Build LLM request
LLMRequest req;
req.prompt = prompt;
req.system_message = system_message;
req.allow_cache = allow_cache;
req.max_latency_ms = 0; // No specific latency requirement
// Convert (this will use cache if available)
NL2SQLResult sql_result = nl2sql->convert(req);
// Process (this will use cache if available)
LLMResult llm_result = llm_bridge->process(req);
if (sql_result.sql_query.empty() || sql_result.sql_query.find("NL2SQL conversion failed") == 0) {
result["error"] = "Failed to convert natural language to SQL: " + sql_result.explanation;
if (!llm_result.error_code.empty()) {
result["error"] = "LLM processing failed: " + llm_result.error_details;
return result.dump();
}
// Build result
result["columns"] = json::array({"sql_query", "confidence", "explanation", "cached"});
// Build result - return as single row with text_response
result["columns"] = json::array({"text_response", "explanation", "cached", "provider"});
json rows = json::array();
json row = json::array();
row.push_back(sql_result.sql_query);
char conf_buf[32];
snprintf(conf_buf, sizeof(conf_buf), "%.2f", sql_result.confidence);
row.push_back(std::string(conf_buf));
row.push_back(sql_result.explanation);
row.push_back(sql_result.cached ? "true" : "false");
row.push_back(llm_result.text_response);
row.push_back(llm_result.explanation);
row.push_back(llm_result.cached ? "true" : "false");
row.push_back(llm_result.provider_used);
rows.push_back(row);
result["rows"] = rows;
@ -1784,7 +1769,7 @@ std::string GenAI_Threads_Handler::process_json_query(const std::string& json_qu
}
// Unknown operation type
result["error"] = "Unknown operation type: " + op_type + ". Use 'embed', 'rerank', or 'nl2sql'";
result["error"] = "Unknown operation type: " + op_type + ". Use 'embed', 'rerank', or 'llm'";
return result.dump();
} catch (const json::parse_error& e) {

@ -0,0 +1,375 @@
/**
* @file LLM_Bridge.cpp
* @brief Implementation of Generic LLM Bridge
*
* This file implements the generic LLM bridge pipeline including:
* - Vector cache operations for semantic similarity
* - Model selection based on latency/budget
* - Generic LLM API calls (Ollama, OpenAI-compatible, Anthropic-compatible)
*
* @see LLM_Bridge.h
*/
#include "LLM_Bridge.h"
#include "sqlite3db.h"
#include "proxysql_utils.h"
#include "GenAI_Thread.h"
#include "cpp.h"
#include <cstring>
#include <cstdlib>
#include <sstream>
#include <algorithm>
#include <regex>
#include <chrono>
using json = nlohmann::json;
// Global GenAI handler for embedding generation
extern GenAI_Threads_Handler *GloGATH;
// Global AI Features Manager for status updates
extern AI_Features_Manager *GloAI;
// ============================================================================
// Error Handling Helper Functions
// ============================================================================
/**
* @brief Convert error code enum to string representation
*/
const char* llm_error_code_to_string(LLMErrorCode code) {
switch (code) {
case LLMErrorCode::SUCCESS: return "SUCCESS";
case LLMErrorCode::ERR_API_KEY_MISSING: return "ERR_API_KEY_MISSING";
case LLMErrorCode::ERR_API_KEY_INVALID: return "ERR_API_KEY_INVALID";
case LLMErrorCode::ERR_TIMEOUT: return "ERR_TIMEOUT";
case LLMErrorCode::ERR_CONNECTION_FAILED: return "ERR_CONNECTION_FAILED";
case LLMErrorCode::ERR_RATE_LIMITED: return "ERR_RATE_LIMITED";
case LLMErrorCode::ERR_SERVER_ERROR: return "ERR_SERVER_ERROR";
case LLMErrorCode::ERR_EMPTY_RESPONSE: return "ERR_EMPTY_RESPONSE";
case LLMErrorCode::ERR_INVALID_RESPONSE: return "ERR_INVALID_RESPONSE";
case LLMErrorCode::ERR_VALIDATION_FAILED: return "ERR_VALIDATION_FAILED";
case LLMErrorCode::ERR_UNKNOWN_PROVIDER: return "ERR_UNKNOWN_PROVIDER";
case LLMErrorCode::ERR_REQUEST_TOO_LARGE: return "ERR_REQUEST_TOO_LARGE";
default: return "UNKNOWN";
}
}
// Forward declarations of external functions from LLM_Clients.cpp
extern std::string call_generic_openai_with_retry(const std::string& prompt, const std::string& model,
const std::string& url, const char* key,
const std::string& req_id);
extern std::string call_generic_anthropic_with_retry(const std::string& prompt, const std::string& model,
const std::string& url, const char* key,
const std::string& req_id);
// ============================================================================
// LLM_Bridge Implementation
// ============================================================================
/**
* @brief Constructor - initializes with default configuration
*/
LLM_Bridge::LLM_Bridge()
: vector_db(nullptr)
{
// Set default configuration
config.enabled = false;
config.provider = strdup("openai");
config.provider_url = strdup("http://localhost:11434/v1/chat/completions");
config.provider_model = strdup("llama3.2");
config.provider_key = nullptr;
config.cache_similarity_threshold = 85;
config.timeout_ms = 30000;
proxy_debug(PROXY_DEBUG_GENAI, 3, "LLM_Bridge: Initialized with defaults\n");
}
/**
* @brief Destructor - frees allocated resources
*/
LLM_Bridge::~LLM_Bridge() {
if (config.provider) free(config.provider);
if (config.provider_url) free(config.provider_url);
if (config.provider_model) free(config.provider_model);
if (config.provider_key) free(config.provider_key);
proxy_debug(PROXY_DEBUG_GENAI, 3, "LLM_Bridge: Destroyed\n");
}
/**
* @brief Initialize the LLM bridge
*/
int LLM_Bridge::init() {
proxy_info("LLM_Bridge: Initialized successfully\n");
return 0;
}
/**
* @brief Shutdown the LLM bridge
*/
void LLM_Bridge::close() {
proxy_info("LLM_Bridge: Shutdown complete\n");
}
/**
* @brief Update configuration from AI_Features_Manager
*/
void LLM_Bridge::update_config(const char* provider, const char* provider_url, const char* provider_model,
const char* provider_key, int cache_threshold, int timeout) {
if (provider) {
if (config.provider) free(config.provider);
config.provider = strdup(provider);
}
if (provider_url) {
if (config.provider_url) free(config.provider_url);
config.provider_url = strdup(provider_url);
}
if (provider_model) {
if (config.provider_model) free(config.provider_model);
config.provider_model = strdup(provider_model);
}
if (provider_key) {
if (config.provider_key) free(config.provider_key);
config.provider_key = provider_key ? strdup(provider_key) : nullptr;
}
config.cache_similarity_threshold = cache_threshold;
config.timeout_ms = timeout;
proxy_debug(PROXY_DEBUG_GENAI, 3, "LLM_Bridge: Configuration updated\n");
}
/**
* @brief Build prompt from request
*/
std::string LLM_Bridge::build_prompt(const LLMRequest& req) {
std::string prompt = req.prompt;
// Add system message if provided
if (!req.system_message.empty()) {
// For most LLM APIs, the system message is handled separately
// This is a simplified implementation
}
return prompt;
}
/**
* @brief Check vector cache for similar prompts
*/
LLMResult LLM_Bridge::check_cache(const LLMRequest& req) {
LLMResult result;
result.cached = false;
result.cache_hit = false;
if (!vector_db || !req.allow_cache) {
return result;
}
auto start_time = std::chrono::high_resolution_clock::now();
// TODO: Implement vector similarity search
// This would involve:
// 1. Generate embedding for the prompt
// 2. Search vector database for similar prompts
// 3. If similarity >= threshold, return cached response
auto end_time = std::chrono::high_resolution_clock::now();
result.cache_lookup_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time).count();
return result;
}
/**
* @brief Store result in vector cache
*/
void LLM_Bridge::store_in_cache(const LLMRequest& req, const LLMResult& result) {
if (!vector_db || !req.allow_cache) {
return;
}
auto start_time = std::chrono::high_resolution_clock::now();
// TODO: Implement cache storage
// This would involve:
// 1. Generate embedding for the prompt
// 2. Store prompt embedding, response, and metadata in cache table
auto end_time = std::chrono::high_resolution_clock::now();
const_cast<LLMResult&>(result).cache_store_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time).count();
}
/**
* @brief Select appropriate model based on request
*/
ModelProvider LLM_Bridge::select_model(const LLMRequest& req) {
if (!config.provider) {
return ModelProvider::FALLBACK_ERROR;
}
if (strcmp(config.provider, "openai") == 0) {
return ModelProvider::GENERIC_OPENAI;
} else if (strcmp(config.provider, "anthropic") == 0) {
return ModelProvider::GENERIC_ANTHROPIC;
}
return ModelProvider::FALLBACK_ERROR;
}
/**
* @brief Get text embedding for vector cache
*/
std::vector<float> LLM_Bridge::get_text_embedding(const std::string& text) {
std::vector<float> embedding;
// Use GenAI module for embedding generation
if (GloGATH) {
std::vector<std::string> texts = {text};
GenAI_EmbeddingResult result = GloGATH->embed_documents(texts);
if (result.data && result.count > 0) {
// Copy embedding data
size_t dim = result.embedding_size;
embedding.assign(result.data, result.data + dim);
}
}
return embedding;
}
/**
* @brief Process a prompt using the LLM
*/
LLMResult LLM_Bridge::process(const LLMRequest& req) {
LLMResult result;
auto total_start = std::chrono::high_resolution_clock::now();
// Check cache first
result = check_cache(req);
if (result.cached) {
result.cache_hit = true;
result.total_time_ms = result.cache_lookup_time_ms;
if (GloAI) {
GloAI->increment_llm_cache_hits();
GloAI->add_llm_cache_lookup_time_ms(result.cache_lookup_time_ms);
GloAI->add_llm_response_time_ms(result.total_time_ms);
}
return result;
}
if (GloAI) {
GloAI->increment_llm_cache_misses();
GloAI->increment_llm_cache_lookups();
GloAI->add_llm_cache_lookup_time_ms(result.cache_lookup_time_ms);
}
// Build prompt
std::string prompt = build_prompt(req);
// Select model
ModelProvider provider = select_model(req);
if (provider == ModelProvider::FALLBACK_ERROR) {
result.error_code = llm_error_code_to_string(LLMErrorCode::ERR_UNKNOWN_PROVIDER);
result.error_details = "Unknown provider: " + std::string(config.provider ? config.provider : "null");
return result;
}
// Call LLM API
auto llm_start = std::chrono::high_resolution_clock::now();
std::string raw_response;
try {
if (provider == ModelProvider::GENERIC_OPENAI) {
raw_response = call_generic_openai_with_retry(
prompt,
config.provider_model ? config.provider_model : "",
config.provider_url ? config.provider_url : "",
config.provider_key,
req.request_id,
req.max_retries,
req.retry_backoff_ms,
req.retry_multiplier,
req.retry_max_backoff_ms
);
result.provider_used = "openai";
} else if (provider == ModelProvider::GENERIC_ANTHROPIC) {
raw_response = call_generic_anthropic_with_retry(
prompt,
config.provider_model ? config.provider_model : "",
config.provider_url ? config.provider_url : "",
config.provider_key,
req.request_id,
req.max_retries,
req.retry_backoff_ms,
req.retry_multiplier,
req.retry_max_backoff_ms
);
result.provider_used = "anthropic";
}
} catch (const std::exception& e) {
result.error_code = "ERR_EXCEPTION";
result.error_details = e.what();
result.http_status_code = 0;
}
auto llm_end = std::chrono::high_resolution_clock::now();
result.llm_call_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(llm_end - llm_start).count();
// Parse response
if (raw_response.empty() && result.error_code.empty()) {
result.error_code = llm_error_code_to_string(LLMErrorCode::ERR_EMPTY_RESPONSE);
result.error_details = "LLM returned empty response";
} else if (!result.error_code.empty()) {
// Error already set by exception handler
} else {
result.text_response = raw_response;
}
// Store in cache
store_in_cache(req, result);
auto total_end = std::chrono::high_resolution_clock::now();
result.total_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(total_end - total_start).count();
// Update status counters
if (GloAI) {
GloAI->add_llm_response_time_ms(result.total_time_ms);
if (result.cache_store_time_ms > 0) {
GloAI->add_llm_cache_store_time_ms(result.cache_store_time_ms);
GloAI->increment_llm_cache_stores();
}
GloAI->increment_llm_cloud_model_calls();
}
return result;
}
/**
* @brief Clear the vector cache
*/
void LLM_Bridge::clear_cache() {
if (!vector_db) {
return;
}
// TODO: Implement cache clearing
// This would involve deleting all rows from llm_cache table
proxy_info("LLM_Bridge: Cache cleared\n");
}
/**
* @brief Get cache statistics
*/
std::string LLM_Bridge::get_cache_stats() {
// TODO: Implement cache statistics
// This would involve querying the llm_cache table for metrics
json stats;
stats["entries"] = 0;
stats["hits"] = 0;
stats["misses"] = 0;
return stats.dump();
}

@ -19,7 +19,7 @@
* @see NL2SQL_Converter.h
*/
#include "NL2SQL_Converter.h"
#include "LLM_Bridge.h"
#include "sqlite3db.h"
#include "proxysql_utils.h"
#include <cstring>
@ -50,11 +50,11 @@ using json = nlohmann::json;
do { \
if (req_id && strlen(req_id) > 0) { \
proxy_debug(PROXY_DEBUG_NL2SQL, 2, \
"NL2SQL [%s]: REQUEST url=%s model=%s prompt_len=%zu\n", \
"LLM [%s]: REQUEST url=%s model=%s prompt_len=%zu\n", \
req_id, url, model, prompt.length()); \
} else { \
proxy_debug(PROXY_DEBUG_NL2SQL, 2, \
"NL2SQL: REQUEST url=%s model=%s prompt_len=%zu\n", \
"LLM: REQUEST url=%s model=%s prompt_len=%zu\n", \
url, model, prompt.length()); \
} \
} while(0)
@ -63,11 +63,11 @@ using json = nlohmann::json;
do { \
if (req_id && strlen(req_id) > 0) { \
proxy_debug(PROXY_DEBUG_NL2SQL, 3, \
"NL2SQL [%s]: RESPONSE status=%d duration_ms=%ld response=%s\n", \
"LLM [%s]: RESPONSE status=%d duration_ms=%ld response=%s\n", \
req_id, status, duration_ms, response_preview.c_str()); \
} else { \
proxy_debug(PROXY_DEBUG_NL2SQL, 3, \
"NL2SQL: RESPONSE status=%d duration_ms=%ld response=%s\n", \
"LLM: RESPONSE status=%d duration_ms=%ld response=%s\n", \
status, duration_ms, response_preview.c_str()); \
} \
} while(0)
@ -75,10 +75,10 @@ using json = nlohmann::json;
#define LOG_LLM_ERROR(req_id, phase, error, status) \
do { \
if (req_id && strlen(req_id) > 0) { \
proxy_error("NL2SQL [%s]: ERROR phase=%s error=%s status=%d\n", \
proxy_error("LLM [%s]: ERROR phase=%s error=%s status=%d\n", \
req_id, phase, error, status); \
} else { \
proxy_error("NL2SQL: ERROR phase=%s error=%s status=%d\n", \
proxy_error("LLM: ERROR phase=%s error=%s status=%d\n", \
phase, error, status); \
} \
} while(0)
@ -214,7 +214,7 @@ static void sleep_with_jitter(int base_delay_ms, double jitter_factor = 0.1) {
* @param req_id Request ID for correlation (optional)
* @return Generated SQL or empty string on error
*/
std::string NL2SQL_Converter::call_generic_openai(const std::string& prompt, const std::string& model,
std::string LLM_Bridge::call_generic_openai(const std::string& prompt, const std::string& model,
const std::string& url, const char* key,
const std::string& req_id) {
// Start timing
@ -381,7 +381,7 @@ std::string NL2SQL_Converter::call_generic_openai(const std::string& prompt, con
* @param req_id Request ID for correlation (optional)
* @return Generated SQL or empty string on error
*/
std::string NL2SQL_Converter::call_generic_anthropic(const std::string& prompt, const std::string& model,
std::string LLM_Bridge::call_generic_anthropic(const std::string& prompt, const std::string& model,
const std::string& url, const char* key,
const std::string& req_id) {
// Start timing
@ -544,7 +544,7 @@ std::string NL2SQL_Converter::call_generic_anthropic(const std::string& prompt,
* @param max_backoff_ms Maximum backoff delay in milliseconds
* @return Generated SQL or empty string if all retries fail
*/
std::string NL2SQL_Converter::call_generic_openai_with_retry(
std::string LLM_Bridge::call_generic_openai_with_retry(
const std::string& prompt,
const std::string& model,
const std::string& url,
@ -568,7 +568,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry(
// If we got a successful response, return it
if (!result.empty()) {
if (attempt > 0) {
proxy_info("NL2SQL [%s]: Request succeeded after %d retries\n",
proxy_info("LLM [%s]: Request succeeded after %d retries\n",
req_id.c_str(), attempt);
}
return result;
@ -580,7 +580,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry(
// If this was our last attempt, give up
if (attempt == max_retries) {
proxy_error("NL2SQL [%s]: Request failed after %d attempts. Max retries reached.\n",
proxy_error("LLM [%s]: Request failed after %d attempts. Max retries reached.\n",
req_id.c_str(), attempt + 1);
return "";
}
@ -590,10 +590,10 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry(
if (is_retryable_error(last_http_code, last_curl_code) || result.empty()) {
// Log retry attempt
if (result.empty()) {
proxy_warning("NL2SQL [%s]: Empty response, retrying in %dms (attempt %d/%d)\n",
proxy_warning("LLM [%s]: Empty response, retrying in %dms (attempt %d/%d)\n",
req_id.c_str(), current_backoff_ms, attempt + 1, max_retries + 1);
} else {
proxy_warning("NL2SQL [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n",
proxy_warning("LLM [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n",
req_id.c_str(), last_http_code, current_backoff_ms, attempt + 1, max_retries + 1);
}
@ -609,7 +609,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry(
attempt++;
} else {
// Non-retryable error, give up
proxy_error("NL2SQL [%s]: Non-retryable error (HTTP %d), giving up.\n",
proxy_error("LLM [%s]: Non-retryable error (HTTP %d), giving up.\n",
req_id.c_str(), last_http_code);
return "";
}
@ -638,7 +638,7 @@ std::string NL2SQL_Converter::call_generic_openai_with_retry(
* @param max_backoff_ms Maximum backoff delay in milliseconds
* @return Generated SQL or empty string if all retries fail
*/
std::string NL2SQL_Converter::call_generic_anthropic_with_retry(
std::string LLM_Bridge::call_generic_anthropic_with_retry(
const std::string& prompt,
const std::string& model,
const std::string& url,
@ -661,7 +661,7 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry(
// If we got a successful response, return it
if (!result.empty()) {
if (attempt > 0) {
proxy_info("NL2SQL [%s]: Request succeeded after %d retries\n",
proxy_info("LLM [%s]: Request succeeded after %d retries\n",
req_id.c_str(), attempt);
}
return result;
@ -669,7 +669,7 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry(
// If this was our last attempt, give up
if (attempt == max_retries) {
proxy_error("NL2SQL [%s]: Request failed after %d attempts. Max retries reached.\n",
proxy_error("LLM [%s]: Request failed after %d attempts. Max retries reached.\n",
req_id.c_str(), attempt + 1);
return "";
}
@ -679,10 +679,10 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry(
if (is_retryable_error(last_http_code, last_curl_code) || result.empty()) {
// Log retry attempt
if (result.empty()) {
proxy_warning("NL2SQL [%s]: Empty response, retrying in %dms (attempt %d/%d)\n",
proxy_warning("LLM [%s]: Empty response, retrying in %dms (attempt %d/%d)\n",
req_id.c_str(), current_backoff_ms, attempt + 1, max_retries + 1);
} else {
proxy_warning("NL2SQL [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n",
proxy_warning("LLM [%s]: Retryable error (HTTP %d), retrying in %dms (attempt %d/%d)\n",
req_id.c_str(), last_http_code, current_backoff_ms, attempt + 1, max_retries + 1);
}
@ -698,7 +698,7 @@ std::string NL2SQL_Converter::call_generic_anthropic_with_retry(
attempt++;
} else {
// Non-retryable error, give up
proxy_error("NL2SQL [%s]: Non-retryable error (HTTP %d), giving up.\n",
proxy_error("LLM [%s]: Non-retryable error (HTTP %d), giving up.\n",
req_id.c_str(), last_http_code);
return "";
}

@ -85,7 +85,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo
MySQL_Catalog.oo MySQL_Tool_Handler.oo \
Config_Tool_Handler.oo Query_Tool_Handler.oo \
Admin_Tool_Handler.oo Cache_Tool_Handler.oo Observe_Tool_Handler.oo \
AI_Features_Manager.oo NL2SQL_Converter.oo LLM_Clients.oo Anomaly_Detector.oo AI_Vector_Storage.oo AI_Tool_Handler.oo
AI_Features_Manager.oo LLM_Bridge.oo LLM_Clients.oo Anomaly_Detector.oo AI_Vector_Storage.oo AI_Tool_Handler.oo
OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX))
HEADERS := ../include/*.h ../include/*.hpp

@ -16,6 +16,7 @@ using json = nlohmann::json;
#include "MySQL_PreparedStatement.h"
#include "GenAI_Thread.h"
#include "AI_Features_Manager.h"
#include "LLM_Bridge.h"
#include "Anomaly_Detector.h"
#include "MySQL_Logger.hpp"
#include "StatCounters.h"
@ -3871,31 +3872,33 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
#endif // epoll_create1 - fallback blocking path
}
// Handler for NL2SQL: queries - Natural Language to SQL conversion
// Handler for LLM: queries - Generic LLM bridge processing
// Query format:
// NL2SQL: Show me top 10 customers by revenue
// Returns: Resultset with the generated SQL query
// LLM: Summarize the customer feedback
// LLM: Generate a Python function to validate emails
// LLM: Explain this SQL query: SELECT * FROM users
// Returns: Resultset with the text response from LLM
//
// Note: This now uses the async GENAI path to avoid blocking MySQL threads.
// The NL2SQL query is converted to a JSON GENAI request and sent asynchronously.
void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(const char* query, size_t query_len, PtrSize_t* pkt) {
// Skip leading space after "NL2SQL:"
// The LLM query is converted to a JSON GENAI request and sent asynchronously.
void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___llm(const char* query, size_t query_len, PtrSize_t* pkt) {
// Skip leading space after "LLM:"
while (query_len > 0 && (*query == ' ' || *query == '\t')) {
query++;
query_len--;
}
if (query_len == 0) {
// Empty query after NL2SQL:
// Empty query after LLM:
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1240, (char*)"HY000", "Empty NL2SQL: query", true);
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1240, (char*)"HY000", "Empty LLM: query", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Check GenAI module is initialized (NL2SQL now uses GenAI module)
// Check GenAI module is initialized (LLM now uses GenAI module)
if (!GloGATH) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1241, (char*)"HY000", "GenAI module is not initialized", true);
@ -3905,7 +3908,7 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
return;
}
// Check AI manager is available for NL2SQL converter
// Check AI manager is available for LLM bridge
if (!GloAI) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1242, (char*)"HY000", "AI features module is not initialized", true);
@ -3915,11 +3918,11 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
return;
}
// Get NL2SQL converter from AI manager
NL2SQL_Converter* nl2sql = GloAI->get_nl2sql();
if (!nl2sql) {
// Get LLM bridge from AI manager
LLM_Bridge* llm_bridge = GloAI->get_llm_bridge();
if (!llm_bridge) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1243, (char*)"HY000", "NL2SQL converter is not initialized", true);
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1243, (char*)"HY000", "LLM bridge is not initialized", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
@ -3927,16 +3930,16 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
}
// Increment total requests counter
GloAI->increment_nl2sql_total_requests();
GloAI->increment_llm_total_requests();
#ifdef epoll_create1
// Build JSON query for NL2SQL operation
// Build JSON query for LLM operation
json json_query;
json_query["type"] = "nl2sql";
json_query["query"] = std::string(query, query_len);
json_query["type"] = "llm";
json_query["prompt"] = std::string(query, query_len);
json_query["allow_cache"] = true;
// Add schema if available
// Add schema if available (for context)
if (client_myds->myconn->userinfo->schemaname) {
json_query["schema"] = std::string(client_myds->myconn->userinfo->schemaname);
}
@ -3954,38 +3957,38 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
// Request sent asynchronously - don't free pkt, will be freed in response handler
// Return immediately, session is now free to handle other queries
proxy_debug(PROXY_DEBUG_NL2SQL, 2, "NL2SQL: Query sent asynchronously via GenAI: %s\n", std::string(query, query_len).c_str());
proxy_debug(PROXY_DEBUG_GENAI, 2, "LLM: Query sent asynchronously via GenAI: %s\n", std::string(query, query_len).c_str());
#else
// Fallback to synchronous blocking path for systems without epoll
// Build NL2SQL request
NL2SQLRequest req;
req.natural_language = std::string(query, query_len);
// Build LLM request
LLMRequest req;
req.prompt = std::string(query, query_len);
req.schema_name = client_myds->myconn->userinfo->schemaname ? client_myds->myconn->userinfo->schemaname : "";
req.allow_cache = true;
req.max_latency_ms = 0; // No specific latency requirement
// Call NL2SQL converter (blocking fallback)
NL2SQLResult result = nl2sql->convert(req);
// Call LLM bridge (blocking fallback)
LLMResult result = llm_bridge->process(req);
// Update performance counters based on result
if (result.cache_hit) {
GloAI->increment_nl2sql_cache_hits();
GloAI->increment_llm_cache_hits();
} else {
GloAI->increment_nl2sql_cache_misses();
GloAI->increment_llm_cache_misses();
}
// Update timing counters
GloAI->add_nl2sql_response_time_ms(result.total_time_ms);
GloAI->add_nl2sql_cache_lookup_time_ms(result.cache_lookup_time_ms);
GloAI->increment_nl2sql_cache_lookups();
GloAI->add_llm_response_time_ms(result.total_time_ms);
GloAI->add_llm_cache_lookup_time_ms(result.cache_lookup_time_ms);
GloAI->increment_llm_cache_lookups();
if (result.cache_hit) {
// For cache hits, we're done
} else {
// For cache misses, also count LLM call time and cache store time
GloAI->add_nl2sql_cache_store_time_ms(result.cache_store_time_ms);
GloAI->add_llm_cache_store_time_ms(result.cache_store_time_ms);
if (result.cache_store_time_ms > 0) {
GloAI->increment_nl2sql_cache_stores();
GloAI->increment_llm_cache_stores();
}
// Update model call counters
@ -3998,19 +4001,23 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
if (prefer_local_models &&
(result.explanation.find("localhost") != std::string::npos ||
result.explanation.find("127.0.0.1") != std::string::npos)) {
GloAI->increment_nl2sql_local_model_calls();
GloAI->increment_llm_local_model_calls();
} else {
GloAI->increment_nl2sql_cloud_model_calls();
GloAI->increment_llm_cloud_model_calls();
}
} else if (result.provider_used == "anthropic") {
GloAI->increment_nl2sql_cloud_model_calls();
GloAI->increment_llm_cloud_model_calls();
}
}
if (result.sql_query.empty() || result.sql_query.find("NL2SQL conversion failed") == 0) {
// Conversion failed
std::string err_msg = "Failed to convert natural language to SQL: ";
err_msg += result.explanation;
if (result.text_response.empty() && !result.error_code.empty()) {
// LLM processing failed
std::string err_msg = "LLM processing failed: ";
err_msg += result.error_code;
if (!result.error_details.empty()) {
err_msg += " - ";
err_msg += result.error_details;
}
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1244, (char*)"HY000", (char*)err_msg.c_str(), true);
l_free(pkt->size, pkt->ptr);
@ -4019,8 +4026,8 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
return;
}
// Build resultset with the generated SQL
std::vector<std::string> columns = {"sql_query", "confidence", "explanation", "cached"};
// Build resultset with the generated text response
std::vector<std::string> columns = {"text_response", "explanation", "cached", "provider"};
std::unique_ptr<SQLite3_result> resultset(new SQLite3_result(columns.size()));
// Add column definitions
@ -4030,13 +4037,10 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
// Add single row with the result
char** row_data = (char**)malloc(columns.size() * sizeof(char*));
row_data[0] = strdup(result.sql_query.c_str());
char conf_buf[32];
snprintf(conf_buf, sizeof(conf_buf), "%.2f", result.confidence);
row_data[1] = strdup(conf_buf);
row_data[2] = strdup(result.explanation.c_str());
row_data[3] = strdup(result.cached ? "true" : "false");
row_data[0] = strdup(result.text_response.c_str());
row_data[1] = strdup(result.explanation.c_str());
row_data[2] = strdup(result.cached ? "true" : "false");
row_data[3] = strdup(result.provider_used.c_str());
resultset->add_row(row_data);
@ -4054,8 +4058,8 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
proxy_debug(PROXY_DEBUG_NL2SQL, 2, "NL2SQL: Converted '%s' to SQL (confidence: %.2f) [blocking fallback]\n",
req.natural_language.c_str(), result.confidence);
proxy_debug(PROXY_DEBUG_GENAI, 2, "LLM: Processed prompt '%s' [blocking fallback]\n",
req.prompt.c_str());
#endif
}
@ -7037,10 +7041,10 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
return true;
}
// Check for NL2SQL: queries - Natural Language to SQL conversion
if (query_len >= 8 && strncasecmp(query_ptr, "NL2SQL:", 7) == 0) {
// This is a NL2SQL: query - handle with NL2SQL converter
handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(query_ptr + 7, query_len - 7, pkt);
// Check for LLM: queries - Generic LLM bridge processing
if (query_len >= 5 && strncasecmp(query_ptr, "LLM:", 4) == 0) {
// This is a LLM: query - handle with LLM bridge
handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___llm(query_ptr + 4, query_len - 4, pkt);
return true;
}
}

@ -1,790 +0,0 @@
/**
* @file NL2SQL_Converter.cpp
* @brief Implementation of Natural Language to SQL Converter
*
* This file implements the NL2SQL conversion pipeline including:
* - Vector cache operations for semantic similarity
* - Model selection based on latency/budget
* - Generic LLM API calls (Ollama, OpenAI-compatible, Anthropic-compatible)
* - SQL validation and cleaning
*
* @see NL2SQL_Converter.h
*/
#include "NL2SQL_Converter.h"
#include "sqlite3db.h"
#include "proxysql_utils.h"
#include "GenAI_Thread.h"
#include <cstring>
#include <cstdlib>
#include <sstream>
#include <algorithm>
#include <regex>
#include <chrono>
using json = nlohmann::json;
// Global GenAI handler for embedding generation
extern GenAI_Threads_Handler *GloGATH;
// Global instance is defined elsewhere if needed
// NL2SQL_Converter *GloNL2SQL = NULL;
// ============================================================================
// Error Handling Helper Functions
// ============================================================================
/**
* @brief Convert error code enum to string representation
*
* Returns the string representation of an error code for logging
* and display purposes.
*
* @param code The error code to convert
* @return String representation of the error code
*/
const char* nl2sql_error_code_to_string(NL2SQLErrorCode code) {
switch (code) {
case NL2SQLErrorCode::SUCCESS: return "SUCCESS";
case NL2SQLErrorCode::ERR_API_KEY_MISSING: return "ERR_API_KEY_MISSING";
case NL2SQLErrorCode::ERR_API_KEY_INVALID: return "ERR_API_KEY_INVALID";
case NL2SQLErrorCode::ERR_TIMEOUT: return "ERR_TIMEOUT";
case NL2SQLErrorCode::ERR_CONNECTION_FAILED: return "ERR_CONNECTION_FAILED";
case NL2SQLErrorCode::ERR_RATE_LIMITED: return "ERR_RATE_LIMITED";
case NL2SQLErrorCode::ERR_SERVER_ERROR: return "ERR_SERVER_ERROR";
case NL2SQLErrorCode::ERR_EMPTY_RESPONSE: return "ERR_EMPTY_RESPONSE";
case NL2SQLErrorCode::ERR_INVALID_RESPONSE: return "ERR_INVALID_RESPONSE";
case NL2SQLErrorCode::ERR_SQL_INJECTION_DETECTED: return "ERR_SQL_INJECTION_DETECTED";
case NL2SQLErrorCode::ERR_VALIDATION_FAILED: return "ERR_VALIDATION_FAILED";
case NL2SQLErrorCode::ERR_UNKNOWN_PROVIDER: return "ERR_UNKNOWN_PROVIDER";
case NL2SQLErrorCode::ERR_REQUEST_TOO_LARGE: return "ERR_REQUEST_TOO_LARGE";
default: return "UNKNOWN_ERROR";
}
}
/**
* @brief Format detailed error context for logging and user display
*
* Creates a structured error message including:
* - Query (truncated if too long)
* - Schema name
* - Provider attempted
* - Endpoint URL
* - Specific error message
*
* @param req The NL2SQL request that failed
* @param provider The provider that was attempted
* @param url The endpoint URL that was used
* @param error The specific error message
* @return Formatted error context string
*/
static std::string format_error_context(const NL2SQLRequest& req,
const std::string& provider,
const std::string& url,
const std::string& error)
{
std::ostringstream oss;
oss << "NL2SQL conversion failed:\n"
<< " Query: " << req.natural_language.substr(0, 100)
<< (req.natural_language.length() > 100 ? "..." : "") << "\n"
<< " Schema: " << (req.schema_name.empty() ? "(none)" : req.schema_name) << "\n"
<< " Provider: " << provider << "\n"
<< " URL: " << url << "\n"
<< " Error: " << error;
return oss.str();
}
/**
* @brief Set error details in NL2SQLResult
*
* Helper function to populate error fields in result struct.
*
* @param result The result to update
* @param error_code The error code string
* @param error_details Detailed error context
* @param http_status HTTP status code (0 if N/A)
* @param provider Provider that was attempted
*/
static void set_error_details(NL2SQLResult& result,
const std::string& error_code,
const std::string& error_details,
int http_status,
const std::string& provider)
{
result.error_code = error_code;
result.error_details = error_details;
result.http_status_code = http_status;
result.provider_used = provider;
}
// ============================================================================
// Constructor/Destructor
// ============================================================================
/**
* Constructor initializes with default configuration values.
* The vector_db will be set by AI_Features_Manager during init().
*/
NL2SQL_Converter::NL2SQL_Converter() : vector_db(NULL) {
config.enabled = true;
config.query_prefix = strdup("NL2SQL:");
config.provider = strdup("openai");
config.provider_url = strdup("http://localhost:11434/v1/chat/completions"); // Ollama default
config.provider_model = strdup("llama3.2");
config.provider_key = NULL;
config.cache_similarity_threshold = 85;
config.timeout_ms = 30000;
}
NL2SQL_Converter::~NL2SQL_Converter() {
free(config.query_prefix);
free(config.provider);
free(config.provider_url);
free(config.provider_model);
free(config.provider_key);
}
// ============================================================================
// Lifecycle
// ============================================================================
/**
* Initialize the NL2SQL converter.
* The vector DB will be provided by AI_Features_Manager during initialization.
*/
int NL2SQL_Converter::init() {
proxy_info("NL2SQL: Initializing NL2SQL Converter v%s\n", NL2SQL_CONVERTER_VERSION);
// Vector DB will be provided by AI_Features_Manager
// This is a stub implementation for Phase 1
proxy_info("NL2SQL: NL2SQL Converter initialized (stub)\n");
return 0;
}
void NL2SQL_Converter::close() {
proxy_info("NL2SQL: NL2SQL Converter closed\n");
}
void NL2SQL_Converter::update_config(const char* provider, const char* provider_url,
const char* provider_model, const char* provider_key,
int cache_threshold, int timeout) {
// Free old values
free(config.provider);
free(config.provider_url);
free(config.provider_model);
free(config.provider_key);
// Set new values
config.provider = strdup(provider ? provider : "openai");
config.provider_url = strdup(provider_url ? provider_url : "http://localhost:11434/v1/chat/completions");
config.provider_model = strdup(provider_model ? provider_model : "llama3.2");
config.provider_key = provider_key ? strdup(provider_key) : NULL;
config.cache_similarity_threshold = cache_threshold;
config.timeout_ms = timeout;
}
// ============================================================================
// Vector Cache Operations (semantic similarity cache)
// ============================================================================
/**
* @brief Generate vector embedding for text
*
* Generates a 1536-dimensional embedding using the GenAI module.
* This embedding represents the semantic meaning of the text.
*
* @param text Input text to embed
* @return Vector embedding (empty if not available)
*/
std::vector<float> NL2SQL_Converter::get_query_embedding(const std::string& text) {
if (!GloGATH) {
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: GenAI handler not available for embedding");
return {};
}
// Generate embedding using GenAI
GenAI_EmbeddingResult emb_result = GloGATH->embed_documents({text});
if (!emb_result.data || emb_result.count == 0) {
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Failed to generate embedding");
return {};
}
// Convert to std::vector<float>
std::vector<float> embedding(emb_result.data, emb_result.data + emb_result.embedding_size);
// Free the result data (GenAI allocates with malloc)
if (emb_result.data) {
free(emb_result.data);
}
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Generated embedding with %zu dimensions", embedding.size());
return embedding;
}
/**
* @brief Check vector cache for semantically similar previous conversions
*
* Uses sqlite-vec to find previous NL2SQL conversions with similar
* natural language queries. This allows caching based on semantic meaning
* rather than exact string matching.
*/
NL2SQLResult NL2SQL_Converter::check_vector_cache(const NL2SQLRequest& req) {
NL2SQLResult result;
result.cached = false;
if (!vector_db || !req.allow_cache) {
return result;
}
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Checking vector cache for: %s\n",
req.natural_language.c_str());
// Generate embedding for the query
std::vector<float> query_embedding = get_query_embedding(req.natural_language);
if (query_embedding.empty()) {
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Failed to generate embedding for cache lookup");
return result;
}
// Convert embedding to JSON for sqlite-vec MATCH
std::string embedding_json = "[";
for (size_t i = 0; i < query_embedding.size(); i++) {
if (i > 0) embedding_json += ",";
embedding_json += std::to_string(query_embedding[i]);
}
embedding_json += "]";
// Calculate distance threshold from similarity
// Similarity 0-100 -> Distance 0-2 (cosine distance: 0=similar, 2=dissimilar)
float distance_threshold = 2.0f - (config.cache_similarity_threshold / 50.0f);
// Build KNN search query
char search[1024];
snprintf(search, sizeof(search),
"SELECT c.natural_language, c.generated_sql, c.schema_context, "
" vec_distance_cosine(v.embedding, '%s') as distance "
"FROM nl2sql_cache c "
"JOIN nl2sql_cache_vec v ON c.id = v.rowid "
"WHERE v.embedding MATCH '%s' "
"AND distance < %f "
"ORDER BY distance "
"LIMIT 1",
embedding_json.c_str(), embedding_json.c_str(), distance_threshold);
// Execute search
sqlite3* db = vector_db->get_db();
sqlite3_stmt* stmt = NULL;
int rc = sqlite3_prepare_v2(db, search, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Cache search prepare failed: %s", sqlite3_errmsg(db));
return result;
}
// Check if any cached queries matched
rc = sqlite3_step(stmt);
if (rc == SQLITE_ROW) {
// Found similar cached query
result.cached = true;
// Extract cached result (natural_lang and schema_ctx available but not currently used)
// const char* natural_lang = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
const char* generated_sql = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
// const char* schema_ctx = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
double distance = sqlite3_column_double(stmt, 3);
// Calculate similarity score from distance
float similarity = 1.0f - (distance / 2.0f);
result.confidence = similarity;
result.sql_query = generated_sql ? generated_sql : "";
result.explanation = "Retrieved from semantic cache (similarity: " +
std::to_string((int)(similarity * 100)) + "%)";
proxy_info("NL2SQL: Cache hit! (distance: %.3f, similarity: %.0f%%)\n",
distance, similarity * 100);
}
sqlite3_finalize(stmt);
return result;
}
/**
* @brief Store a new NL2SQL conversion in the vector cache
*
* Stores both the original query and generated SQL, along with
* the query embedding for semantic similarity search.
*/
void NL2SQL_Converter::store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result) {
if (!vector_db || !req.allow_cache) {
return;
}
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Storing in vector cache: %s -> %s\n",
req.natural_language.c_str(), result.sql_query.c_str());
// Generate embedding for the natural language query
std::vector<float> embedding = get_query_embedding(req.natural_language);
if (embedding.empty()) {
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "NL2SQL: Failed to generate embedding for cache storage");
return;
}
// Insert into main table with embedding BLOB
sqlite3* db = vector_db->get_db();
sqlite3_stmt* stmt = NULL;
const char* insert = "INSERT INTO nl2sql_cache "
"(natural_language, generated_sql, schema_context, embedding) "
"VALUES (?, ?, ?, ?)";
int rc = sqlite3_prepare_v2(db, insert, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
proxy_error("NL2SQL: Failed to prepare cache insert: %s\n", sqlite3_errmsg(db));
return;
}
// Bind values
sqlite3_bind_text(stmt, 1, req.natural_language.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, result.sql_query.c_str(), -1, SQLITE_TRANSIENT);
// Schema context (may be empty)
std::string schema_context;
if (!req.context_tables.empty()) {
schema_context = "{"; // Simple format: table names
for (size_t i = 0; i < req.context_tables.size(); i++) {
if (i > 0) schema_context += ",";
schema_context += req.context_tables[i];
}
schema_context += "}";
}
sqlite3_bind_text(stmt, 3, schema_context.c_str(), -1, SQLITE_TRANSIENT);
// Bind embedding as BLOB
sqlite3_bind_blob(stmt, 4, embedding.data(), embedding.size() * sizeof(float), SQLITE_TRANSIENT);
// Execute insert
rc = sqlite3_step(stmt);
if (rc != SQLITE_DONE) {
proxy_error("NL2SQL: Failed to insert into cache: %s\n", sqlite3_errmsg(db));
sqlite3_finalize(stmt);
return;
}
sqlite3_finalize(stmt);
// Get the inserted rowid
sqlite3_int64 rowid = sqlite3_last_insert_rowid(db);
// Update virtual table (sqlite-vec needs explicit rowid insertion)
char update_vec[256];
snprintf(update_vec, sizeof(update_vec),
"INSERT INTO nl2sql_cache_vec(rowid) VALUES (%lld)", rowid);
char* err = NULL;
rc = sqlite3_exec(db, update_vec, NULL, NULL, &err);
if (rc != SQLITE_OK) {
proxy_error("NL2SQL: Failed to update vec table: %s\n", err ? err : "unknown");
if (err) sqlite3_free(err);
return;
}
proxy_info("NL2SQL: Stored in cache (id: %lld)\n", rowid);
}
// ============================================================================
// Model Selection Logic
// ============================================================================
/**
* @brief Select the best model provider for the given request
*
* Selection criteria:
* 1. Explicit provider preference -> use that
* 2. For generic providers: check API key availability (only for cloud)
*
* @note For local endpoints (like Ollama), API key is optional
*/
ModelProvider NL2SQL_Converter::select_model(const NL2SQLRequest& req) {
// Check provider preference
std::string provider(config.provider ? config.provider : "openai");
if (provider == "openai") {
// For local endpoints, API key is optional
// Check if this is a local endpoint
std::string url(config.provider_url ? config.provider_url : "");
bool is_local = (url.find("localhost") != std::string::npos ||
url.find("127.0.0.1") != std::string::npos ||
url.find("http://localhost:11434") != std::string::npos);
if (!is_local && !config.provider_key) {
proxy_error("NL2SQL: OpenAI-compatible provider requested but API key not configured\n");
return ModelProvider::FALLBACK_ERROR;
}
return ModelProvider::GENERIC_OPENAI;
} else if (provider == "anthropic") {
// Anthropic always requires API key
if (!config.provider_key) {
proxy_error("NL2SQL: Anthropic-compatible provider requested but API key not configured\n");
return ModelProvider::FALLBACK_ERROR;
}
return ModelProvider::GENERIC_ANTHROPIC;
}
// Unknown provider, default to OpenAI format
proxy_warning("NL2SQL: Unknown provider '%s', defaulting to OpenAI format\n", provider.c_str());
return ModelProvider::GENERIC_OPENAI;
}
// ============================================================================
// Prompt Building
// ============================================================================
/**
* @brief Build the prompt for LLM with schema context
*
* Constructs a comprehensive prompt including:
* - System instructions
* - Schema information (tables, columns)
* - User's natural language query
*/
std::string NL2SQL_Converter::build_prompt(const NL2SQLRequest& req, const std::string& schema_context) {
std::ostringstream prompt;
// System instructions
prompt << "You are a SQL expert. Convert the following natural language question to a SQL query.\n\n";
// Add schema context if available
if (!schema_context.empty()) {
prompt << "Database Schema:\n";
prompt << schema_context;
prompt << "\n";
}
// User's question
prompt << "Question: " << req.natural_language << "\n\n";
prompt << "Return ONLY the SQL query. No explanations, no markdown formatting.\n";
return prompt.str();
}
/**
* @brief Get schema context for the specified tables
*
* Retrieves table and column information from the MySQL_Tool_Handler
* or from cached schema information.
*/
std::string NL2SQL_Converter::get_schema_context(const std::vector<std::string>& tables) {
// TODO: Implement schema context retrieval via MySQL_Tool_Handler
// For Phase 2, return empty string
return "";
}
// ============================================================================
// SQL Validation
// ============================================================================
/**
* @brief Validate SQL and generate confidence score
*
* Performs multi-factor validation:
* 1. SQL keyword detection
* 2. Structural validation (parentheses, quotes)
* 3. Common SQL injection pattern detection
* 4. Length and complexity checks
*
* @param sql The SQL to validate
* @return Confidence score 0.0-1.0
*/
float NL2SQL_Converter::validate_and_score_sql(const std::string& sql) {
if (sql.empty()) {
return 0.0f;
}
float confidence = 0.0f;
int checks_passed = 0;
int total_checks = 0;
// Trim leading whitespace for validation
size_t start = sql.find_first_not_of(" \t\n\r");
if (start == std::string::npos) {
return 0.0f; // Empty or whitespace only
}
std::string trimmed_sql = sql.substr(start);
std::string upper_sql = trimmed_sql;
std::transform(upper_sql.begin(), upper_sql.end(), upper_sql.begin(), ::toupper);
// Check 1: SQL keyword detection
total_checks++;
static const std::vector<std::string> sql_keywords = {
"SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP",
"TRUNCATE", "REPLACE", "GRANT", "REVOKE", "SHOW", "DESCRIBE",
"EXPLAIN", "WITH", "CALL", "BEGIN", "COMMIT", "ROLLBACK"
};
for (const auto& keyword : sql_keywords) {
if (upper_sql.find(keyword) == 0 || upper_sql.find("-- " + keyword) == 0) {
confidence += 0.4f;
checks_passed++;
break;
}
}
// Check 2: Structural validation - balanced parentheses
total_checks++;
int paren_count = 0;
bool balanced_parens = true;
for (char c : sql) {
if (c == '(') paren_count++;
else if (c == ')') paren_count--;
if (paren_count < 0) {
balanced_parens = false;
break;
}
}
if (balanced_parens && paren_count == 0) {
confidence += 0.15f;
checks_passed++;
} else if (paren_count != 0) {
// Unbalanced parentheses reduce confidence
confidence -= 0.1f;
}
// Check 3: Balanced quotes
total_checks++;
int single_quotes = 0;
int double_quotes = 0;
for (size_t i = 0; i < sql.length(); i++) {
if (sql[i] == '\'' && (i == 0 || sql[i-1] != '\\')) {
single_quotes++;
}
if (sql[i] == '"' && (i == 0 || sql[i-1] != '\\')) {
double_quotes++;
}
}
if (single_quotes % 2 == 0 && double_quotes % 2 == 0) {
confidence += 0.15f;
checks_passed++;
} else {
confidence -= 0.1f;
}
// Check 4: Minimum length check
total_checks++;
if (sql.length() >= 10) {
confidence += 0.1f;
checks_passed++;
}
// Check 5: Contains FROM clause for SELECT statements (quality indicator)
total_checks++;
if (upper_sql.find("SELECT") == 0 && upper_sql.find("FROM") != std::string::npos) {
confidence += 0.1f;
checks_passed++;
}
// Check 6: SQL injection pattern detection (negative impact)
total_checks++;
static const std::vector<std::string> injection_patterns = {
"; DROP", "; DELETE", "; INSERT", "; UPDATE",
"1=1", "1 = 1", "OR TRUE", "AND TRUE",
"UNION SELECT", "'; --", "\"; --"
};
bool has_injection = false;
std::string check_upper = upper_sql;
for (const auto& pattern : injection_patterns) {
std::string pattern_upper = pattern;
std::transform(pattern_upper.begin(), pattern_upper.end(), pattern_upper.begin(), ::toupper);
if (check_upper.find(pattern_upper) != std::string::npos) {
has_injection = true;
break;
}
}
if (!has_injection) {
confidence += 0.1f;
checks_passed++;
} else {
confidence -= 0.3f; // Significant penalty for injection patterns
proxy_warning("NL2SQL: Potential SQL injection pattern detected in generated SQL\n");
}
// Normalize confidence to 0.0-1.0 range
if (confidence < 0.0f) confidence = 0.0f;
if (confidence > 1.0f) confidence = 1.0f;
// Additional logging for low confidence
if (confidence < 0.5f) {
proxy_debug(PROXY_DEBUG_NL2SQL, 2,
"NL2SQL: Low confidence score %.2f (passed %d/%d checks). SQL: %s\n",
confidence, checks_passed, total_checks, sql.c_str());
}
return confidence;
}
// ============================================================================
// Main Conversion Method
// ============================================================================
/**
* @brief Convert natural language to SQL (main entry point)
*
* Conversion Pipeline:
* 1. Check vector cache for semantically similar queries
* 2. Build prompt with schema context
* 3. Select appropriate model (Ollama or generic provider)
* 4. Call LLM API via HTTP
* 5. Parse and clean SQL response
* 6. Store in vector cache for future use
*
* The confidence score is calculated based on:
* - SQL keyword validation (does it look like SQL?)
* - Response quality (non-empty, well-formed)
* - Default score of 0.85 for valid-looking SQL
*
* @note This is a synchronous blocking call.
*/
NL2SQLResult NL2SQL_Converter::convert(const NL2SQLRequest& req) {
NL2SQLResult result;
// Start timing the entire conversion
auto start_time = std::chrono::steady_clock::now();
proxy_info("NL2SQL: Converting query: %s\n", req.natural_language.c_str());
// Check vector cache first
auto cache_start = std::chrono::steady_clock::now();
if (req.allow_cache) {
result = check_vector_cache(req);
if (result.cached && !result.sql_query.empty()) {
proxy_info("NL2SQL: Cache hit! Returning cached SQL\n");
// Set timing information for cache hit
auto cache_end = std::chrono::steady_clock::now();
int cache_lookup_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(cache_end - cache_start).count();
result.total_time_ms = cache_lookup_time_ms;
result.cache_lookup_time_ms = cache_lookup_time_ms;
result.cache_hit = true;
return result;
}
}
auto cache_end = std::chrono::steady_clock::now();
int cache_lookup_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(cache_end - cache_start).count();
// Build prompt with schema context
std::string schema_context = get_schema_context(req.context_tables);
std::string prompt = build_prompt(req, schema_context);
// Select model provider
ModelProvider provider = select_model(req);
// Call appropriate LLM
std::string raw_sql;
std::string url;
const char* model = NULL;
const char* key = config.provider_key;
// Time the LLM call
auto llm_start = std::chrono::steady_clock::now();
switch (provider) {
case ModelProvider::GENERIC_OPENAI:
// Use configured URL or default Ollama endpoint
url = (config.provider_url && strlen(config.provider_url) > 0)
? config.provider_url
: "http://localhost:11434/v1/chat/completions";
model = config.provider_model ? config.provider_model : "llama3.2";
raw_sql = call_generic_openai_with_retry(prompt, model, url, key, req.request_id,
req.max_retries, req.retry_backoff_ms,
req.retry_multiplier, req.retry_max_backoff_ms);
result.explanation = "Generated by OpenAI-compatible provider (" + std::string(model) + ")";
result.provider_used = "openai";
break;
case ModelProvider::GENERIC_ANTHROPIC:
// Use configured URL or default Anthropic endpoint
url = (config.provider_url && strlen(config.provider_url) > 0)
? config.provider_url
: "https://api.anthropic.com/v1/messages";
model = config.provider_model ? config.provider_model : "claude-3-haiku";
raw_sql = call_generic_anthropic_with_retry(prompt, model, url, key, req.request_id,
req.max_retries, req.retry_backoff_ms,
req.retry_multiplier, req.retry_max_backoff_ms);
result.explanation = "Generated by Anthropic-compatible provider (" + std::string(model) + ")";
result.provider_used = "anthropic";
break;
case ModelProvider::FALLBACK_ERROR:
default: {
// Format error context
std::string provider_str(config.provider ? config.provider : "unknown");
std::string url_str(config.provider_url ? config.provider_url : "not configured");
std::string error_msg = "API key not configured or provider error";
std::string context = format_error_context(req, provider_str, url_str, error_msg);
proxy_error("NL2SQL: %s\n", context.c_str());
set_error_details(result, "ERR_API_KEY_MISSING", context, 0, provider_str);
result.sql_query = "-- NL2SQL conversion failed: " + error_msg + "\n";
result.confidence = 0.0f;
result.explanation = "Error: " + error_msg;
return result;
}
}
auto llm_end = std::chrono::steady_clock::now();
int llm_call_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(llm_end - llm_start).count();
// Validate and clean SQL
if (raw_sql.empty()) {
std::string provider_str(config.provider ? config.provider : "unknown");
std::string url_str(config.provider_url ? config.provider_url : "not configured");
std::string error_msg = "empty response from LLM";
std::string context = format_error_context(req, provider_str, url_str, error_msg);
proxy_error("NL2SQL: %s\n", context.c_str());
set_error_details(result, "ERR_EMPTY_RESPONSE", context, 0, provider_str);
result.sql_query = "-- NL2SQL conversion failed: " + error_msg + "\n";
result.confidence = 0.0f;
result.explanation += " (empty response)";
return result;
}
// Improved SQL validation
float confidence = validate_and_score_sql(raw_sql);
result.sql_query = raw_sql;
result.confidence = confidence;
// Store in vector cache for future use if confidence is good enough
auto cache_store_start = std::chrono::steady_clock::now();
if (req.allow_cache && confidence >= 0.5f) {
store_in_vector_cache(req, result);
}
auto cache_store_end = std::chrono::steady_clock::now();
int cache_store_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(cache_store_end - cache_store_start).count();
proxy_info("NL2SQL: Conversion complete. Confidence: %.2f\n", result.confidence);
// Calculate total time
auto end_time = std::chrono::steady_clock::now();
int total_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time).count();
// Populate timing information in result
result.total_time_ms = total_time_ms;
result.cache_lookup_time_ms = cache_lookup_time_ms;
result.cache_store_time_ms = cache_store_time_ms;
result.llm_call_time_ms = llm_call_time_ms;
result.cache_hit = false; // This will be set to true if we return from cache hit
return result;
}
// ============================================================================
// Cache Management
// ============================================================================
void NL2SQL_Converter::clear_cache() {
proxy_info("NL2SQL: Cache cleared\n");
// TODO: Implement cache clearing
}
std::string NL2SQL_Converter::get_cache_stats() {
return "{\"entries\": 0, \"hits\": 0, \"misses\": 0}";
// TODO: Implement real cache statistics
}

@ -121,10 +121,10 @@ ProxySQL_MCP_Server::ProxySQL_MCP_Server(int p, MCP_Threads_Handler* h)
proxy_info("Observe Tool Handler initialized\n");
}
// 6. AI Tool Handler (for NL2SQL and other AI features)
// 6. AI Tool Handler (for LLM and other AI features)
extern AI_Features_Manager *GloAI;
if (GloAI) {
handler->ai_tool_handler = new AI_Tool_Handler(GloAI->get_nl2sql(), GloAI->get_anomaly_detector());
handler->ai_tool_handler = new AI_Tool_Handler(GloAI->get_llm_bridge(), GloAI->get_anomaly_detector());
if (handler->ai_tool_handler->init() == 0) {
proxy_info("AI Tool Handler initialized\n");
} else {
@ -164,7 +164,7 @@ ProxySQL_MCP_Server::ProxySQL_MCP_Server(int p, MCP_Threads_Handler* h)
ws->register_resource("/mcp/cache", cache_resource.get(), true);
_endpoints.push_back({"/mcp/cache", std::move(cache_resource)});
// 6. AI endpoint (for NL2SQL and other AI features)
// 6. AI endpoint (for LLM and other AI features)
if (handler->ai_tool_handler) {
std::unique_ptr<httpserver::http_resource> ai_resource =
std::unique_ptr<httpserver::http_resource>(new MCP_JSONRPC_Resource(handler, handler->ai_tool_handler, "ai"));

Loading…
Cancel
Save