From bc4fff12ce5e0a2d003d7039ac65d9ef773e3a8f Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Fri, 16 Jan 2026 10:51:24 +0000 Subject: [PATCH] feat: Add NL2SQL query interception in MySQL_Session - Add NL2SQL handler declaration - Add routing for 'NL2SQL:' prefix - Return resultset with generated SQL and metadata --- include/MySQL_Session.h | 1 + lib/MySQL_Session.cpp | 110 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/include/MySQL_Session.h b/include/MySQL_Session.h index b44eea8a5..90da6b618 100644 --- a/include/MySQL_Session.h +++ b/include/MySQL_Session.h @@ -284,6 +284,7 @@ class MySQL_Session: public Base_Session 0 && (*query == ' ' || *query == '\t')) { + query++; + query_len--; + } + + if (query_len == 0) { + // Empty query after NL2SQL: + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1240, (char*)"HY000", "Empty NL2SQL: query", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Check AI module is initialized + if (!GloAI) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1241, (char*)"HY000", "AI features module is not initialized", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Get NL2SQL converter from AI manager + NL2SQL_Converter* nl2sql = GloAI->get_nl2sql(); + if (!nl2sql) { + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1242, (char*)"HY000", "NL2SQL converter is not initialized", true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Build NL2SQL request + NL2SQLRequest req; + req.natural_language = std::string(query, query_len); + req.schema_name = client_myds->myconn->userinfo->schemaname ? client_myds->myconn->userinfo->schemaname : ""; + req.allow_cache = true; + req.max_latency_ms = 0; // No specific latency requirement + + // Call NL2SQL converter (synchronous for Phase 2) + NL2SQLResult result = nl2sql->convert(req); + + if (result.sql_query.empty() || result.sql_query.find("NL2SQL conversion failed") == 0) { + // Conversion failed + std::string err_msg = "Failed to convert natural language to SQL: "; + err_msg += result.explanation; + client_myds->DSS = STATE_QUERY_SENT_NET; + client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 1243, (char*)"HY000", (char*)err_msg.c_str(), true); + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + return; + } + + // Build resultset with the generated SQL + std::vector columns = {"sql_query", "confidence", "explanation", "cached"}; + std::unique_ptr resultset(new SQLite3_result(columns.size())); + + // Add column definitions + for (size_t i = 0; i < columns.size(); i++) { + resultset->add_column_definition(SQLITE_TEXT, (char*)columns[i].c_str()); + } + + // Add single row with the result + char** row_data = (char**)malloc(columns.size() * sizeof(char*)); + row_data[0] = strdup(result.sql_query.c_str()); + + char conf_buf[32]; + snprintf(conf_buf, sizeof(conf_buf), "%.2f", result.confidence); + row_data[1] = strdup(conf_buf); + row_data[2] = strdup(result.explanation.c_str()); + row_data[3] = strdup(result.cached ? "true" : "false"); + + resultset->add_row(row_data); + + // Free row data + for (size_t i = 0; i < columns.size(); i++) { + free(row_data[i]); + } + free(row_data); + + // Send resultset to client + SQLite3_to_MySQL(resultset.get(), NULL, 0, &client_myds->myprot, false, + (client_myds->myconn->options.client_flag & CLIENT_DEPRECATE_EOF)); + + l_free(pkt->size, pkt->ptr); + client_myds->DSS = STATE_SLEEP; + status = WAITING_CLIENT_DATA; + + proxy_debug(PROXY_DEBUG_NL2SQL, 2, "NL2SQL: Converted '%s' to SQL (confidence: %.2f)\n", + req.natural_language.c_str(), result.confidence); +} + #ifdef epoll_create1 /** * @brief Send GenAI request asynchronously via socketpair @@ -6759,6 +6862,13 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai(query_ptr + 6, query_len - 6, pkt); return true; } + + // Check for NL2SQL: queries - Natural Language to SQL conversion + if (query_len >= 8 && strncasecmp(query_ptr, "NL2SQL:", 7) == 0) { + // This is a NL2SQL: query - handle with NL2SQL converter + handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___nl2sql(query_ptr + 7, query_len - 7, pkt); + return true; + } } if (qpo->new_query) {