Add experimental GenAI RERANK: query support for MySQL

This commit adds experimental support for reranking documents directly
from MySQL queries using a special RERANK: syntax.

Changes:
- Add handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai_rerank()
- Add RERANK: query detection alongside EMBED: detection
- Implement JSON parsing for query, documents array, and optional top_n
- Build resultset with index, score, and document columns
- Use MySQL ERR_Packet for error handling

Query format: RERANK: {"query": "search query", "documents": ["doc1", "doc2", ...], "top_n": 5}
Result format: 1 row per result, 3 columns (index, score, document)
pull/5310/head
Rene Cannao 1 month ago
parent 253591d262
commit 39939f598b

@ -284,6 +284,7 @@ class MySQL_Session: public Base_Session<MySQL_Session, MySQL_Data_Stream, MySQL
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_INIT_DB_replace_CLICKHOUSE(PtrSize_t& pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___not_mysql(PtrSize_t& pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai_embedding(const char* query, size_t query_len, PtrSize_t* pkt);
void handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai_rerank(const char* query, size_t query_len, PtrSize_t* pkt);
bool handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_detect_SQLi();
bool handler___status_WAITING_CLIENT_DATA___STATE_SLEEP_MULTI_PACKET(PtrSize_t& pkt);
bool handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM__various(PtrSize_t* pkt, bool* wrong_pass);

@ -3736,6 +3736,172 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
}
}
// Handler for RERANK: queries - experimental GenAI integration
// Query format: RERANK: {"query": "search query", "documents": ["doc1", "doc2", ...], "top_n": 5}
// Returns: Resultset with reranked documents (index, score, document)
void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai_rerank(const char* query, size_t query_len, PtrSize_t* pkt) {
// Skip leading space after "RERANK:"
while (query_len > 0 && (*query == ' ' || *query == '\t')) {
query++;
query_len--;
}
if (query_len == 0) {
// Empty query after RERANK:
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2234, (char*)"HY000", "Empty RERANK: query", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Parse JSON object with query, documents, and optional top_n
try {
json j = json::parse(std::string(query, query_len));
if (!j.is_object()) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2235, (char*)"HY000", "RERANK: query requires a JSON object with 'query' and 'documents' fields", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Extract query
if (!j.contains("query") || !j["query"].is_string()) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2236, (char*)"HY000", "RERANK: query requires a 'query' string field", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
std::string query_str = j["query"].get<std::string>();
if (query_str.empty()) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2237, (char*)"HY000", "RERANK: query field cannot be empty", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Extract documents
if (!j.contains("documents") || !j["documents"].is_array()) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2238, (char*)"HY000", "RERANK: query requires a 'documents' array field", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
std::vector<std::string> documents;
for (const auto& doc : j["documents"]) {
if (doc.is_string()) {
documents.push_back(doc.get<std::string>());
} else {
documents.push_back(doc.dump());
}
}
if (documents.empty()) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2239, (char*)"HY000", "RERANK: documents array cannot be empty", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Extract optional top_n
uint32_t top_n = 0; // 0 means return all
if (j.contains("top_n") && j["top_n"].is_number()) {
top_n = j["top_n"].get<uint32_t>();
}
// Call GenAI module to rerank documents
// Note: This is a synchronous call for the experimental implementation
// TODO: Make this asynchronous using socketpair
if (!GloGATH) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2240, (char*)"HY000", "GenAI module is not initialized", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
GenAI_RerankResultArray result = GloGATH->rerank_documents(query_str, documents, top_n);
if (!result.data || result.count == 0) {
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2241, (char*)"HY000", "Failed to rerank documents", true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
return;
}
// Build resultset: 1 row per result, 3 columns (index, score, document)
std::unique_ptr<SQLite3_result> resultset(new SQLite3_result(3));
resultset->add_column_definition(SQLITE_TEXT, "index");
resultset->add_column_definition(SQLITE_TEXT, "score");
resultset->add_column_definition(SQLITE_TEXT, "document");
for (size_t i = 0; i < result.count; i++) {
const GenAI_RerankResult& r = result.data[i];
// Convert values to strings
std::string index_str = std::to_string(r.index);
std::string score_str = std::to_string(r.score);
const std::string& doc_str = documents[r.index];
// Add row to resultset
char* row_data[3];
char* index_copy = strdup(index_str.c_str());
char* score_copy = strdup(score_str.c_str());
char* doc_copy = strdup(doc_str.c_str());
row_data[0] = index_copy;
row_data[1] = score_copy;
row_data[2] = doc_copy;
resultset->add_row(row_data);
free(index_copy);
free(score_copy);
free(doc_copy);
}
// Send resultset to client
SQLite3_to_MySQL(resultset.get(), NULL, 0, &client_myds->myprot, false,
(client_myds->myconn->options.client_flag & CLIENT_DEPRECATE_EOF));
// Clean up
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
} catch (const json::parse_error& e) {
std::string err_msg = "JSON parse error in RERANK: query: ";
err_msg += e.what();
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2242, (char*)"HY000", err_msg.c_str(), true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
} catch (const std::exception& e) {
std::string err_msg = "Error processing RERANK: query: ";
err_msg += e.what();
client_myds->DSS = STATE_QUERY_SENT_NET;
client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, 1, 2243, (char*)"HY000", err_msg.c_str(), true);
l_free(pkt->size, pkt->ptr);
client_myds->DSS = STATE_SLEEP;
status = WAITING_CLIENT_DATA;
}
}
// this function was inline inside MySQL_Session::get_pkts_from_client
// where:
// status = WAITING_CLIENT_DATA
@ -6205,6 +6371,12 @@ bool MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C
handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai_embedding(query_ptr + 6, query_len - 6, pkt);
return true;
}
if (query_len >= 8 && strncasecmp(query_ptr, "RERANK:", 7) == 0) {
// This is a RERANK: query - handle with GenAI module
handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY___genai_rerank(query_ptr + 7, query_len - 7, pkt);
return true;
}
}
if (qpo->new_query) {

Loading…
Cancel
Save