You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/doc/VECTOR_FEATURES/API.md

18 KiB

Vector Features API Reference

Overview

This document describes the C++ API for Vector Features in ProxySQL, including NL2SQL vector cache and Anomaly Detection embedding similarity.

Table of Contents


NL2SQL_Converter API

Class: NL2SQL_Converter

Location: include/NL2SQL_Converter.h

The NL2SQL_Converter class provides natural language to SQL conversion with vector-based semantic caching.


Method: get_query_embedding()

Generate vector embedding for a text query.

std::vector<float> get_query_embedding(const std::string& text);

Parameters:

  • text: The input text to generate embedding for

Returns:

  • std::vector<float>: 1536-dimensional embedding vector, or empty vector on failure

Description: Calls the GenAI module to generate a text embedding using llama-server. The embedding is a 1536-dimensional float array representing the semantic meaning of the text.

Example:

NL2SQL_Converter* converter = GloAI->get_nl2sql();
std::vector<float> embedding = converter->get_query_embedding("Show all customers");

if (embedding.size() == 1536) {
    proxy_info("Generated embedding with %zu dimensions\n", embedding.size());
} else {
    proxy_error("Failed to generate embedding\n");
}

Memory Management:

  • GenAI allocates embedding data with malloc()
  • This method copies data to std::vector<float> and frees the original
  • Caller owns the returned vector

Method: check_vector_cache()

Search for semantically similar queries in the vector cache.

NL2SQLResult check_vector_cache(const NL2SQLRequest& req);

Parameters:

  • req: NL2SQL request containing the natural language query

Returns:

  • NL2SQLResult: Result with cached SQL if found, cached=false if not

Description: Performs KNN search using cosine distance to find the most similar cached query. Returns cached SQL if similarity > threshold.

Algorithm:

  1. Generate embedding for query text
  2. Convert embedding to JSON for sqlite-vec MATCH clause
  3. Calculate distance threshold from similarity threshold
  4. Execute KNN search: WHERE embedding MATCH '[...]' AND distance < threshold ORDER BY distance LIMIT 1
  5. Return cached result if found

Distance Calculation:

float distance_threshold = 2.0f - (similarity_threshold / 50.0f);
// Example: similarity=85 → distance=0.3

Example:

NL2SQLRequest req;
req.natural_language = "Display USA customers";
req.allow_cache = true;

NL2SQLResult result = converter->check_vector_cache(req);

if (result.cached) {
    proxy_info("Cache hit! Score: %.2f\n", result.confidence);
    // Use result.sql_query
} else {
    proxy_info("Cache miss, calling LLM\n");
}

Method: store_in_vector_cache()

Store a NL2SQL conversion in the vector cache.

void store_in_vector_cache(const NL2SQLRequest& req, const NL2SQLResult& result);

Parameters:

  • req: Original NL2SQL request
  • result: NL2SQL conversion result to cache

Description: Stores the conversion with its embedding for future similarity search. Updates both the main table and virtual vector table.

Storage Process:

  1. Generate embedding for the natural language query
  2. Insert into nl2sql_cache table with embedding BLOB
  3. Get rowid from last insert
  4. Insert rowid into nl2sql_cache_vec virtual table
  5. Log cache entry

Example:

NL2SQLRequest req;
req.natural_language = "Show all customers";

NL2SQLResult result;
result.sql_query = "SELECT * FROM customers";
result.confidence = 0.95f;

converter->store_in_vector_cache(req, result);

Method: convert()

Convert natural language to SQL (main entry point).

NL2SQLResult convert(const NL2SQLRequest& req);

Parameters:

  • req: NL2SQL request with natural language query and context

Returns:

  • NL2SQLResult: Generated SQL with confidence score and metadata

Description: Complete conversion pipeline with vector caching:

  1. Check vector cache for similar queries
  2. If cache miss, build prompt with schema context
  3. Select model provider (Ollama/OpenAI/Anthropic)
  4. Call LLM API
  5. Validate and clean SQL
  6. Store result in vector cache

Example:

NL2SQLRequest req;
req.natural_language = "Find customers from USA with orders > $1000";
req.schema_name = "sales";
req.allow_cache = true;

NL2SQLResult result = converter->convert(req);

if (result.confidence > 0.7f) {
    execute_sql(result.sql_query);
    proxy_info("Generated by: %s\n", result.explanation.c_str());
}

