You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/doc/VECTOR_FEATURES/API.md

737 lines
18 KiB

# Vector Features API Reference
## Overview
This document describes the C++ API for Vector Features in ProxySQL, including NL2SQL vector cache and Anomaly Detection embedding similarity.
## Table of Contents
- [NL2SQL_Converter API](#nl2sql_converter-api)
- [Anomaly_Detector API](#anomaly_detector-api)
- [Data Structures](#data-structures)
- [Error Handling](#error-handling)
- [Usage Examples](#usage-examples)
---
## NL2SQL_Converter API
### Class: NL2SQL_Converter
Location: `include/NL2SQL_Converter.h`
The NL2SQL_Converter class provides natural language to SQL conversion with vector-based semantic caching.
---
### Method: `get_query_embedding()`
Generate vector embedding for a text query.
```cpp
std::vector<float> get_query_embedding(const std::string& text);
```
**Parameters:**
- `text`: The input text to generate embedding for
**Returns:**
- `std::vector<float>`: 1536-dimensional embedding vector, or empty vector on failure
**Description:**
Calls the GenAI module to generate a text embedding using llama-server. The embedding is a 1536-dimensional float array representing the semantic meaning of the text.
**Example:**
```cpp
NL2SQL_Converter* converter = GloAI->get_nl2sql();
std::vector<float> embedding = converter->get_query_embedding("Show all customers");
if (embedding.size() == 1536) {
proxy_info("Generated embedding with %zu dimensions\n", embedding.size());
} else {
proxy_error("Failed to generate embedding\n");
}
```
**Memory Management:**
- GenAI allocates embedding data with `malloc()`
- This method copies data to `std::vector<float>` and frees the original
- Caller owns the returned vector
---
### Method: `check_vector_cache()`
Search for semantically similar queries in the vector cache.
```cpp
NL2SQLResult check_vector_cache(const NL2SQLRequest& req);
```
**Parameters:**
- `req`: NL2SQL request containing the natural language query
**Returns:**
- `NL2SQLResult`: Result with cached SQL if found, `cached=false` if not
**Description:**
Performs KNN search using cosine distance to find the most similar cached query. Returns cached SQL if similarity > threshold.
**Algorithm:**
1. Generate embedding for query text
2. Convert embedding to JSON for sqlite-vec MATCH clause
3. Calculate distance threshold from similarity threshold
4. Execute KNN search: `WHERE embedding MATCH '[...]' AND distance < threshold ORDER BY distance LIMIT 1`
5. Return cached result if found
**Distance Calculation:**
```cpp
float distance_threshold = 2.0f - (similarity_threshold / 50.0f);
// Example: similarity=85 → distance=0.3
```
**Example:**
```cpp
NL2SQLRequest req;
req.natural_language = "Display USA customers";
req.allow_cache = true;
NL2SQLResult result = converter->check_vector_cache(req);
if (result.cached) {
proxy_info("Cache hit! Score: %.2f\n", result.confidence);
// Use result.sql_query
} else {
proxy_info("Cache miss, calling LLM\n");
}
```
---
### Method: `store_in_vector_cache()`
Store a NL2SQL conversion in the vector cache.
```cpp
void store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result);
```
**Parameters:**
- `req`: Original NL2SQL request
- `result`: NL2SQL conversion result to cache
**Description:**
Stores the conversion with its embedding for future similarity search. Updates both the main table and virtual vector table.
**Storage Process:**
1. Generate embedding for the natural language query
2. Insert into `nl2sql_cache` table with embedding BLOB
3. Get `rowid` from last insert
4. Insert `rowid` into `nl2sql_cache_vec` virtual table
5. Log cache entry
**Example:**
```cpp
NL2SQLRequest req;
req.natural_language = "Show all customers";
NL2SQLResult result;
result.sql_query = "SELECT * FROM customers";
result.confidence = 0.95f;
converter->store_in_vector_cache(req, result);
```
---
### Method: `convert()`
Convert natural language to SQL (main entry point).
```cpp
NL2SQLResult convert(const NL2SQLRequest& req);
```
**Parameters:**
- `req`: NL2SQL request with natural language query and context
**Returns:**
- `NL2SQLResult`: Generated SQL with confidence score and metadata
**Description:**
Complete conversion pipeline with vector caching:
1. Check vector cache for similar queries
2. If cache miss, build prompt with schema context
3. Select model provider (Ollama/OpenAI/Anthropic)
4. Call LLM API
5. Validate and clean SQL
6. Store result in vector cache
**Example:**
```cpp
NL2SQLRequest req;
req.natural_language = "Find customers from USA with orders > $1000";
req.schema_name = "sales";
req.allow_cache = true;
NL2SQLResult result = converter->convert(req);
if (result.confidence > 0.7f) {
execute_sql(result.sql_query);
proxy_info("Generated by: %s\n", result.explanation.c_str());
}
```
---
### Method: `clear_cache()`
Clear the NL2SQL vector cache.
```cpp
void clear_cache();
```
**Description:**
Deletes all entries from both `nl2sql_cache` and `nl2sql_cache_vec` tables.
**Example:**
```cpp
converter->clear_cache();
proxy_info("NL2SQL cache cleared\n");
```
---
### Method: `get_cache_stats()`
Get cache statistics.
```cpp
std::string get_cache_stats();
```
**Returns:**
- `std::string`: JSON string with cache statistics
**Statistics Include:**
- Total entries
- Cache hits
- Cache misses
- Hit rate
**Example:**
```cpp
std::string stats = converter->get_cache_stats();
proxy_info("Cache stats: %s\n", stats.c_str());
// Output: {"entries": 150, "hits": 1200, "misses": 300, "hit_rate": 0.80}
```
---
## Anomaly_Detector API
### Class: Anomaly_Detector
Location: `include/Anomaly_Detector.h`
The Anomaly_Detector class provides SQL threat detection using embedding similarity.
---
### Method: `get_query_embedding()`
Generate vector embedding for a SQL query.
```cpp
std::vector<float> get_query_embedding(const std::string& query);
```
**Parameters:**
- `query`: The SQL query to generate embedding for
**Returns:**
- `std::vector<float>`: 1536-dimensional embedding vector, or empty vector on failure
**Description:**
Normalizes the query (lowercase, remove extra whitespace) and generates embedding via GenAI module.
**Normalization Process:**
1. Convert to lowercase
2. Remove extra whitespace
3. Standardize SQL keywords
4. Generate embedding
**Example:**
```cpp
Anomaly_Detector* detector = GloAI->get_anomaly();
std::vector<float> embedding = detector->get_query_embedding(
"SELECT * FROM users WHERE id = 1 OR 1=1--"
);
if (embedding.size() == 1536) {
// Check similarity against threat patterns
}
```
---
### Method: `check_embedding_similarity()`
Check if query is similar to known threat patterns.
```cpp
AnomalyResult check_embedding_similarity(const std::string& query);
```
**Parameters:**
- `query`: The SQL query to check
**Returns:**
- `AnomalyResult`: Detection result with risk score and matched pattern
**Detection Algorithm:**
1. Normalize and generate embedding for query
2. KNN search against `anomaly_patterns_vec`
3. For each match within threshold:
- Calculate risk score: `(severity / 10) * (1 - distance / 2)`
4. Return highest risk match
**Risk Score Formula:**
```cpp
risk_score = (severity / 10.0f) * (1.0f - (distance / 2.0f));
// severity: 1-10 from threat pattern
// distance: 0-2 from cosine distance
// risk_score: 0-1 (multiply by 100 for percentage)
```
**Example:**
```cpp
AnomalyResult result = detector->check_embedding_similarity(
"SELECT * FROM users WHERE id = 5 OR 2=2--"
);
if (result.risk_score > 0.7f) {
proxy_warning("High risk query detected! Score: %.2f\n", result.risk_score);
proxy_warning("Matched pattern: %s\n", result.matched_pattern.c_str());
// Block query
}
if (result.detected) {
proxy_info("Threat type: %s\n", result.threat_type.c_str());
}
```
---
### Method: `add_threat_pattern()`
Add a new threat pattern to the database.
```cpp
bool add_threat_pattern(
const std::string& pattern_name,
const std::string& query_example,
const std::string& pattern_type,
int severity
);
```
**Parameters:**
- `pattern_name`: Human-readable name for the pattern
- `query_example`: Example SQL query representing this threat
- `pattern_type`: Type of threat (`sql_injection`, `dos`, `privilege_escalation`, etc.)
- `severity`: Severity level (1-10, where 10 is most severe)
**Returns:**
- `bool`: `true` if pattern added successfully, `false` on error
**Description:**
Stores threat pattern with embedding in both `anomaly_patterns` and `anomaly_patterns_vec` tables.
**Storage Process:**
1. Generate embedding for query example
2. Insert into `anomaly_patterns` with embedding BLOB
3. Get `rowid` from last insert
4. Insert `rowid` into `anomaly_patterns_vec` virtual table
**Example:**
```cpp
bool success = detector->add_threat_pattern(
"OR 1=1 Tautology",
"SELECT * FROM users WHERE username='admin' OR 1=1--'",
"sql_injection",
9 // high severity
);
if (success) {
proxy_info("Threat pattern added\n");
} else {
proxy_error("Failed to add threat pattern\n");
}
```
---
### Method: `list_threat_patterns()`
List all threat patterns in the database.
```cpp
std::string list_threat_patterns();
```
**Returns:**
- `std::string`: JSON array of threat patterns
**JSON Format:**
```json
[
{
"id": 1,
"pattern_name": "OR 1=1 Tautology",
"pattern_type": "sql_injection",
"query_example": "SELECT * FROM users WHERE username='admin' OR 1=1--'",
"severity": 9,
"created_at": 1705334400
}
]
```
**Example:**
```cpp
std::string patterns_json = detector->list_threat_patterns();
proxy_info("Threat patterns:\n%s\n", patterns_json.c_str());
// Parse with nlohmann/json
json patterns = json::parse(patterns_json);
for (const auto& pattern : patterns) {
proxy_info("- %s (severity: %d)\n",
pattern["pattern_name"].get<std::string>().c_str(),
pattern["severity"].get<int>());
}
```
---
### Method: `remove_threat_pattern()`
Remove a threat pattern from the database.
```cpp
bool remove_threat_pattern(int pattern_id);
```
**Parameters:**
- `pattern_id`: ID of the pattern to remove
**Returns:**
- `bool`: `true` if removed successfully, `false` on error
**Description:**
Deletes from both `anomaly_patterns_vec` (virtual table) and `anomaly_patterns` (main table).
**Example:**
```cpp
bool success = detector->remove_threat_pattern(5);
if (success) {
proxy_info("Threat pattern 5 removed\n");
} else {
proxy_error("Failed to remove pattern\n");
}
```
---
### Method: `get_statistics()`
Get anomaly detection statistics.
```cpp
std::string get_statistics();
```
**Returns:**
- `std::string`: JSON string with detailed statistics
**Statistics Include:**
```json
{
"total_checks": 1500,
"detected_anomalies": 45,
"blocked_queries": 12,
"flagged_queries": 33,
"threat_patterns_count": 10,
"threat_patterns_by_type": {
"sql_injection": 6,
"dos": 2,
"privilege_escalation": 1,
"data_exfiltration": 1
}
}
```
**Example:**
```cpp
std::string stats = detector->get_statistics();
proxy_info("Anomaly stats: %s\n", stats.c_str());
```
---
## Data Structures
### NL2SQLRequest
```cpp
struct NL2SQLRequest {
std::string natural_language; // Input natural language query
std::string schema_name; // Target schema name
std::vector<std::string> context_tables; // Relevant tables
bool allow_cache; // Whether to check cache
int max_latency_ms; // Max acceptable latency (0 = no limit)
};
```
### NL2SQLResult
```cpp
struct NL2SQLResult {
std::string sql_query; // Generated SQL query
float confidence; // Confidence score (0.0-1.0)
std::string explanation; // Which model was used
bool cached; // Whether from cache
};
```
### AnomalyResult
```cpp
struct AnomalyResult {
bool detected; // Whether anomaly was detected
float risk_score; // Risk score (0.0-1.0)
std::string threat_type; // Type of threat
std::string matched_pattern; // Name of matched pattern
std::string action_taken; // "blocked", "flagged", "allowed"
};
```
---
## Error Handling
### Return Values
- **bool functions**: Return `false` on error
- **vector<float>**: Returns empty vector on error
- **string functions**: Return empty string or JSON error object
### Logging
Use ProxySQL logging macros:
```cpp
proxy_error("Failed to generate embedding: %s\n", error_msg);
proxy_warning("Low confidence result: %.2f\n", confidence);
proxy_info("Cache hit for query: %s\n", query.c_str());
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "Embedding generated with %zu dimensions", size);
```
### Error Checking Example
```cpp
std::vector<float> embedding = converter->get_query_embedding(text);
if (embedding.empty()) {
proxy_error("Failed to generate embedding for: %s\n", text.c_str());
// Handle error - return error or use fallback
return error_result;
}
if (embedding.size() != 1536) {
proxy_warning("Unexpected embedding size: %zu (expected 1536)\n", embedding.size());
// May still work, but log warning
}
```
---
## Usage Examples
### Complete NL2SQL Conversion with Cache
```cpp
// Get converter
NL2SQL_Converter* converter = GloAI->get_nl2sql();
if (!converter) {
proxy_error("NL2SQL converter not initialized\n");
return;
}
// Prepare request
NL2SQLRequest req;
req.natural_language = "Find customers from USA with orders > $1000";
req.schema_name = "sales";
req.context_tables = {"customers", "orders"};
req.allow_cache = true;
req.max_latency_ms = 0; // No latency constraint
// Convert
NL2SQLResult result = converter->convert(req);
// Check result
if (result.confidence > 0.7f) {
proxy_info("Generated SQL: %s\n", result.sql_query.c_str());
proxy_info("Confidence: %.2f\n", result.confidence);
proxy_info("Source: %s\n", result.explanation.c_str());
if (result.cached) {
proxy_info("Retrieved from semantic cache\n");
}
// Execute the SQL
execute_sql(result.sql_query);
} else {
proxy_warning("Low confidence conversion: %.2f\n", result.confidence);
}
```
### Complete Anomaly Detection Flow
```cpp
// Get detector
Anomaly_Detector* detector = GloAI->get_anomaly();
if (!detector) {
proxy_error("Anomaly detector not initialized\n");
return;
}
// Add threat pattern
detector->add_threat_pattern(
"Sleep-based DoS",
"SELECT * FROM users WHERE id=1 AND sleep(10)",
"dos",
6
);
// Check incoming query
std::string query = "SELECT * FROM users WHERE id=5 AND SLEEP(5)--";
AnomalyResult result = detector->check_embedding_similarity(query);
if (result.detected) {
proxy_warning("Anomaly detected! Risk: %.2f\n", result.risk_score);
// Get risk threshold from config
int risk_threshold = GloAI->variables.ai_anomaly_risk_threshold;
float risk_threshold_normalized = risk_threshold / 100.0f;
if (result.risk_score > risk_threshold_normalized) {
proxy_error("Blocking high-risk query\n");
// Block the query
return error_response("Query blocked by anomaly detection");
} else {
proxy_warning("Flagging medium-risk query\n");
// Flag but allow
log_flagged_query(query, result);
}
}
// Allow query to proceed
execute_query(query);
```
### Threat Pattern Management
```cpp
// Add multiple threat patterns
std::vector<std::tuple<std::string, std::string, std::string, int>> patterns = {
{"OR 1=1", "SELECT * FROM users WHERE id=1 OR 1=1--", "sql_injection", 9},
{"UNION SELECT", "SELECT name FROM products WHERE id=1 UNION SELECT password FROM users", "sql_injection", 8},
{"DROP TABLE", "SELECT * FROM users; DROP TABLE users--", "privilege_escalation", 10}
};
for (const auto& [name, example, type, severity] : patterns) {
if (detector->add_threat_pattern(name, example, type, severity)) {
proxy_info("Added pattern: %s\n", name.c_str());
}
}
// List all patterns
std::string json = detector->list_threat_patterns();
auto patterns_data = json::parse(json);
proxy_info("Total patterns: %zu\n", patterns_data.size());
// Remove a pattern
int pattern_id = patterns_data[0]["id"];
if (detector->remove_threat_pattern(pattern_id)) {
proxy_info("Removed pattern %d\n", pattern_id);
}
// Get statistics
std::string stats = detector->get_statistics();
proxy_info("Statistics: %s\n", stats.c_str());
```
---
## Integration Points
### From MySQL_Session
Query interception happens in `MySQL_Session::execute_query()`:
```cpp
// Check if this is a NL2SQL query
if (query.find("NL2SQL:") == 0) {
NL2SQL_Converter* converter = GloAI->get_nl2sql();
NL2SQLRequest req;
req.natural_language = query.substr(7); // Remove "NL2SQL:" prefix
NL2SQLResult result = converter->convert(req);
return result.sql_query;
}
// Check for anomalies
Anomaly_Detector* detector = GloAI->get_anomaly();
AnomalyResult result = detector->check_embedding_similarity(query);
if (result.detected && result.risk_score > threshold) {
return error("Query blocked");
}
```
### From MCP Tools
MCP tools can call these methods via JSON-RPC:
```json
{
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "ai_add_threat_pattern",
"arguments": {
"pattern_name": "...",
"query_example": "...",
"pattern_type": "sql_injection",
"severity": 9
}
}
}
```
---
## Thread Safety
- **Read operations** (check_vector_cache, check_embedding_similarity): Thread-safe, use read locks
- **Write operations** (store_in_vector_cache, add_threat_pattern): Thread-safe, use write locks
- **Global access**: Always access via `GloAI` which manages locks
```cpp
// Safe pattern
NL2SQL_Converter* converter = GloAI->get_nl2sql();
if (converter) {
// Method handles locking internally
NL2SQLResult result = converter->convert(req);
}
```