Fix critical issues identified by gemini-code-assist

This commit addresses critical issues identified in the code review:

1. Fix non-blocking read handling:
   - lib/GenAI_Thread.cpp (listener_loop): Properly handle EAGAIN/EWOULDBLOCK
     - Return early on EAGAIN/EWOULDBLOCK instead of closing connection
     - Handle EOF (n==0) separately from errors (n<0)
   - lib/MySQL_Session.cpp (handle_genai_response): Properly handle EAGAIN/EWOULDBLOCK
     - Return early on EAGAIN/EWOULDBLOCK instead of cleaning up request
     - Use goto for cleaner control flow

2. Refactor JSON building/parsing to use nlohmann/json:
   - lib/GenAI_Thread.cpp (call_llama_batch_embedding):
     - Replace manual stringstream JSON building with nlohmann/json
     - Replace fragile string-based parsing with nlohmann/json::parse()
     - Support multiple response formats (results, data, embeddings)
     - Add proper error handling with try/catch
   - lib/GenAI_Thread.cpp (call_llama_rerank):
     - Replace manual stringstream JSON building with nlohmann/json
     - Replace fragile string-based parsing with nlohmann/json::parse()
     - Support multiple response formats and field names
     - Add proper error handling with try/catch

These changes:
- Fix potential connection drops due to incorrect EAGAIN handling
- Improve security and robustness of JSON handling
- Reduce code complexity and improve maintainability
- Add support for multiple API response formats
pull/5310/head
Rene Cannao 3 months ago
parent b77d38c2ca
commit 33a87c66a7