Method: clear_cache()

Clear the NL2SQL vector cache.

void clear_cache();

Description: Deletes all entries from both nl2sql_cache and nl2sql_cache_vec tables.

Example:

converter->clear_cache();
proxy_info("NL2SQL cache cleared\n");

Method: get_cache_stats()

Get cache statistics.

std::string get_cache_stats();

Returns:

  • std::string: JSON string with cache statistics

Statistics Include:

  • Total entries
  • Cache hits
  • Cache misses
  • Hit rate

Example:

std::string stats = converter->get_cache_stats();
proxy_info("Cache stats: %s\n", stats.c_str());
// Output: {"entries": 150, "hits": 1200, "misses": 300, "hit_rate": 0.80}

Anomaly_Detector API

Class: Anomaly_Detector

Location: include/Anomaly_Detector.h

The Anomaly_Detector class provides SQL threat detection using embedding similarity.


Method: get_query_embedding()

Generate vector embedding for a SQL query.

std::vector<float> get_query_embedding(const std::string& query);

Parameters:

  • query: The SQL query to generate embedding for

Returns:

  • std::vector<float>: 1536-dimensional embedding vector, or empty vector on failure

Description: Normalizes the query (lowercase, remove extra whitespace) and generates embedding via GenAI module.

Normalization Process:

  1. Convert to lowercase
  2. Remove extra whitespace
  3. Standardize SQL keywords
  4. Generate embedding

Example:

Anomaly_Detector* detector = GloAI->get_anomaly();
std::vector<float> embedding = detector->get_query_embedding(
    "SELECT * FROM users WHERE id = 1 OR 1=1--"
);

if (embedding.size() == 1536) {
    // Check similarity against threat patterns
}

Method: check_embedding_similarity()

Check if query is similar to known threat patterns.

AnomalyResult check_embedding_similarity(const std::string& query);

Parameters:

  • query: The SQL query to check

Returns:

  • AnomalyResult: Detection result with risk score and matched pattern

Detection Algorithm:

  1. Normalize and generate embedding for query
  2. KNN search against anomaly_patterns_vec
  3. For each match within threshold:
    • Calculate risk score: (severity / 10) * (1 - distance / 2)
  4. Return highest risk match

Risk Score Formula:

risk_score = (severity / 10.0f) * (1.0f - (distance / 2.0f));
// severity: 1-10 from threat pattern
// distance: 0-2 from cosine distance
// risk_score: 0-1 (multiply by 100 for percentage)

Example:

AnomalyResult result = detector->check_embedding_similarity(
    "SELECT * FROM users WHERE id = 5 OR 2=2--"
);

if (result.risk_score > 0.7f) {
    proxy_warning("High risk query detected! Score: %.2f\n", result.risk_score);
    proxy_warning("Matched pattern: %s\n", result.matched_pattern.c_str());
    // Block query
}

if (result.detected) {
    proxy_info("Threat type: %s\n", result.threat_type.c_str());
}

Method: add_threat_pattern()

Add a new threat pattern to the database.

bool add_threat_pattern(
    const std::string& pattern_name,
    const std::string& query_example,
    const std::string& pattern_type,
    int severity
);

Parameters:

  • pattern_name: Human-readable name for the pattern
  • query_example: Example SQL query representing this threat
  • pattern_type: Type of threat (sql_injection, dos, privilege_escalation, etc.)
  • severity: Severity level (1-10, where 10 is most severe)

Returns:

  • bool: true if pattern added successfully, false on error

Description: Stores threat pattern with embedding in both anomaly_patterns and anomaly_patterns_vec tables.

Storage Process:

  1. Generate embedding for query example
  2. Insert into anomaly_patterns with embedding BLOB
  3. Get rowid from last insert
  4. Insert rowid into anomaly_patterns_vec virtual table

Example:

bool success = detector->add_threat_pattern(
    "OR 1=1 Tautology",
    "SELECT * FROM users WHERE username='admin' OR 1=1--'",
    "sql_injection",
    9  // high severity
);

if (success) {
    proxy_info("Threat pattern added\n");
} else {
    proxy_error("Failed to add threat pattern\n");
}

Method: list_threat_patterns()

List all threat patterns in the database.

std::string list_threat_patterns();

Returns:

  • std::string: JSON array of threat patterns

JSON Format:

