You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/lib/AI_Features_Manager.cpp

423 lines
12 KiB

#include "AI_Features_Manager.h"
#include "NL2SQL_Converter.h"
#include "Anomaly_Detector.h"
#include "sqlite3db.h"
#include "proxysql_utils.h"
#include <cstring>
#include <cstdlib>
#include <sys/stat.h>
#include <libgen.h> // for dirname
// Global instance is defined in src/main.cpp
extern AI_Features_Manager *GloAI;
// Forward declaration to avoid header ordering issues
class ProxySQL_Admin;
extern ProxySQL_Admin *GloAdmin;
AI_Features_Manager::AI_Features_Manager()
: shutdown_(0), nl2sql_converter(NULL), anomaly_detector(NULL), vector_db(NULL)
{
pthread_rwlock_init(&rwlock, NULL);
// Initialize configuration variables to defaults
variables.ai_features_enabled = false;
variables.ai_nl2sql_enabled = false;
variables.ai_anomaly_detection_enabled = false;
variables.ai_nl2sql_query_prefix = strdup("NL2SQL:");
variables.ai_nl2sql_model_provider = strdup("ollama");
variables.ai_nl2sql_ollama_model = strdup("llama3.2");
variables.ai_nl2sql_openai_model = strdup("gpt-4o-mini");
variables.ai_nl2sql_anthropic_model = strdup("claude-3-haiku");
variables.ai_nl2sql_cache_similarity_threshold = 85;
variables.ai_nl2sql_timeout_ms = 30000;
variables.ai_nl2sql_openai_key = NULL;
variables.ai_nl2sql_anthropic_key = NULL;
variables.ai_anomaly_risk_threshold = 70;
variables.ai_anomaly_similarity_threshold = 80;
variables.ai_anomaly_rate_limit = 100;
variables.ai_anomaly_auto_block = true;
variables.ai_anomaly_log_only = false;
variables.ai_prefer_local_models = true;
variables.ai_daily_budget_usd = 10.0;
variables.ai_max_cloud_requests_per_hour = 100;
variables.ai_vector_db_path = strdup("/var/lib/proxysql/ai_features.db");
variables.ai_vector_dimension = 1536; // OpenAI text-embedding-3-small
// Initialize status counters
memset(&status_variables, 0, sizeof(status_variables));
}
AI_Features_Manager::~AI_Features_Manager() {
shutdown();
// Free configuration strings
free(variables.ai_nl2sql_query_prefix);
free(variables.ai_nl2sql_model_provider);
free(variables.ai_nl2sql_ollama_model);
free(variables.ai_nl2sql_openai_model);
free(variables.ai_nl2sql_anthropic_model);
free(variables.ai_nl2sql_openai_key);
free(variables.ai_nl2sql_anthropic_key);
free(variables.ai_vector_db_path);
pthread_rwlock_destroy(&rwlock);
}
int AI_Features_Manager::init_vector_db() {
proxy_info("AI: Initializing vector storage at %s\n", variables.ai_vector_db_path);
// Ensure directory exists
char* path_copy = strdup(variables.ai_vector_db_path);
char* dir = dirname(path_copy);
struct stat st;
if (stat(dir, &st) != 0) {
// Create directory if it doesn't exist
char cmd[512];
snprintf(cmd, sizeof(cmd), "mkdir -p %s", dir);
system(cmd);
}
free(path_copy);
vector_db = new SQLite3DB();
char path_buf[512];
strncpy(path_buf, variables.ai_vector_db_path, sizeof(path_buf) - 1);
path_buf[sizeof(path_buf) - 1] = '\0';
int rc = vector_db->open(path_buf, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE);
if (rc != SQLITE_OK) {
proxy_error("AI: Failed to open vector database: %s\n", variables.ai_vector_db_path);
delete vector_db;
vector_db = NULL;
return -1;
}
// Create tables for NL2SQL cache
const char* create_nl2sql_cache =
"CREATE TABLE IF NOT EXISTS nl2sql_cache ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"natural_language TEXT NOT NULL,"
"generated_sql TEXT NOT NULL,"
"schema_context TEXT,"
"embedding BLOB,"
"hit_count INTEGER DEFAULT 0,"
"last_hit INTEGER,"
"created_at INTEGER DEFAULT (strftime('%s', 'now'))"
");";
if (vector_db->execute(create_nl2sql_cache) != 0) {
proxy_error("AI: Failed to create nl2sql_cache table\n");
return -1;
}
// Create table for anomaly patterns
const char* create_anomaly_patterns =
"CREATE TABLE IF NOT EXISTS anomaly_patterns ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"pattern_name TEXT,"
"pattern_type TEXT," // 'sql_injection', 'dos', 'privilege_escalation'
"query_example TEXT,"
"embedding BLOB,"
"severity INTEGER," // 1-10
"created_at INTEGER DEFAULT (strftime('%s', 'now'))"
");";
if (vector_db->execute(create_anomaly_patterns) != 0) {
proxy_error("AI: Failed to create anomaly_patterns table\n");
return -1;
}
// Create table for query history
const char* create_query_history =
"CREATE TABLE IF NOT EXISTS query_history ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"query_text TEXT NOT NULL,"
"generated_sql TEXT,"
"embedding BLOB,"
"execution_time_ms INTEGER,"
"success BOOLEAN,"
"timestamp INTEGER DEFAULT (strftime('%s', 'now'))"
");";
if (vector_db->execute(create_query_history) != 0) {
proxy_error("AI: Failed to create query_history table\n");
return -1;
}
proxy_info("AI: Vector storage initialized successfully\n");
return 0;
}
int AI_Features_Manager::init_nl2sql() {
if (!variables.ai_nl2sql_enabled) {
proxy_info("AI: NL2SQL disabled, skipping initialization\n");
return 0;
}
proxy_info("AI: Initializing NL2SQL Converter\n");
nl2sql_converter = new NL2SQL_Converter();
if (nl2sql_converter->init() != 0) {
proxy_error("AI: Failed to initialize NL2SQL Converter\n");
delete nl2sql_converter;
nl2sql_converter = NULL;
return -1;
}
proxy_info("AI: NL2SQL Converter initialized\n");
return 0;
}
int AI_Features_Manager::init_anomaly_detector() {
if (!variables.ai_anomaly_detection_enabled) {
proxy_info("AI: Anomaly detection disabled, skipping initialization\n");
return 0;
}
proxy_info("AI: Initializing Anomaly Detector\n");
anomaly_detector = new Anomaly_Detector();
if (anomaly_detector->init() != 0) {
proxy_error("AI: Failed to initialize Anomaly Detector\n");
delete anomaly_detector;
anomaly_detector = NULL;
return -1;
}
proxy_info("AI: Anomaly Detector initialized\n");
return 0;
}
void AI_Features_Manager::close_vector_db() {
if (vector_db) {
delete vector_db;
vector_db = NULL;
}
}
void AI_Features_Manager::close_nl2sql() {
if (nl2sql_converter) {
nl2sql_converter->close();
delete nl2sql_converter;
nl2sql_converter = NULL;
}
}
void AI_Features_Manager::close_anomaly_detector() {
if (anomaly_detector) {
anomaly_detector->close();
delete anomaly_detector;
anomaly_detector = NULL;
}
}
int AI_Features_Manager::init() {
proxy_info("AI: Initializing AI Features Manager v%s\n", AI_FEATURES_MANAGER_VERSION);
if (!variables.ai_features_enabled) {
proxy_info("AI: AI features disabled by configuration\n");
return 0;
}
// Initialize vector storage first (needed by both NL2SQL and Anomaly Detector)
if (init_vector_db() != 0) {
proxy_error("AI: Failed to initialize vector storage\n");
return -1;
}
// Initialize NL2SQL
if (init_nl2sql() != 0) {
proxy_error("AI: Failed to initialize NL2SQL\n");
return -1;
}
// Initialize Anomaly Detector
if (init_anomaly_detector() != 0) {
proxy_error("AI: Failed to initialize Anomaly Detector\n");
return -1;
}
proxy_info("AI: AI Features Manager initialized successfully\n");
return 0;
}
void AI_Features_Manager::shutdown() {
if (shutdown_) return;
shutdown_ = 1;
proxy_info("AI: Shutting down AI Features Manager\n");
close_nl2sql();
close_anomaly_detector();
close_vector_db();
proxy_info("AI: AI Features Manager shutdown complete\n");
}
void AI_Features_Manager::wrlock() {
pthread_rwlock_wrlock(&rwlock);
}
void AI_Features_Manager::wrunlock() {
pthread_rwlock_unlock(&rwlock);
}
char* AI_Features_Manager::get_variable(const char* name) {
if (strcmp(name, "ai_features_enabled") == 0)
return variables.ai_features_enabled ? strdup("true") : strdup("false");
if (strcmp(name, "ai_nl2sql_enabled") == 0)
return variables.ai_nl2sql_enabled ? strdup("true") : strdup("false");
if (strcmp(name, "ai_anomaly_detection_enabled") == 0)
return variables.ai_anomaly_detection_enabled ? strdup("true") : strdup("false");
if (strcmp(name, "ai_nl2sql_query_prefix") == 0)
return strdup(variables.ai_nl2sql_query_prefix);
if (strcmp(name, "ai_nl2sql_model_provider") == 0)
return strdup(variables.ai_nl2sql_model_provider);
if (strcmp(name, "ai_nl2sql_ollama_model") == 0)
return strdup(variables.ai_nl2sql_ollama_model);
if (strcmp(name, "ai_nl2sql_openai_model") == 0)
return strdup(variables.ai_nl2sql_openai_model);
if (strcmp(name, "ai_anomaly_risk_threshold") == 0) {
char buf[32];
snprintf(buf, sizeof(buf), "%d", variables.ai_anomaly_risk_threshold);
return strdup(buf);
}
if (strcmp(name, "ai_prefer_local_models") == 0)
return variables.ai_prefer_local_models ? strdup("true") : strdup("false");
if (strcmp(name, "ai_vector_db_path") == 0)
return strdup(variables.ai_vector_db_path);
return NULL;
}
bool AI_Features_Manager::set_variable(const char* name, const char* value) {
wrlock();
bool changed = false;
if (strcmp(name, "ai_features_enabled") == 0) {
bool new_val = (strcmp(value, "true") == 0);
changed = (variables.ai_features_enabled != new_val);
variables.ai_features_enabled = new_val;
}
else if (strcmp(name, "ai_nl2sql_enabled") == 0) {
bool new_val = (strcmp(value, "true") == 0);
changed = (variables.ai_nl2sql_enabled != new_val);
variables.ai_nl2sql_enabled = new_val;
}
else if (strcmp(name, "ai_anomaly_detection_enabled") == 0) {
bool new_val = (strcmp(value, "true") == 0);
changed = (variables.ai_anomaly_detection_enabled != new_val);
variables.ai_anomaly_detection_enabled = new_val;
}
else if (strcmp(name, "ai_nl2sql_query_prefix") == 0) {
free(variables.ai_nl2sql_query_prefix);
variables.ai_nl2sql_query_prefix = strdup(value);
changed = true;
}
else if (strcmp(name, "ai_nl2sql_model_provider") == 0) {
free(variables.ai_nl2sql_model_provider);
variables.ai_nl2sql_model_provider = strdup(value);
changed = true;
}
else if (strcmp(name, "ai_nl2sql_ollama_model") == 0) {
free(variables.ai_nl2sql_ollama_model);
variables.ai_nl2sql_ollama_model = strdup(value);
changed = true;
}
else if (strcmp(name, "ai_nl2sql_openai_model") == 0) {
free(variables.ai_nl2sql_openai_model);
variables.ai_nl2sql_openai_model = strdup(value);
changed = true;
}
else if (strcmp(name, "ai_anomaly_risk_threshold") == 0) {
variables.ai_anomaly_risk_threshold = atoi(value);
changed = true;
}
else if (strcmp(name, "ai_prefer_local_models") == 0) {
variables.ai_prefer_local_models = (strcmp(value, "true") == 0);
changed = true;
}
else if (strcmp(name, "ai_vector_db_path") == 0) {
free(variables.ai_vector_db_path);
variables.ai_vector_db_path = strdup(value);
changed = true;
}
wrunlock();
return changed;
}
char** AI_Features_Manager::get_variables_list() {
// Return NULL-terminated array of variable names
static const char* vars[] = {
"ai_features_enabled",
"ai_nl2sql_enabled",
"ai_anomaly_detection_enabled",
"ai_nl2sql_query_prefix",
"ai_nl2sql_model_provider",
"ai_nl2sql_ollama_model",
"ai_nl2sql_openai_model",
"ai_nl2sql_anthropic_model",
"ai_nl2sql_cache_similarity_threshold",
"ai_nl2sql_timeout_ms",
"ai_anomaly_risk_threshold",
"ai_anomaly_similarity_threshold",
"ai_anomaly_rate_limit",
"ai_anomaly_auto_block",
"ai_anomaly_log_only",
"ai_prefer_local_models",
"ai_daily_budget_usd",
"ai_max_cloud_requests_per_hour",
"ai_vector_db_path",
"ai_vector_dimension",
NULL
};
// Clone the array
char** result = (char**)malloc(sizeof(char*) * 21);
for (int i = 0; vars[i]; i++) {
result[i] = strdup(vars[i]);
}
result[20] = NULL;
return result;
}
std::string AI_Features_Manager::get_status_json() {
char buf[1024];
snprintf(buf, sizeof(buf),
"{"
"\"version\": \"%s\","
"\"nl2sql\": {"
"\"total_requests\": %llu,"
"\"cache_hits\": %llu,"
"\"local_calls\": %llu,"
"\"cloud_calls\": %llu"
"},"
"\"anomaly\": {"
"\"total_checks\": %llu,"
"\"blocked\": %llu,"
"\"flagged\": %llu"
"},"
"\"spend\": {"
"\"daily_usd\": %.2f"
"}"
"}",
AI_FEATURES_MANAGER_VERSION,
status_variables.nl2sql_total_requests,
status_variables.nl2sql_cache_hits,
status_variables.nl2sql_local_model_calls,
status_variables.nl2sql_cloud_model_calls,
status_variables.anomaly_total_checks,
status_variables.anomaly_blocked_queries,
status_variables.anomaly_flagged_queries,
status_variables.daily_cloud_spend_usd
);
return std::string(buf);
}