You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/include/RAG_Tool_Handler.h

512 lines
16 KiB

/**
* @file RAG_Tool_Handler.h
* @brief RAG Tool Handler for MCP protocol
*
* Provides RAG (Retrieval-Augmented Generation) tools via MCP protocol including:
* - FTS search over documents
* - Vector search over embeddings
* - Hybrid search combining FTS and vectors
* - Fetch tools for retrieving document/chunk content
* - Refetch tool for authoritative source data
* - Admin tools for operational visibility
*
* The RAG subsystem implements a complete retrieval system with:
* - Full-text search using SQLite FTS5
* - Semantic search using vector embeddings with sqlite3-vec
* - Hybrid search combining both approaches
* - Comprehensive filtering capabilities
* - Security features including input validation and limits
* - Performance optimizations
*
* @date 2026-01-19
* @author ProxySQL Team
* @copyright GNU GPL v3
* @ingroup mcp
* @ingroup rag
*/
#ifndef CLASS_RAG_TOOL_HANDLER_H
#define CLASS_RAG_TOOL_HANDLER_H
#include "MCP_Tool_Handler.h"
#include "sqlite3db.h"
#include "GenAI_Thread.h"
#include <string>
#include <vector>
#include <map>
#include <pthread.h>
// Forward declarations
class AI_Features_Manager;
class Discovery_Schema;
/**
* @brief RAG Tool Handler for MCP
*
* Provides RAG-powered tools through the MCP protocol:
* - rag.search_fts: Keyword search using FTS5
* - rag.search_vector: Semantic search using vector embeddings
* - rag.search_hybrid: Hybrid search combining FTS and vectors
* - rag.get_chunks: Fetch chunk content by chunk_id
* - rag.get_docs: Fetch document content by doc_id
* - rag.fetch_from_source: Refetch authoritative data from source
* - rag.admin.stats: Operational statistics
*
* The RAG subsystem implements a complete retrieval system with:
* - Full-text search using SQLite FTS5
* - Semantic search using vector embeddings with sqlite3-vec
* - Hybrid search combining both approaches with Reciprocal Rank Fusion
* - Comprehensive filtering capabilities by source, document, tags, dates, etc.
* - Security features including input validation, limits, and timeouts
* - Performance optimizations with prepared statements and connection management
*
* @ingroup mcp
* @ingroup rag
*/
class RAG_Tool_Handler : public MCP_Tool_Handler {
private:
/// Vector database connection
SQLite3DB* vector_db;
/// AI features manager for shared resources
AI_Features_Manager* ai_manager;
/// Discovery catalog for logging
Discovery_Schema* catalog;
/// Catalog path for database initialization
std::string catalog_path;
/// @name Configuration Parameters
/// @{
/// Maximum number of search results (default: 50)
int k_max;
/// Maximum number of candidates for hybrid search (default: 500)
int candidates_max;
/// Maximum query length in bytes (default: 8192)
int query_max_bytes;
/// Maximum response size in bytes (default: 5000000)
int response_max_bytes;
/// Operation timeout in milliseconds (default: 2000)
int timeout_ms;
/// @}
// Statistics for a specific (tool, schema) pair
struct ToolUsageStats {
unsigned long long count;
unsigned long long first_seen;
unsigned long long last_seen;
unsigned long long sum_time;
unsigned long long min_time;
unsigned long long max_time;
ToolUsageStats() : count(0), first_seen(0), last_seen(0),
sum_time(0), min_time(0), max_time(0) {}
void add_timing(unsigned long long duration, unsigned long long timestamp) {
count++;
sum_time += duration;
if (duration < min_time || min_time == 0) {
if (duration) min_time = duration;
}
if (duration > max_time) {
max_time = duration;
}
if (first_seen == 0) {
first_seen = timestamp;
}
last_seen = timestamp;
}
};
// Tool usage counters: endpoint -> tool_name -> schema_name -> ToolUsageStats
typedef std::map<std::string, ToolUsageStats> SchemaStatsMap;
typedef std::map<std::string, SchemaStatsMap> ToolStatsMap;
typedef std::map<std::string, ToolStatsMap> ToolUsageStatsMap;
ToolUsageStatsMap tool_usage_stats;
pthread_mutex_t counters_lock;
// Friend function for tracking tool invocations
friend void track_tool_invocation(RAG_Tool_Handler*, const std::string&, const std::string&, const std::string&, unsigned long long);
/**
* @brief Helper to extract string parameter from JSON
*
* Safely extracts a string parameter from a JSON object, handling type
* conversion if necessary. Returns the default value if the key is not
* found or cannot be converted to a string.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @param default_val Default value if key not found
* @return Extracted string value or default
*
* @see get_json_int()
* @see get_json_bool()
* @see get_json_string_array()
* @see get_json_int_array()
*/
static std::string get_json_string(const json& j, const std::string& key,
const std::string& default_val = "");
/**
* @brief Helper to extract int parameter from JSON
*
* Safely extracts an integer parameter from a JSON object, handling type
* conversion from string if necessary. Returns the default value if the
* key is not found or cannot be converted to an integer.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @param default_val Default value if key not found
* @return Extracted int value or default
*
* @see get_json_string()
* @see get_json_bool()
* @see get_json_string_array()
* @see get_json_int_array()
*/
static int get_json_int(const json& j, const std::string& key, int default_val = 0);
/**
* @brief Helper to extract bool parameter from JSON
*
* Safely extracts a boolean parameter from a JSON object, handling type
* conversion from string or integer if necessary. Returns the default
* value if the key is not found or cannot be converted to a boolean.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @param default_val Default value if key not found
* @return Extracted bool value or default
*
* @see get_json_string()
* @see get_json_int()
* @see get_json_string_array()
* @see get_json_int_array()
*/
static bool get_json_bool(const json& j, const std::string& key, bool default_val = false);
/**
* @brief Helper to extract string array from JSON
*
* Safely extracts a string array parameter from a JSON object, filtering
* out non-string elements. Returns an empty vector if the key is not
* found or is not an array.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @return Vector of extracted strings
*
* @see get_json_string()
* @see get_json_int()
* @see get_json_bool()
* @see get_json_int_array()
*/
static std::vector<std::string> get_json_string_array(const json& j, const std::string& key);
/**
* @brief Helper to extract int array from JSON
*
* Safely extracts an integer array parameter from a JSON object, handling
* type conversion from string if necessary. Returns an empty vector if
* the key is not found or is not an array.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @return Vector of extracted integers
*
* @see get_json_string()
* @see get_json_int()
* @see get_json_bool()
* @see get_json_string_array()
*/
static std::vector<int> get_json_int_array(const json& j, const std::string& key);
/**
* @brief Validate and limit k parameter
*
* Ensures the k parameter is within acceptable bounds (1 to k_max).
* Returns default value of 10 if k is invalid.
*
* @param k Requested number of results
* @return Validated k value within configured limits
*
* @see validate_candidates()
* @see k_max
*/
int validate_k(int k);
/**
* @brief Validate and limit candidates parameter
*
* Ensures the candidates parameter is within acceptable bounds (1 to candidates_max).
* Returns default value of 50 if candidates is invalid.
*
* @param candidates Requested number of candidates
* @return Validated candidates value within configured limits
*
* @see validate_k()
* @see candidates_max
*/
int validate_candidates(int candidates);
/**
* @brief Validate query length
*
* Checks if the query string length is within the configured query_max_bytes limit.
*
* @param query Query string to validate
* @return true if query is within length limits, false otherwise
*
* @see query_max_bytes
*/
bool validate_query_length(const std::string& query);
/**
* @brief Escape FTS query string for safe use in MATCH clause
*
* Escapes single quotes in FTS query strings by doubling them,
* which is the standard escaping method for SQLite FTS5.
* This prevents FTS injection while allowing legitimate single quotes in queries.
*
* @param query Raw FTS query string from user input
* @return Escaped query string safe for use in MATCH clause
*
* @see execute_tool()
*/
std::string escape_fts_query(const std::string& query);
/**
* @brief Execute database query and return results
*
* Executes a SQL query against the vector database and returns the results.
* Handles error checking and logging. The caller is responsible for freeing
* the returned SQLite3_result.
*
* @param query SQL query string to execute
* @return SQLite3_result pointer or NULL on error
*
* @see vector_db
*/
SQLite3_result* execute_query(const char* query);
/**
* @brief Execute parameterized database query with bindings
*
* Executes a parameterized SQL query against the vector database with bound parameters
* and returns the results. This prevents SQL injection vulnerabilities.
* Handles error checking and logging. The caller is responsible for freeing
* the returned SQLite3_result.
*
* @param query SQL query string with placeholders to execute
* @param bindings Vector of parameter bindings (text, int, double)
* @return SQLite3_result pointer or NULL on error
*
* @see vector_db
*/
SQLite3_result* execute_parameterized_query(const char* query, const std::vector<std::pair<int, std::string>>& text_bindings = {}, const std::vector<std::pair<int, int>>& int_bindings = {});
/**
* @brief Build SQL filter conditions from JSON filters
*
* Builds SQL WHERE conditions from JSON filter parameters with proper input validation
* to prevent SQL injection. This consolidates the duplicated filter building logic
* across different search tools.
*
* @param filters JSON object containing filter parameters
* @param sql Reference to SQL string to append conditions to
* @return true on success, false on validation error
*
* @see execute_tool()
*/
bool build_sql_filters(const json& filters, std::string& sql);
/**
* @brief Compute Reciprocal Rank Fusion score
*
* Computes the Reciprocal Rank Fusion score for hybrid search ranking.
* Formula: weight / (k0 + rank)
*
* @param rank Rank position (1-based)
* @param k0 Smoothing parameter
* @param weight Weight factor for this ranking
* @return RRF score
*
* @see rag.search_hybrid
*/
double compute_rrf_score(int rank, int k0, double weight);
/**
* @brief Normalize scores to 0-1 range (higher is better)
*
* Normalizes various types of scores to a consistent 0-1 range where
* higher values indicate better matches. Different score types may
* require different normalization approaches.
*
* @param score Raw score to normalize
* @param score_type Type of score being normalized
* @return Normalized score in 0-1 range
*/
double normalize_score(double score, const std::string& score_type);
public:
/**
* @brief Constructor
*
* Initializes the RAG tool handler with configuration parameters from GenAI_Thread
* if available, otherwise uses default values.
*
* Configuration parameters:
* - k_max: Maximum number of search results (default: 50)
* - candidates_max: Maximum number of candidates for hybrid search (default: 500)
* - query_max_bytes: Maximum query length in bytes (default: 8192)
* - response_max_bytes: Maximum response size in bytes (default: 5000000)
* - timeout_ms: Operation timeout in milliseconds (default: 2000)
*
* @param ai_mgr Pointer to AI_Features_Manager for database access and configuration
* @param cat_path Path to the catalog database (for logging)
*
* @see AI_Features_Manager
* @see Discovery_Schema
* @see GenAI_Thread
*/
RAG_Tool_Handler(AI_Features_Manager* ai_mgr, const std::string& cat_path = "");
/**
* @brief Destructor
*
* Cleans up resources and closes database connections.
*
* @see close()
*/
~RAG_Tool_Handler();
/**
* @brief Initialize the tool handler
*
* Initializes the RAG tool handler by establishing database connections
* and preparing internal state. Must be called before executing any tools.
*
* @return 0 on success, -1 on error
*
* @see close()
* @see vector_db
* @see ai_manager
*/
int init() override;
/**
* @brief Close and cleanup
*
* Cleans up resources and closes database connections. Called automatically
* by the destructor.
*
* @see init()
* @see ~RAG_Tool_Handler()
*/
void close() override;
/**
* @brief Get handler name
*
* Returns the name of this tool handler for identification purposes.
*
* @return Handler name as string ("rag")
*
* @see MCP_Tool_Handler
*/
std::string get_handler_name() const override { return "rag"; }
/**
* @brief Get list of available tools
*
* Returns a comprehensive list of all available RAG tools with their
* input schemas and descriptions. Tools include:
* - rag.search_fts: Keyword search using FTS5
* - rag.search_vector: Semantic search using vector embeddings
* - rag.search_hybrid: Hybrid search combining FTS and vectors
* - rag.get_chunks: Fetch chunk content by chunk_id
* - rag.get_docs: Fetch document content by doc_id
* - rag.fetch_from_source: Refetch authoritative data from source
* - rag.admin.stats: Operational statistics
*
* @return JSON object containing tool definitions and schemas
*
* @see get_tool_description()
* @see execute_tool()
*/
json get_tool_list() override;
/**
* @brief Get description of a specific tool
*
* Returns the schema and description for a specific RAG tool.
*
* @param tool_name Name of the tool to describe
* @return JSON object with tool description or error response
*
* @see get_tool_list()
* @see execute_tool()
*/
json get_tool_description(const std::string& tool_name) override;
/**
* @brief Execute a tool with arguments
*
* Executes the specified RAG tool with the provided arguments. Handles
* input validation, parameter processing, database queries, and result
* formatting according to MCP specifications.
*
* Supported tools:
* - rag.search_fts: Full-text search over documents
* - rag.search_vector: Vector similarity search
* - rag.search_hybrid: Hybrid search with two modes (fuse, fts_then_vec)
* - rag.get_chunks: Retrieve chunk content by ID
* - rag.get_docs: Retrieve document content by ID
* - rag.fetch_from_source: Refetch data from authoritative source
* - rag.admin.stats: Get operational statistics
*
* @param tool_name Name of the tool to execute
* @param arguments JSON object containing tool arguments
* @return JSON response with results or error information
*
* @see get_tool_list()
* @see get_tool_description()
*/
json execute_tool(const std::string& tool_name, const json& arguments) override;
/**
* @brief Set the vector database
*
* Sets the vector database connection for this tool handler.
*
* @param db Pointer to SQLite3DB vector database
*
* @see vector_db
* @see init()
*/
void set_vector_db(SQLite3DB* db) { vector_db = db; }
/**
* @brief Get tool usage statistics (thread-safe copy)
* @return ToolUsageStatsMap copy with endpoint -> tool_name -> schema_name -> ToolUsageStats
*/
ToolUsageStatsMap get_tool_usage_stats();
/**
* @brief Get tool usage statistics as SQLite3_result* with optional reset
* @param reset If true, resets internal counters after capturing data
* @return SQLite3_result* with columns: endpoint, tool, schema, count, first_seen, last_seen, sum_time, min_time, max_time. Caller must delete.
*/
SQLite3_result* get_tool_usage_stats_resultset(bool reset = false);
};
#endif /* CLASS_RAG_TOOL_HANDLER_H */