[
  {
    "id": 1,
    "pattern_name": "OR 1=1 Tautology",
    "pattern_type": "sql_injection",
    "query_example": "SELECT * FROM users WHERE username='admin' OR 1=1--'",
    "severity": 9,
    "created_at": 1705334400
  }
]

Example:

std::string patterns_json = detector->list_threat_patterns();
proxy_info("Threat patterns:\n%s\n", patterns_json.c_str());

// Parse with nlohmann/json
json patterns = json::parse(patterns_json);
for (const auto& pattern : patterns) {
    proxy_info("- %s (severity: %d)\n",
        pattern["pattern_name"].get<std::string>().c_str(),
        pattern["severity"].get<int>());
}

Method: remove_threat_pattern()

Remove a threat pattern from the database.

bool remove_threat_pattern(int pattern_id);

Parameters:

  • pattern_id: ID of the pattern to remove

Returns:

  • bool: true if removed successfully, false on error

Description: Deletes from both anomaly_patterns_vec (virtual table) and anomaly_patterns (main table).

Example:

bool success = detector->remove_threat_pattern(5);

if (success) {
    proxy_info("Threat pattern 5 removed\n");
} else {
    proxy_error("Failed to remove pattern\n");
}

Method: get_statistics()

Get anomaly detection statistics.

std::string get_statistics();

Returns:

  • std::string: JSON string with detailed statistics

Statistics Include:

{
  "total_checks": 1500,
  "detected_anomalies": 45,
  "blocked_queries": 12,
  "flagged_queries": 33,
  "threat_patterns_count": 10,
  "threat_patterns_by_type": {
    "sql_injection": 6,
    "dos": 2,
    "privilege_escalation": 1,
    "data_exfiltration": 1
  }
}

Example:

std::string stats = detector->get_statistics();
proxy_info("Anomaly stats: %s\n", stats.c_str());

Data Structures

NL2SQLRequest

struct NL2SQLRequest {
    std::string natural_language;  // Input natural language query
    std::string schema_name;       // Target schema name
    std::vector<std::string> context_tables;  // Relevant tables
    bool allow_cache;              // Whether to check cache
    int max_latency_ms;            // Max acceptable latency (0 = no limit)
};

NL2SQLResult

struct NL2SQLResult {
    std::string sql_query;         // Generated SQL query
    float confidence;              // Confidence score (0.0-1.0)
    std::string explanation;       // Which model was used
    bool cached;                   // Whether from cache
};

AnomalyResult

struct AnomalyResult {
    bool detected;                 // Whether anomaly was detected
    float risk_score;              // Risk score (0.0-1.0)
    std::string threat_type;       // Type of threat
    std::string matched_pattern;   // Name of matched pattern
    std::string action_taken;      // "blocked", "flagged", "allowed"
};

Error Handling

Return Values

  • bool functions: Return false on error
  • vector: Returns empty vector on error
  • string functions: Return empty string or JSON error object

Logging

Use ProxySQL logging macros:

proxy_error("Failed to generate embedding: %s\n", error_msg);
proxy_warning("Low confidence result: %.2f\n", confidence);
proxy_info("Cache hit for query: %s\n", query.c_str());
proxy_debug(PROXY_DEBUG_NL2SQL, 3, "Embedding generated with %zu dimensions", size);

Error Checking Example

std::vector<float> embedding = converter->get_query_embedding(text);

if (embedding.empty()) {
    proxy_error("Failed to generate embedding for: %s\n", text.c_str());
    // Handle error - return error or use fallback
    return error_result;
}

if (embedding.size() != 1536) {
    proxy_warning("Unexpected embedding size: %zu (expected 1536)\n", embedding.size());
    // May still work, but log warning
}

Usage Examples

Complete NL2SQL Conversion with Cache

// Get converter
NL2SQL_Converter* converter = GloAI->get_nl2sql();
if (!converter) {
    proxy_error("NL2SQL converter not initialized\n");
    return;
}

// Prepare request
NL2SQLRequest req;
req.natural_language = "Find customers from USA with orders > $1000";
req.schema_name = "sales";
req.context_tables = {"customers", "orders"};
req.allow_cache = true;
req.max_latency_ms = 0;  // No latency constraint

// Convert
NL2SQLResult result = converter->convert(req);

