feat: Add NL2SQL query interception in MySQL_Session

- Add NL2SQL handler declaration
- Add routing for 'NL2SQL:' prefix
- Return resultset with generated SQL and metadata
pull/5310/head
Rene Cannao 3 months ago
parent 147a059781
commit bc4fff12ce

@ -284,6 +284,7 @@ class MySQL_Session: public Base_Session<MySQL_Session, MySQL_Data_Stream, MySQL
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_INIT_DB_replace_CLICKHOUSE(PtrSize_t& pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___not_mysql(PtrSize_t& pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai(const char* query, size_t query_len, PtrSize_t* pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(const char* query, size_t query_len, PtrSize_t* pkt);
#ifdef epoll_create1
/**
* @brief Handle GenAI response from socketpair

@ -3789,6 +3789,109 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
#endif // epoll_create1 - fallback blocking path
}
// Handler for NL2SQL: queries - Natural Language to SQL conversion
// Query format:
// NL2SQL: Show me top 10 customers by revenue
// Returns: Resultset with the generated SQL query
void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(const char* query, size_t query_len, PtrSize_t* pkt) {
// Skip leading space after "NL2SQL:"
while (query_len > 0 && (*query == ' ' || *query == '\t')) {
query++;
query_len--;
}
if (query_len == 0) {
// Empty query after NL2SQL:
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1240, (char*)"HY000", "Empty NL2SQL: query", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Check AI module is initialized
if (!GloAI) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1241, (char*)"HY000", "AI features module is not initialized", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Get NL2SQL converter from AI manager
NL2SQL_Converter* nl2sql = GloAI->get_nl2sql();
if (!nl2sql) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1242, (char*)"HY000", "NL2SQL converter is not initialized", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Build NL2SQL request
NL2SQLRequest req;
req.natural_language = std::string(query, query_len);
req.schema_name = client_myds->myconn->userinfo->schemaname ? client_myds->myconn->userinfo->schemaname : "";
req.allow_cache = true;
req.max_latency_ms = 0; // No specific latency requirement
// Call NL2SQL converter (synchronous for Phase 2)
NL2SQLResult result = nl2sql->convert(req);
if (result.sql_query.empty() || result.sql_query.find("NL2SQL conversion failed") == 0) {
// Conversion failed
std::string err_msg = "Failed to convert natural language to SQL: ";
err_msg += result.explanation;
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1243, (char*)"HY000", (char*)err_msg.c_str(), true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Build resultset with the generated SQL
std::vector<std::string> columns = {"sql_query", "confidence", "explanation", "cached"};
std::unique_ptr<SQLite3_result> resultset(new SQLite3_result(columns.size()));
// Add column definitions
for (size_t i = 0; i < columns.size(); i++) {
resultset->add_column_definition(SQLITE_TEXT, (char*)columns[i].c_str());
}
// Add single row with the result
char** row_data = (char**)malloc(columns.size() * sizeof(char*));
row_data[0] = strdup(result.sql_query.c_str());
char conf_buf[32];
snprintf(conf_buf, sizeof(conf_buf), "%.2f", result.confidence);
row_data[1] = strdup(conf_buf);
row_data[2] = strdup(result.explanation.c_str());
row_data[3] = strdup(result.cached ? "true" : "false");
resultset->add_row(row_data);
// Free row data
for (size_t i = 0; i < columns.size(); i++) {
free(row_data[i]);
}
free(row_data);
// Send resultset to client
SQLite3_to_MySQL(resultset.get(), NULL, 0, &client_myds->myprot, false,
(client_myds->myconn->options.client_flag & CLIENT_DEPRECATE_EOF));
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
proxy_debug(PROXY_DEBUG_NL2SQL, 2, "NL2SQL: Converted '%s' to SQL (confidence: %.2f)\n",
req.natural_language.c_str(), result.confidence);
}
#ifdef epoll_create1
/**
* @brief Send GenAI request asynchronously via socketpair
@ -6759,6 +6862,13 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai(query_ptr + 6, query_len - 6, pkt);
return true;
}
// Check for NL2SQL: queries - Natural Language to SQL conversion
if (query_len >= 8 && strncasecmp(query_ptr, "NL2SQL:", 7) == 0) {
// This is a NL2SQL: query - handle with NL2SQL converter
handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(query_ptr + 7, query_len - 7, pkt);
return true;
}
}
if (qpo->new_query) {

Loading…
Cancel
Save