@ -521,32 +521,10 @@ GenAI_EmbeddingResult GenAI_Threads_Handler::call_llama_batch_embedding(const st
return result;
}
// Build JSON request
std::stringstream json;
json << "{\"input\":[";
for (size_t i = 0; i < texts.size(); i++) {
if (i > 0) json << ",";
json << "\"";
// Escape JSON special characters
for (char c : texts[i]) {
switch (c) {
case '"': json << "\\\""; break;
case '\\': json << "\\\\"; break;
case '\n': json << "\\n"; break;
case '\r': json << "\\r"; break;
case '\t': json << "\\t"; break;
default: json << c; break;
}
}
json << "\"";
}
json << "]}";
std::string json_str = json.str();
// Build JSON request using nlohmann/json
json payload;
payload["input"] = texts;
std::string json_str = payload.dump();
// Configure curl
curl_easy_setopt(curl, CURLOPT_URL, variables.genai_embedding_uri);
@ -571,80 +549,57 @@ GenAI_EmbeddingResult GenAI_Threads_Handler::call_llama_batch_embedding(const st
proxy_error("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
status_variables.failed_requests++;
} else {
// Parse JSON response to extract embeddings
std::vector<std::vector<float>> all_embeddings;
size_t pos = 0;
while ((pos = response_data.find("\"embedding\":", pos)) != std::string::npos) {
size_t array_start = response_data.find("[", pos);
if (array_start == std::string::npos) break;
size_t inner_start = array_start + 1;
if (inner_start >= response_data.size() || response_data[inner_start] != '[') {
inner_start = array_start;
}
size_t array_end = inner_start;
int bracket_count = 0;
bool in_array = false;
for (size_t i = inner_start; i < response_data.size(); i++) {
if (response_data[i] == '[') {
bracket_count++;
in_array = true;
} else if (response_data[i] == ']') {
bracket_count--;
if (bracket_count == 0 && in_array) {
array_end = i;
break;
// Parse JSON response using nlohmann/json
try {
json response_json = json::parse(response_data);
std::vector<std::vector<float>> all_embeddings;
// Handle different response formats
if (response_json.contains("results") && response_json["results"].is_array()) {
// Format: {"results": [{"embedding": [...]}, ...]}
for (const auto& result_item : response_json["results"]) {
if (result_item.contains("embedding") && result_item["embedding"].is_array()) {
std::vector<float> embedding = result_item["embedding"].get<std::vector<float>>();
all_embeddings.push_back(std::move(embedding));
}
}
}
std::string array_str = response_data.substr(inner_start + 1, array_end - inner_start - 1);
std::vector<float> embedding;
std::stringstream ss(array_str);
std::string token;
while (std::getline(ss, token, ',')) {
token.erase(0, token.find_first_not_of(" \t\n\r"));
token.erase(token.find_last_not_of(" \t\n\r") + 1);
if (token == "null" || token.empty()) {
continue;
}
try {
float val = std::stof(token);
embedding.push_back(val);
} catch (...) {
// Skip invalid values
} else if (response_json.contains("data") && response_json["data"].is_array()) {
// Format: {"data": [{"embedding": [...]}]}
for (const auto& item : response_json["data"]) {
if (item.contains("embedding") && item["embedding"].is_array()) {
std::vector<float> embedding = item["embedding"].get<std::vector<float>>();
all_embeddings.push_back(std::move(embedding));
}
}
} else if (response_json.contains("embeddings") && response_json["embeddings"].is_array()) {
// Format: {"embeddings": [[...], ...]}
all_embeddings = response_json["embeddings"].get<std::vector<std::vector<float>>>();
}
if (!embedding.empty()) {
all_embeddings.push_back(std::move(embedding));
}
pos = array_end + 1;
}
// Convert to contiguous array
if (!all_embeddings.empty()) {
result.count = all_embeddings.size();
result.embedding_size = all_embeddings[0].size();
// Convert to contiguous array
if (!all_embeddings.empty()) {
result.count = all_embeddings.size();
result.embedding_size = all_embeddings[0].size();
size_t total_floats = result.embedding_size * result.count;
result.data = new float[total_floats];
size_t total_floats = result.embedding_size * result.count;
result.data = new float[total_floats];
for (size_t i = 0; i < all_embeddings.size(); i++) {
size_t offset = i * result.embedding_size;
const auto& emb = all_embeddings[i];
std::copy(emb.begin(), emb.end(), result.data + offset);
}
for (size_t i = 0; i < all_embeddings.size(); i++) {
size_t offset = i * result.embedding_size;
const auto& emb = all_embeddings[i];
std::copy(emb.begin(), emb.end(), result.data + offset);
status_variables.completed_requests++;
} else {
status_variables.failed_requests++;
}
status_variables.completed_requests++;
} else {
} catch (const json::parse_error& e) {
proxy_error("Failed to parse embedding response JSON: %s\n", e.what());
status_variables.failed_requests++;
} catch (const std::exception& e) {
proxy_error("Error processing embedding response: %s\n", e.what());
status_variables.failed_requests++;
}
}
@ -717,44 +672,11 @@ GenAI_RerankResultArray GenAI_Threads_Handler::call_llama_rerank(const std::stri
return result;
}
// Build JSON request
std::stringstream json;
json << "{\"query\":\"";
for (char c : query) {
switch (c) {
case '"': json << "\\\""; break;
case '\\': json << "\\\\"; break;
case '\n': json << "\\n"; break;
case '\r': json << "\\r"; break;
case '\t': json << "\\t"; break;
default: json << c; break;
}
}
json << "\",\"documents\":[";
for (size_t i = 0; i < texts.size(); i++) {
if (i > 0) json << ",";
json << "\"";
for (char c : texts[i]) {
switch (c) {
case '"': json << "\\\""; break;
case '\\': json << "\\\\"; break;
case '\n': json << "\\n"; break;
case '\r': json << "\\r"; break;
case '\t': json << "\\t"; break;
default: json << c; break;
}
}
json << "\"";
}
json << "]}";
std::string json_str = json.str();
// Build JSON request using nlohmann/json
json payload;
payload["query"] = query;
payload["documents"] = texts;
std::string json_str = payload.dump();
// Configure curl
curl_easy_setopt(curl, CURLOPT_URL, variables.genai_rerank_uri);
@ -776,100 +698,62 @@ GenAI_RerankResultArray GenAI_Threads_Handler::call_llama_rerank(const std::stri
proxy_error("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
status_variables.failed_requests++;
} else {
size_t results_pos = response_data.find("\"results\":");
if (results_pos != std::string::npos) {
size_t array_start = response_data.find("[", results_pos);
if (array_start != std::string::npos) {
size_t array_end = array_start;
int bracket_count = 0;
bool in_array = false;
for (size_t i = array_start; i < response_data.size(); i++) {
if (response_data[i] == '[') {
bracket_count++;
in_array = true;
} else if (response_data[i] == ']') {
bracket_count--;
if (bracket_count == 0 && in_array) {
array_end = i;
break;
}
// Parse JSON response using nlohmann/json
try {
json response_json = json::parse(response_data);
std::vector<GenAI_RerankResult> results;
// Handle different response formats
if (response_json.contains("results") && response_json["results"].is_array()) {
// Format: {"results": [{"index": 0, "relevance_score": 0.95}, ...]}
for (const auto& result_item : response_json["results"]) {
GenAI_RerankResult r;
r.index = result_item.value("index", 0);
// Support both "relevance_score" and "score" field names
if (result_item.contains("relevance_score")) {
r.score = result_item.value("relevance_score", 0.0f);
} else {
r.score = result_item.value("score", 0.0f);
}
results.push_back(r);
}
std::string array_str = response_data.substr(array_start + 1, array_end - array_start - 1);
std::vector<GenAI_RerankResult> results;
size_t pos = 0;
while (pos < array_str.size()) {
size_t index_pos = array_str.find("\"index\":", pos);
if (index_pos == std::string::npos) break;
size_t num_start = index_pos + 8;
while (num_start < array_str.size() &&
(array_str[num_start] == ' ' || array_str[num_start] == '\t')) {
num_start++;
}
size_t num_end = num_start;
while (num_end < array_str.size() &&
(isdigit(array_str[num_end]) || array_str[num_end] == '-')) {
num_end++;
}
uint32_t index = 0;
if (num_start < num_end) {
try {
index = std::stoul(array_str.substr(num_start, num_end - num_start));
} catch (...) {}
}
size_t score_pos = array_str.find("\"relevance_score\":", index_pos);
if (score_pos == std::string::npos) break;
size_t score_start = score_pos + 18;
while (score_start < array_str.size() &&
(array_str[score_start] == ' ' || array_str[score_start] == '\t')) {
score_start++;
}
size_t score_end = score_start;
while (score_end < array_str.size() &&
(isdigit(array_str[score_end]) ||
array_str[score_end] == '.' ||
array_str[score_end] == '-' ||
array_str[score_end] == 'e' ||
array_str[score_end] == 'E')) {
score_end++;
}
float score = 0.0f;
if (score_start < score_end) {
try {
score = std::stof(array_str.substr(score_start, score_end - score_start));
} catch (...) {}
} else if (response_json.contains("data") && response_json["data"].is_array()) {
// Alternative format: {"data": [...]}
for (const auto& result_item : response_json["data"]) {
GenAI_RerankResult r;
r.index = result_item.value("index", 0);
// Support both "relevance_score" and "score" field names
if (result_item.contains("relevance_score")) {
r.score = result_item.value("relevance_score", 0.0f);
} else {
r.score = result_item.value("score", 0.0f);
}
results.push_back({index, score});
pos = score_end + 1;
results.push_back(r);
}
}
if (!results.empty() && top_n > 0) {
size_t count = std::min(static_cast<size_t>(top_n), results.size());
result.count = count;
result.data = new GenAI_RerankResult[count];
std::copy(results.begin(), results.begin() + count, result.data);
} else {
result.count = results.size();
result.data = new GenAI_RerankResult[results.size()];
std::copy(results.begin(), results.end(), result.data);
}
// Apply top_n limit if specified
if (!results.empty() && top_n > 0 && top_n < results.size()) {
result.count = top_n;
result.data = new GenAI_RerankResult[top_n];
std::copy(results.begin(), results.begin() + top_n, result.data);
} else if (!results.empty()) {
result.count = results.size();
result.data = new GenAI_RerankResult[results.size()];
std::copy(results.begin(), results.end(), result.data);
}
if (!results.empty()) {
status_variables.completed_requests++;
} else {
status_variables.failed_requests++;
}
} else {
} catch (const json::parse_error& e) {
proxy_error("Failed to parse rerank response JSON: %s\n", e.what());
status_variables.failed_requests++;
} catch (const std::exception& e) {
proxy_error("Error processing rerank response: %s\n", e.what());
status_variables.failed_requests++;
}
}
@ -985,13 +869,25 @@ void GenAI_Threads_Handler::listener_loop() {
GenAI_RequestHeader header;
ssize_t n = read(client_fd, &header, sizeof(header));
if (n <= 0) {
// Client disconnected or error
if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
proxy_error("GenAI: Error reading from client fd %d: %s\n",
client_fd, strerror(errno));
if (n < 0) {
// Check for non-blocking read - not an error, just no data yet
if (errno == EAGAIN || errno == EWOULDBLOCK) {
continue;
}
// Remove from epoll
// Real error - log and close connection
proxy_error("GenAI: Error reading from client fd %d: %s\n",
client_fd, strerror(errno));
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr);
close(client_fd);
{
std::lock_guard<std::mutex> lock(clients_mutex_);
client_fds_.erase(client_fd);
}
continue;
}
if (n == 0) {
// Client disconnected (EOF)
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr);
close(client_fd);
{

@ -3972,22 +3972,32 @@ void MySQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___handle_
GenAI_ResponseHeader resp;
ssize_t n = read(fd, &resp, sizeof(resp));
if (n <= 0) {
// Connection closed or error
if (n < 0) {
proxy_error("GenAI: Error reading response header from fd %d: %s\n",
fd, strerror(errno));
}
// Find and cleanup the pending request
for (auto& pair : pending_genai_requests_) {
if (pair.second.client_fd == fd) {
genai_cleanup_request(pair.first);
break;
}
if (n < 0) {
// Check for non-blocking read - not an error, just no data yet
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return;
}
return;
// Real error - log and cleanup
proxy_error("GenAI: Error reading response header from fd %d: %s\n",
fd, strerror(errno));
} else if (n == 0) {
// Connection closed (EOF) - cleanup
} else {
// Successfully read header, continue processing
goto process_response;
}
// Cleanup path for error or EOF
for (auto& pair : pending_genai_requests_) {
if (pair.second.client_fd == fd) {
genai_cleanup_request(pair.first);
break;
}
}
return;
process_response:
if (n != sizeof(resp)) {
proxy_error("GenAI: Incomplete response header from fd %d: got %zd, expected %zu\n",
fd, n, sizeof(resp));

Loading…
Cancel
Save