// Check result
if (result.confidence > 0.7f) {
    proxy_info("Generated SQL: %s\n", result.sql_query.c_str());
    proxy_info("Confidence: %.2f\n", result.confidence);
    proxy_info("Source: %s\n", result.explanation.c_str());

    if (result.cached) {
        proxy_info("Retrieved from semantic cache\n");
    }

    // Execute the SQL
    execute_sql(result.sql_query);
} else {
    proxy_warning("Low confidence conversion: %.2f\n", result.confidence);
}

Complete Anomaly Detection Flow

// Get detector
Anomaly_Detector* detector = GloAI->get_anomaly();
if (!detector) {
    proxy_error("Anomaly detector not initialized\n");
    return;
}

// Add threat pattern
detector->add_threat_pattern(
    "Sleep-based DoS",
    "SELECT * FROM users WHERE id=1 AND sleep(10)",
    "dos",
    6
);

// Check incoming query
std::string query = "SELECT * FROM users WHERE id=5 AND SLEEP(5)--";
AnomalyResult result = detector->check_embedding_similarity(query);

if (result.detected) {
    proxy_warning("Anomaly detected! Risk: %.2f\n", result.risk_score);

    // Get risk threshold from config
    int risk_threshold = GloAI->variables.ai_anomaly_risk_threshold;
    float risk_threshold_normalized = risk_threshold / 100.0f;

    if (result.risk_score > risk_threshold_normalized) {
        proxy_error("Blocking high-risk query\n");
        // Block the query
        return error_response("Query blocked by anomaly detection");
    } else {
        proxy_warning("Flagging medium-risk query\n");
        // Flag but allow
        log_flagged_query(query, result);
    }
}

// Allow query to proceed
execute_query(query);

Threat Pattern Management

// Add multiple threat patterns
std::vector<std::tuple<std::string, std::string, std::string, int>> patterns = {
    {"OR 1=1", "SELECT * FROM users WHERE id=1 OR 1=1--", "sql_injection", 9},
    {"UNION SELECT", "SELECT name FROM products WHERE id=1 UNION SELECT password FROM users", "sql_injection", 8},
    {"DROP TABLE", "SELECT * FROM users; DROP TABLE users--", "privilege_escalation", 10}
};

for (const auto& [name, example, type, severity] : patterns) {
    if (detector->add_threat_pattern(name, example, type, severity)) {
        proxy_info("Added pattern: %s\n", name.c_str());
    }
}

// List all patterns
std::string json = detector->list_threat_patterns();
auto patterns_data = json::parse(json);
proxy_info("Total patterns: %zu\n", patterns_data.size());

// Remove a pattern
int pattern_id = patterns_data[0]["id"];
if (detector->remove_threat_pattern(pattern_id)) {
    proxy_info("Removed pattern %d\n", pattern_id);
}

// Get statistics
std::string stats = detector->get_statistics();
proxy_info("Statistics: %s\n", stats.c_str());

Integration Points

From MySQL_Session

Query interception happens in MySQL_Session::execute_query():

// Check if this is a NL2SQL query
if (query.find("NL2SQL:") == 0) {
    NL2SQL_Converter* converter = GloAI->get_nl2sql();
    NL2SQLRequest req;
    req.natural_language = query.substr(7);  // Remove "NL2SQL:" prefix
    NL2SQLResult result = converter->convert(req);
    return result.sql_query;
}

// Check for anomalies
Anomaly_Detector* detector = GloAI->get_anomaly();
AnomalyResult result = detector->check_embedding_similarity(query);
if (result.detected && result.risk_score > threshold) {
    return error("Query blocked");
}

From MCP Tools

MCP tools can call these methods via JSON-RPC:

{
  "jsonrpc": "2.0",
  "method": "tools/call",
  "params": {
    "name": "ai_add_threat_pattern",
    "arguments": {
      "pattern_name": "...",
      "query_example": "...",
      "pattern_type": "sql_injection",
      "severity": 9
    }
  }
}

Thread Safety

  • Read operations (check_vector_cache, check_embedding_similarity): Thread-safe, use read locks
  • Write operations (store_in_vector_cache, add_threat_pattern): Thread-safe, use write locks
  • Global access: Always access via GloAI which manages locks
// Safe pattern
NL2SQL_Converter* converter = GloAI->get_nl2sql();
if (converter) {
    // Method handles locking internally
    NL2SQLResult result = converter->convert(req);
}