# Vector Features API Reference ## Overview This document describes the C++ API for Vector Features in ProxySQL, including NL2SQL vector cache and Anomaly Detection embedding similarity. ## Table of Contents - [NL2SQL_Converter API](#nl2sql_converter-api) - [Anomaly_Detector API](#anomaly_detector-api) - [Data Structures](#data-structures) - [Error Handling](#error-handling) - [Usage Examples](#usage-examples) --- ## NL2SQL_Converter API ### Class: NL2SQL_Converter Location: `include/NL2SQL_Converter.h` The NL2SQL_Converter class provides natural language to SQL conversion with vector-based semantic caching. --- ### Method: `get_query_embedding()` Generate vector embedding for a text query. ```cpp std::vector get_query_embedding(const std::string& text); ``` **Parameters:** - `text`: The input text to generate embedding for **Returns:** - `std::vector`: 1536-dimensional embedding vector, or empty vector on failure **Description:** Calls the GenAI module to generate a text embedding using llama-server. The embedding is a 1536-dimensional float array representing the semantic meaning of the text. **Example:** ```cpp NL2SQL_Converter* converter = GloAI->get_nl2sql(); std::vector embedding = converter->get_query_embedding("Show all customers"); if (embedding.size() == 1536) { proxy_info("Generated embedding with %zu dimensions\n", embedding.size()); } else { proxy_error("Failed to generate embedding\n"); } ``` **Memory Management:** - GenAI allocates embedding data with `malloc()` - This method copies data to `std::vector` and frees the original - Caller owns the returned vector --- ### Method: `check_vector_cache()` Search for semantically similar queries in the vector cache. ```cpp NL2SQLResult check_vector_cache(const NL2SQLRequest& req); ``` **Parameters:** - `req`: NL2SQL request containing the natural language query **Returns:** - `NL2SQLResult`: Result with cached SQL if found, `cached=false` if not **Description:** Performs KNN search using cosine distance to find the most similar cached query. Returns cached SQL if similarity > threshold. **Algorithm:** 1. Generate embedding for query text 2. Convert embedding to JSON for sqlite-vec MATCH clause 3. Calculate distance threshold from similarity threshold 4. Execute KNN search: `WHERE embedding MATCH '[...]' AND distance < threshold ORDER BY distance LIMIT 1` 5. Return cached result if found **Distance Calculation:** ```cpp float distance_threshold = 2.0f - (similarity_threshold / 50.0f); // Example: similarity=85 → distance=0.3 ``` **Example:** ```cpp NL2SQLRequest req; req.natural_language = "Display USA customers"; req.allow_cache = true; NL2SQLResult result = converter->check_vector_cache(req); if (result.cached) { proxy_info("Cache hit! Score: %.2f\n", result.confidence); // Use result.sql_query } else { proxy_info("Cache miss, calling LLM\n"); } ``` --- ### Method: `store_in_vector_cache()` Store a NL2SQL conversion in the vector cache. ```cpp void store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result); ``` **Parameters:** - `req`: Original NL2SQL request - `result`: NL2SQL conversion result to cache **Description:** Stores the conversion with its embedding for future similarity search. Updates both the main table and virtual vector table. **Storage Process:** 1. Generate embedding for the natural language query 2. Insert into `nl2sql_cache` table with embedding BLOB 3. Get `rowid` from last insert 4. Insert `rowid` into `nl2sql_cache_vec` virtual table 5. Log cache entry **Example:** ```cpp NL2SQLRequest req; req.natural_language = "Show all customers"; NL2SQLResult result; result.sql_query = "SELECT * FROM customers"; result.confidence = 0.95f; converter->store_in_vector_cache(req, result); ``` --- ### Method: `convert()` Convert natural language to SQL (main entry point). ```cpp NL2SQLResult convert(const NL2SQLRequest& req); ``` **Parameters:** - `req`: NL2SQL request with natural language query and context **Returns:** - `NL2SQLResult`: Generated SQL with confidence score and metadata **Description:** Complete conversion pipeline with vector caching: 1. Check vector cache for similar queries 2. If cache miss, build prompt with schema context 3. Select model provider (Ollama/OpenAI/Anthropic) 4. Call LLM API 5. Validate and clean SQL 6. Store result in vector cache **Example:** ```cpp NL2SQLRequest req; req.natural_language = "Find customers from USA with orders > $1000"; req.schema_name = "sales"; req.allow_cache = true; NL2SQLResult result = converter->convert(req); if (result.confidence > 0.7f) { execute_sql(result.sql_query); proxy_info("Generated by: %s\n", result.explanation.c_str()); } ``` --- ### Method: `clear_cache()` Clear the NL2SQL vector cache. ```cpp void clear_cache(); ``` **Description:** Deletes all entries from both `nl2sql_cache` and `nl2sql_cache_vec` tables. **Example:** ```cpp converter->clear_cache(); proxy_info("NL2SQL cache cleared\n"); ``` --- ### Method: `get_cache_stats()` Get cache statistics. ```cpp std::string get_cache_stats(); ``` **Returns:** - `std::string`: JSON string with cache statistics **Statistics Include:** - Total entries - Cache hits - Cache misses - Hit rate **Example:** ```cpp std::string stats = converter->get_cache_stats(); proxy_info("Cache stats: %s\n", stats.c_str()); // Output: {"entries": 150, "hits": 1200, "misses": 300, "hit_rate": 0.80} ``` --- ## Anomaly_Detector API ### Class: Anomaly_Detector Location: `include/Anomaly_Detector.h` The Anomaly_Detector class provides SQL threat detection using embedding similarity. --- ### Method: `get_query_embedding()` Generate vector embedding for a SQL query. ```cpp std::vector get_query_embedding(const std::string& query); ``` **Parameters:** - `query`: The SQL query to generate embedding for **Returns:** - `std::vector`: 1536-dimensional embedding vector, or empty vector on failure **Description:** Normalizes the query (lowercase, remove extra whitespace) and generates embedding via GenAI module. **Normalization Process:** 1. Convert to lowercase 2. Remove extra whitespace 3. Standardize SQL keywords 4. Generate embedding **Example:** ```cpp Anomaly_Detector* detector = GloAI->get_anomaly(); std::vector embedding = detector->get_query_embedding( "SELECT * FROM users WHERE id = 1 OR 1=1--" ); if (embedding.size() == 1536) { // Check similarity against threat patterns } ``` --- ### Method: `check_embedding_similarity()` Check if query is similar to known threat patterns. ```cpp AnomalyResult check_embedding_similarity(const std::string& query); ``` **Parameters:** - `query`: The SQL query to check **Returns:** - `AnomalyResult`: Detection result with risk score and matched pattern **Detection Algorithm:** 1. Normalize and generate embedding for query 2. KNN search against `anomaly_patterns_vec` 3. For each match within threshold: - Calculate risk score: `(severity / 10) * (1 - distance / 2)` 4. Return highest risk match **Risk Score Formula:** ```cpp risk_score = (severity / 10.0f) * (1.0f - (distance / 2.0f)); // severity: 1-10 from threat pattern // distance: 0-2 from cosine distance // risk_score: 0-1 (multiply by 100 for percentage) ``` **Example:** ```cpp AnomalyResult result = detector->check_embedding_similarity( "SELECT * FROM users WHERE id = 5 OR 2=2--" ); if (result.risk_score > 0.7f) { proxy_warning("High risk query detected! Score: %.2f\n", result.risk_score); proxy_warning("Matched pattern: %s\n", result.matched_pattern.c_str()); // Block query } if (result.detected) { proxy_info("Threat type: %s\n", result.threat_type.c_str()); } ``` --- ### Method: `add_threat_pattern()` Add a new threat pattern to the database. ```cpp bool add_threat_pattern( const std::string& pattern_name, const std::string& query_example, const std::string& pattern_type, int severity ); ``` **Parameters:** - `pattern_name`: Human-readable name for the pattern - `query_example`: Example SQL query representing this threat - `pattern_type`: Type of threat (`sql_injection`, `dos`, `privilege_escalation`, etc.) - `severity`: Severity level (1-10, where 10 is most severe) **Returns:** - `bool`: `true` if pattern added successfully, `false` on error **Description:** Stores threat pattern with embedding in both `anomaly_patterns` and `anomaly_patterns_vec` tables. **Storage Process:** 1. Generate embedding for query example 2. Insert into `anomaly_patterns` with embedding BLOB 3. Get `rowid` from last insert 4. Insert `rowid` into `anomaly_patterns_vec` virtual table **Example:** ```cpp bool success = detector->add_threat_pattern( "OR 1=1 Tautology", "SELECT * FROM users WHERE username='admin' OR 1=1--'", "sql_injection", 9 // high severity ); if (success) { proxy_info("Threat pattern added\n"); } else { proxy_error("Failed to add threat pattern\n"); } ``` --- ### Method: `list_threat_patterns()` List all threat patterns in the database. ```cpp std::string list_threat_patterns(); ``` **Returns:** - `std::string`: JSON array of threat patterns **JSON Format:** ```json [ { "id": 1, "pattern_name": "OR 1=1 Tautology", "pattern_type": "sql_injection", "query_example": "SELECT * FROM users WHERE username='admin' OR 1=1--'", "severity": 9, "created_at": 1705334400 } ] ``` **Example:** ```cpp std::string patterns_json = detector->list_threat_patterns(); proxy_info("Threat patterns:\n%s\n", patterns_json.c_str()); // Parse with nlohmann/json json patterns = json::parse(patterns_json); for (const auto& pattern : patterns) { proxy_info("- %s (severity: %d)\n", pattern["pattern_name"].get().c_str(), pattern["severity"].get()); } ``` --- ### Method: `remove_threat_pattern()` Remove a threat pattern from the database. ```cpp bool remove_threat_pattern(int pattern_id); ``` **Parameters:** - `pattern_id`: ID of the pattern to remove **Returns:** - `bool`: `true` if removed successfully, `false` on error **Description:** Deletes from both `anomaly_patterns_vec` (virtual table) and `anomaly_patterns` (main table). **Example:** ```cpp bool success = detector->remove_threat_pattern(5); if (success) { proxy_info("Threat pattern 5 removed\n"); } else { proxy_error("Failed to remove pattern\n"); } ``` --- ### Method: `get_statistics()` Get anomaly detection statistics. ```cpp std::string get_statistics(); ``` **Returns:** - `std::string`: JSON string with detailed statistics **Statistics Include:** ```json { "total_checks": 1500, "detected_anomalies": 45, "blocked_queries": 12, "flagged_queries": 33, "threat_patterns_count": 10, "threat_patterns_by_type": { "sql_injection": 6, "dos": 2, "privilege_escalation": 1, "data_exfiltration": 1 } } ``` **Example:** ```cpp std::string stats = detector->get_statistics(); proxy_info("Anomaly stats: %s\n", stats.c_str()); ``` --- ## Data Structures ### NL2SQLRequest ```cpp struct NL2SQLRequest { std::string natural_language; // Input natural language query std::string schema_name; // Target schema name std::vector context_tables; // Relevant tables bool allow_cache; // Whether to check cache int max_latency_ms; // Max acceptable latency (0 = no limit) }; ``` ### NL2SQLResult ```cpp struct NL2SQLResult { std::string sql_query; // Generated SQL query float confidence; // Confidence score (0.0-1.0) std::string explanation; // Which model was used bool cached; // Whether from cache }; ``` ### AnomalyResult ```cpp struct AnomalyResult { bool detected; // Whether anomaly was detected float risk_score; // Risk score (0.0-1.0) std::string threat_type; // Type of threat std::string matched_pattern; // Name of matched pattern std::string action_taken; // "blocked", "flagged", "allowed" }; ``` --- ## Error Handling ### Return Values - **bool functions**: Return `false` on error - **vector**: Returns empty vector on error - **string functions**: Return empty string or JSON error object ### Logging Use ProxySQL logging macros: ```cpp proxy_error("Failed to generate embedding: %s\n", error_msg); proxy_warning("Low confidence result: %.2f\n", confidence); proxy_info("Cache hit for query: %s\n", query.c_str()); proxy_debug(PROXY_DEBUG_NL2SQL, 3, "Embedding generated with %zu dimensions", size); ``` ### Error Checking Example ```cpp std::vector embedding = converter->get_query_embedding(text); if (embedding.empty()) { proxy_error("Failed to generate embedding for: %s\n", text.c_str()); // Handle error - return error or use fallback return error_result; } if (embedding.size() != 1536) { proxy_warning("Unexpected embedding size: %zu (expected 1536)\n", embedding.size()); // May still work, but log warning } ``` --- ## Usage Examples ### Complete NL2SQL Conversion with Cache ```cpp // Get converter NL2SQL_Converter* converter = GloAI->get_nl2sql(); if (!converter) { proxy_error("NL2SQL converter not initialized\n"); return; } // Prepare request NL2SQLRequest req; req.natural_language = "Find customers from USA with orders > $1000"; req.schema_name = "sales"; req.context_tables = {"customers", "orders"}; req.allow_cache = true; req.max_latency_ms = 0; // No latency constraint // Convert NL2SQLResult result = converter->convert(req); // Check result if (result.confidence > 0.7f) { proxy_info("Generated SQL: %s\n", result.sql_query.c_str()); proxy_info("Confidence: %.2f\n", result.confidence); proxy_info("Source: %s\n", result.explanation.c_str()); if (result.cached) { proxy_info("Retrieved from semantic cache\n"); } // Execute the SQL execute_sql(result.sql_query); } else { proxy_warning("Low confidence conversion: %.2f\n", result.confidence); } ``` ### Complete Anomaly Detection Flow ```cpp // Get detector Anomaly_Detector* detector = GloAI->get_anomaly(); if (!detector) { proxy_error("Anomaly detector not initialized\n"); return; } // Add threat pattern detector->add_threat_pattern( "Sleep-based DoS", "SELECT * FROM users WHERE id=1 AND sleep(10)", "dos", 6 ); // Check incoming query std::string query = "SELECT * FROM users WHERE id=5 AND SLEEP(5)--"; AnomalyResult result = detector->check_embedding_similarity(query); if (result.detected) { proxy_warning("Anomaly detected! Risk: %.2f\n", result.risk_score); // Get risk threshold from config int risk_threshold = GloAI->variables.ai_anomaly_risk_threshold; float risk_threshold_normalized = risk_threshold / 100.0f; if (result.risk_score > risk_threshold_normalized) { proxy_error("Blocking high-risk query\n"); // Block the query return error_response("Query blocked by anomaly detection"); } else { proxy_warning("Flagging medium-risk query\n"); // Flag but allow log_flagged_query(query, result); } } // Allow query to proceed execute_query(query); ``` ### Threat Pattern Management ```cpp // Add multiple threat patterns std::vector> patterns = { {"OR 1=1", "SELECT * FROM users WHERE id=1 OR 1=1--", "sql_injection", 9}, {"UNION SELECT", "SELECT name FROM products WHERE id=1 UNION SELECT password FROM users", "sql_injection", 8}, {"DROP TABLE", "SELECT * FROM users; DROP TABLE users--", "privilege_escalation", 10} }; for (const auto& [name, example, type, severity] : patterns) { if (detector->add_threat_pattern(name, example, type, severity)) { proxy_info("Added pattern: %s\n", name.c_str()); } } // List all patterns std::string json = detector->list_threat_patterns(); auto patterns_data = json::parse(json); proxy_info("Total patterns: %zu\n", patterns_data.size()); // Remove a pattern int pattern_id = patterns_data[0]["id"]; if (detector->remove_threat_pattern(pattern_id)) { proxy_info("Removed pattern %d\n", pattern_id); } // Get statistics std::string stats = detector->get_statistics(); proxy_info("Statistics: %s\n", stats.c_str()); ``` --- ## Integration Points ### From MySQL_Session Query interception happens in `MySQL_Session::execute_query()`: ```cpp // Check if this is a NL2SQL query if (query.find("NL2SQL:") == 0) { NL2SQL_Converter* converter = GloAI->get_nl2sql(); NL2SQLRequest req; req.natural_language = query.substr(7); // Remove "NL2SQL:" prefix NL2SQLResult result = converter->convert(req); return result.sql_query; } // Check for anomalies Anomaly_Detector* detector = GloAI->get_anomaly(); AnomalyResult result = detector->check_embedding_similarity(query); if (result.detected && result.risk_score > threshold) { return error("Query blocked"); } ``` ### From MCP Tools MCP tools can call these methods via JSON-RPC: ```json { "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "ai_add_threat_pattern", "arguments": { "pattern_name": "...", "query_example": "...", "pattern_type": "sql_injection", "severity": 9 } } } ``` --- ## Thread Safety - **Read operations** (check_vector_cache, check_embedding_similarity): Thread-safe, use read locks - **Write operations** (store_in_vector_cache, add_threat_pattern): Thread-safe, use write locks - **Global access**: Always access via `GloAI` which manages locks ```cpp // Safe pattern NL2SQL_Converter* converter = GloAI->get_nl2sql(); if (converter) { // Method handles locking internally NL2SQLResult result = converter->convert(req); } ```