18 KiB
Vector Features API Reference
Overview
This document describes the C++ API for Vector Features in ProxySQL, including NL2SQL vector cache and Anomaly Detection embedding similarity.
Table of Contents
NL2SQL_Converter API
Class: NL2SQL_Converter
Location: include/NL2SQL_Converter.h
The NL2SQL_Converter class provides natural language to SQL conversion with vector-based semantic caching.
Method: get_query_embedding()
Generate vector embedding for a text query.
std::vector<float> get_query_embedding(const std::string& text);
Parameters:
text: The input text to generate embedding for
Returns:
std::vector<float>: 1536-dimensional embedding vector, or empty vector on failure
Description: Calls the GenAI module to generate a text embedding using llama-server. The embedding is a 1536-dimensional float array representing the semantic meaning of the text.
Example:
NL2SQL_Converter* converter = GloAI->get_nl2sql();
std::vector<float> embedding = converter->get_query_embedding("Show all customers");
if (embedding.size() == 1536) {
proxy_info("Generated embedding with %zu dimensions\n", embedding.size());
} else {
proxy_error("Failed to generate embedding\n");
}
Memory Management:
- GenAI allocates embedding data with
malloc() - This method copies data to
std::vector<float>and frees the original - Caller owns the returned vector
Method: check_vector_cache()
Search for semantically similar queries in the vector cache.
NL2SQLResult check_vector_cache(const NL2SQLRequest& req);
Parameters:
req: NL2SQL request containing the natural language query
Returns:
NL2SQLResult: Result with cached SQL if found,cached=falseif not
Description: Performs KNN search using cosine distance to find the most similar cached query. Returns cached SQL if similarity > threshold.
Algorithm:
- Generate embedding for query text
- Convert embedding to JSON for sqlite-vec MATCH clause
- Calculate distance threshold from similarity threshold
- Execute KNN search:
WHERE embedding MATCH '[...]' AND distance < threshold ORDER BY distance LIMIT 1 - Return cached result if found
Distance Calculation:
float distance_threshold = 2.0f - (similarity_threshold / 50.0f);
// Example: similarity=85 → distance=0.3
Example:
NL2SQLRequest req;
req.natural_language = "Display USA customers";
req.allow_cache = true;
NL2SQLResult result = converter->check_vector_cache(req);
if (result.cached) {
proxy_info("Cache hit! Score: %.2f\n", result.confidence);
// Use result.sql_query
} else {
proxy_info("Cache miss, calling LLM\n");
}
Method: store_in_vector_cache()
Store a NL2SQL conversion in the vector cache.
void store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result);
Parameters:
req: Original NL2SQL requestresult: NL2SQL conversion result to cache
Description: Stores the conversion with its embedding for future similarity search. Updates both the main table and virtual vector table.
Storage Process:
- Generate embedding for the natural language query
- Insert into
nl2sql_cachetable with embedding BLOB - Get
rowidfrom last insert - Insert
rowidintonl2sql_cache_vecvirtual table - Log cache entry
Example:
NL2SQLRequest req;
req.natural_language = "Show all customers";
NL2SQLResult result;
result.sql_query = "SELECT * FROM customers";
result.confidence = 0.95f;
converter->store_in_vector_cache(req, result);
Method: convert()
Convert natural language to SQL (main entry point).
NL2SQLResult convert(const NL2SQLRequest& req);
Parameters:
req: NL2SQL request with natural language query and context
Returns:
NL2SQLResult: Generated SQL with confidence score and metadata
Description: Complete conversion pipeline with vector caching:
- Check vector cache for similar queries
- If cache miss, build prompt with schema context
- Select model provider (Ollama/OpenAI/Anthropic)
- Call LLM API
- Validate and clean SQL
- Store result in vector cache
Example:
NL2SQLRequest req;
req.natural_language = "Find customers from USA with orders > $1000";
req.schema_name = "sales";
req.allow_cache = true;
NL2SQLResult result = converter->convert(req);
if (result.confidence > 0.7f) {
execute_sql(result.sql_query);
proxy_info("Generated by: %s\n", result.explanation.c_str());
}
Method: clear_cache()
Clear the NL2SQL vector cache.
void clear_cache();
Description:
Deletes all entries from both nl2sql_cache and nl2sql_cache_vec tables.
Example:
converter->clear_cache();
proxy_info("NL2SQL cache cleared\n");
Method: get_cache_stats()
Get cache statistics.
std::string get_cache_stats();
Returns:
std::string: JSON string with cache statistics
Statistics Include:
- Total entries
- Cache hits
- Cache misses
- Hit rate
Example:
std::string stats = converter->get_cache_stats();
proxy_info("Cache stats: %s\n", stats.c_str());
// Output: {"entries": 150, "hits": 1200, "misses": 300, "hit_rate": 0.80}
Anomaly_Detector API
Class: Anomaly_Detector
Location: include/Anomaly_Detector.h
The Anomaly_Detector class provides SQL threat detection using embedding similarity.
Method: get_query_embedding()
Generate vector embedding for a SQL query.
std::vector<float> get_query_embedding(const std::string& query);
Parameters:
query: The SQL query to generate embedding for
Returns:
std::vector<float>: 1536-dimensional embedding vector, or empty vector on failure
Description: Normalizes the query (lowercase, remove extra whitespace) and generates embedding via GenAI module.
Normalization Process:
- Convert to lowercase
- Remove extra whitespace
- Standardize SQL keywords
- Generate embedding
Example:
Anomaly_Detector* detector = GloAI->get_anomaly();
std::vector<float> embedding = detector->get_query_embedding(
"SELECT * FROM users WHERE id = 1 OR 1=1--"
);
if (embedding.size() == 1536) {
// Check similarity against threat patterns
}
Method: check_embedding_similarity()
Check if query is similar to known threat patterns.
AnomalyResult check_embedding_similarity(const std::string& query);
Parameters:
query: The SQL query to check
Returns:
AnomalyResult: Detection result with risk score and matched pattern
Detection Algorithm:
- Normalize and generate embedding for query
- KNN search against
anomaly_patterns_vec - For each match within threshold:
- Calculate risk score:
(severity / 10) * (1 - distance / 2)
- Calculate risk score:
- Return highest risk match
Risk Score Formula:
risk_score = (severity / 10.0f) * (1.0f - (distance / 2.0f));
// severity: 1-10 from threat pattern
// distance: 0-2 from cosine distance
// risk_score: 0-1 (multiply by 100 for percentage)
Example:
AnomalyResult result = detector->check_embedding_similarity(
"SELECT * FROM users WHERE id = 5 OR 2=2--"
);
if (result.risk_score > 0.7f) {
proxy_warning("High risk query detected! Score: %.2f\n", result.risk_score);
proxy_warning("Matched pattern: %s\n", result.matched_pattern.c_str());
// Block query
}
if (result.detected) {
proxy_info("Threat type: %s\n", result.threat_type.c_str());
}
Method: add_threat_pattern()
Add a new threat pattern to the database.
bool add_threat_pattern(
const std::string& pattern_name,
const std::string& query_example,
const std::string& pattern_type,
int severity
);
Parameters:
pattern_name: Human-readable name for the patternquery_example: Example SQL query representing this threatpattern_type: Type of threat (sql_injection,dos,privilege_escalation, etc.)severity: Severity level (1-10, where 10 is most severe)
Returns:
bool:trueif pattern added successfully,falseon error
Description:
Stores threat pattern with embedding in both anomaly_patterns and anomaly_patterns_vec tables.
Storage Process:
- Generate embedding for query example
- Insert into
anomaly_patternswith embedding BLOB - Get
rowidfrom last insert - Insert
rowidintoanomaly_patterns_vecvirtual table
Example:
bool success = detector->add_threat_pattern(
"OR 1=1 Tautology",
"SELECT * FROM users WHERE username='admin' OR 1=1--'",
"sql_injection",
9 // high severity
);
if (success) {
proxy_info("Threat pattern added\n");
} else {
proxy_error("Failed to add threat pattern\n");
}
Method: list_threat_patterns()
List all threat patterns in the database.
std::string list_threat_patterns();
Returns:
std::string: JSON array of threat patterns
JSON Format:
[
{
"id": 1,
"pattern_name": "OR 1=1 Tautology",
"pattern_type": "sql_injection",
"query_example": "SELECT * FROM users WHERE username='admin' OR 1=1--'",
"severity": 9,
"created_at": 1705334400
}
]
Example:
std::string patterns_json = detector->list_threat_patterns();
proxy_info("Threat patterns:\n%s\n", patterns_json.c_str());
// Parse with nlohmann/json
json patterns = json::parse(patterns_json);
for (const auto& pattern : patterns) {
proxy_info("- %s (severity: %d)\n",
pattern["pattern_name"].get<std::string>().c_str(),
pattern["severity"].get<int>());
}
Method: remove_threat_pattern()
Remove a threat pattern from the database.
bool remove_threat_pattern(int pattern_id);
Parameters:
pattern_id: ID of the pattern to remove
Returns:
bool:trueif removed successfully,falseon error
Description:
Deletes from both anomaly_patterns_vec (virtual table) and anomaly_patterns (main table).
Example:
bool success = detector->remove_threat_pattern(5);
if (success) {
proxy_info("Threat pattern 5 removed\n");
} else {
proxy_error("Failed to remove pattern\n");
}
Method: get_statistics()
Get anomaly detection statistics.
std::string get_statistics();
Returns:
std::string: JSON string with detailed statistics
Statistics Include:
{
"total_checks": 1500,
"detected_anomalies": 45,
"blocked_queries": 12,
"flagged_queries": 33,
"threat_patterns_count": 10,
"threat_patterns_by_type": {
"sql_injection": 6,
"dos": 2,
"privilege_escalation": 1,
"data_exfiltration": 1
}
}
Example:
std::string stats = detector->get_statistics();
proxy_info("Anomaly stats: %s\n", stats.c_str());
Data Structures
NL2SQLRequest
struct NL2SQLRequest {
std::string natural_language; // Input natural language query
std::string schema_name; // Target schema name
std::vector<std::string> context_tables; // Relevant tables
bool allow_cache; // Whether to check cache
int max_latency_ms; // Max acceptable latency (0 = no limit)
};
NL2SQLResult
struct NL2SQLResult {
std::string sql_query; // Generated SQL query
float confidence; // Confidence score (0.0-1.0)
std::string explanation; // Which model was used
bool cached; // Whether from cache
};
AnomalyResult
struct AnomalyResult {
bool detected; // Whether anomaly was detected
float risk_score; // Risk score (0.0-1.0)
std::string threat_type; // Type of threat
std::string matched_pattern; // Name of matched pattern
std::string action_taken; // "blocked", "flagged", "allowed"
};
Error Handling
Return Values
- bool functions: Return
falseon error - vector: Returns empty vector on error
- string functions: Return empty string or JSON error object
Logging
Use ProxySQL logging macros:
proxy_error("Failed to generate embedding: %s\n", error_msg);
proxy_warning("Low confidence result: %.2f\n", confidence);
proxy_info("Cache hit for query: %s\n", query.c_str());
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "Embedding generated with %zu dimensions", size);
Error Checking Example
std::vector<float> embedding = converter->get_query_embedding(text);
if (embedding.empty()) {
proxy_error("Failed to generate embedding for: %s\n", text.c_str());
// Handle error - return error or use fallback
return error_result;
}
if (embedding.size() != 1536) {
proxy_warning("Unexpected embedding size: %zu (expected 1536)\n", embedding.size());
// May still work, but log warning
}
Usage Examples
Complete NL2SQL Conversion with Cache
// Get converter
NL2SQL_Converter* converter = GloAI->get_nl2sql();
if (!converter) {
proxy_error("NL2SQL converter not initialized\n");
return;
}
// Prepare request
NL2SQLRequest req;
req.natural_language = "Find customers from USA with orders > $1000";
req.schema_name = "sales";
req.context_tables = {"customers", "orders"};
req.allow_cache = true;
req.max_latency_ms = 0; // No latency constraint
// Convert
NL2SQLResult result = converter->convert(req);
// Check result
if (result.confidence > 0.7f) {
proxy_info("Generated SQL: %s\n", result.sql_query.c_str());
proxy_info("Confidence: %.2f\n", result.confidence);
proxy_info("Source: %s\n", result.explanation.c_str());
if (result.cached) {
proxy_info("Retrieved from semantic cache\n");
}
// Execute the SQL
execute_sql(result.sql_query);
} else {
proxy_warning("Low confidence conversion: %.2f\n", result.confidence);
}
Complete Anomaly Detection Flow
// Get detector
Anomaly_Detector* detector = GloAI->get_anomaly();
if (!detector) {
proxy_error("Anomaly detector not initialized\n");
return;
}
// Add threat pattern
detector->add_threat_pattern(
"Sleep-based DoS",
"SELECT * FROM users WHERE id=1 AND sleep(10)",
"dos",
6
);
// Check incoming query
std::string query = "SELECT * FROM users WHERE id=5 AND SLEEP(5)--";
AnomalyResult result = detector->check_embedding_similarity(query);
if (result.detected) {
proxy_warning("Anomaly detected! Risk: %.2f\n", result.risk_score);
// Get risk threshold from config
int risk_threshold = GloAI->variables.ai_anomaly_risk_threshold;
float risk_threshold_normalized = risk_threshold / 100.0f;
if (result.risk_score > risk_threshold_normalized) {
proxy_error("Blocking high-risk query\n");
// Block the query
return error_response("Query blocked by anomaly detection");
} else {
proxy_warning("Flagging medium-risk query\n");
// Flag but allow
log_flagged_query(query, result);
}
}
// Allow query to proceed
execute_query(query);
Threat Pattern Management
// Add multiple threat patterns
std::vector<std::tuple<std::string, std::string, std::string, int>> patterns = {
{"OR 1=1", "SELECT * FROM users WHERE id=1 OR 1=1--", "sql_injection", 9},
{"UNION SELECT", "SELECT name FROM products WHERE id=1 UNION SELECT password FROM users", "sql_injection", 8},
{"DROP TABLE", "SELECT * FROM users; DROP TABLE users--", "privilege_escalation", 10}
};
for (const auto& [name, example, type, severity] : patterns) {
if (detector->add_threat_pattern(name, example, type, severity)) {
proxy_info("Added pattern: %s\n", name.c_str());
}
}
// List all patterns
std::string json = detector->list_threat_patterns();
auto patterns_data = json::parse(json);
proxy_info("Total patterns: %zu\n", patterns_data.size());
// Remove a pattern
int pattern_id = patterns_data[0]["id"];
if (detector->remove_threat_pattern(pattern_id)) {
proxy_info("Removed pattern %d\n", pattern_id);
}
// Get statistics
std::string stats = detector->get_statistics();
proxy_info("Statistics: %s\n", stats.c_str());
Integration Points
From MySQL_Session
Query interception happens in MySQL_Session::execute_query():
// Check if this is a NL2SQL query
if (query.find("NL2SQL:") == 0) {
NL2SQL_Converter* converter = GloAI->get_nl2sql();
NL2SQLRequest req;
req.natural_language = query.substr(7); // Remove "NL2SQL:" prefix
NL2SQLResult result = converter->convert(req);
return result.sql_query;
}
// Check for anomalies
Anomaly_Detector* detector = GloAI->get_anomaly();
AnomalyResult result = detector->check_embedding_similarity(query);
if (result.detected && result.risk_score > threshold) {
return error("Query blocked");
}
From MCP Tools
MCP tools can call these methods via JSON-RPC:
{
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "ai_add_threat_pattern",
"arguments": {
"pattern_name": "...",
"query_example": "...",
"pattern_type": "sql_injection",
"severity": 9
}
}
}
Thread Safety
- Read operations (check_vector_cache, check_embedding_similarity): Thread-safe, use read locks
- Write operations (store_in_vector_cache, add_threat_pattern): Thread-safe, use write locks
- Global access: Always access via
GloAIwhich manages locks
// Safe pattern
NL2SQL_Converter* converter = GloAI->get_nl2sql();
if (converter) {
// Method handles locking internally
NL2SQLResult result = converter->convert(req);
}