You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/lib/RAG_Tool_Handler.cpp

2590 lines
88 KiB

#ifdef PROXYSQLGENAI
#include "proxysql.h"
/**
* @file RAG_Tool_Handler.cpp
* @brief Implementation of RAG Tool Handler for MCP protocol
*
* Implements RAG-powered tools through MCP protocol for retrieval operations.
* This file contains the complete implementation of all RAG functionality
* including search, fetch, and administrative tools.
*
* The RAG subsystem provides:
* - Full-text search using SQLite FTS5
* - Semantic search using vector embeddings with sqlite3-vec
* - Hybrid search combining both approaches with Reciprocal Rank Fusion
* - Comprehensive filtering capabilities
* - Security features including input validation and limits
* - Performance optimizations
*
* @see RAG_Tool_Handler.h
* @ingroup mcp
* @ingroup rag
*/
#include "RAG_Tool_Handler.h"
#include "AI_Features_Manager.h"
#include "Discovery_Schema.h"
#include "GenAI_Thread.h"
#include "LLM_Bridge.h"
#include "proxysql_debug.h"
#include "cpp.h"
#include <sstream>
#include <algorithm>
#include <chrono>
#include <vector>
#include <utility>
// Forward declaration for GloGATH
extern GenAI_Threads_Handler *GloGATH;
// JSON library
#include "../deps/json/json.hpp"
using json = nlohmann::json;
#define PROXYJSON
// Forward declaration for GloGATH
extern GenAI_Threads_Handler *GloGATH;
// ============================================================================
// Tool Invocation Tracking
// ============================================================================
/**
* @brief Track tool invocation (thread-safe)
*/
void track_tool_invocation(
RAG_Tool_Handler* handler,
const std::string& endpoint,
const std::string& tool_name,
const std::string& schema_name,
unsigned long long duration_us
) {
pthread_mutex_lock(&handler->counters_lock);
handler->tool_usage_stats[endpoint][tool_name][schema_name].add_timing(duration_us, monotonic_time());
pthread_mutex_unlock(&handler->counters_lock);
}
// ============================================================================
// Constructor/Destructor
// ============================================================================
/**
* @brief Constructor
*
* Initializes the RAG tool handler with configuration parameters from GenAI_Thread
* if available, otherwise uses default values.
*
* Configuration parameters:
* - k_max: Maximum number of search results (default: 50)
* - candidates_max: Maximum number of candidates for hybrid search (default: 500)
* - query_max_bytes: Maximum query length in bytes (default: 8192)
* - response_max_bytes: Maximum response size in bytes (default: 5000000)
* - timeout_ms: Operation timeout in milliseconds (default: 2000)
*
* @param ai_mgr Pointer to AI_Features_Manager for database access and configuration
*
* @see AI_Features_Manager
* @see GenAI_Thread
*/
RAG_Tool_Handler::RAG_Tool_Handler(AI_Features_Manager* ai_mgr, const std::string& cat_path)
: vector_db(NULL),
ai_manager(ai_mgr),
catalog(NULL),
catalog_path(cat_path),
k_max(50),
candidates_max(500),
query_max_bytes(8192),
response_max_bytes(5000000),
timeout_ms(2000)
{
// Initialize configuration from GenAI_Thread if available
if (ai_manager && GloGATH) {
k_max = GloGATH->variables.genai_rag_k_max;
candidates_max = GloGATH->variables.genai_rag_candidates_max;
query_max_bytes = GloGATH->variables.genai_rag_query_max_bytes;
response_max_bytes = GloGATH->variables.genai_rag_response_max_bytes;
timeout_ms = GloGATH->variables.genai_rag_timeout_ms;
}
// Initialize counters mutex
pthread_mutex_init(&counters_lock, NULL);
proxy_debug(PROXY_DEBUG_GENAI, 3, "RAG_Tool_Handler created\n");
}
/**
* @brief Destructor
*
* Cleans up resources and closes database connections.
*
* @see close()
*/
RAG_Tool_Handler::~RAG_Tool_Handler() {
close();
pthread_mutex_destroy(&counters_lock);
proxy_debug(PROXY_DEBUG_GENAI, 3, "RAG_Tool_Handler destroyed\n");
}
// ============================================================================
// Lifecycle
// ============================================================================
/**
* @brief Initialize the tool handler
*
* Initializes the RAG tool handler by establishing database connections
* and preparing internal state. Must be called before executing any tools.
*
* @return 0 on success, -1 on error
*
* @see close()
* @see vector_db
* @see ai_manager
*/
int RAG_Tool_Handler::init() {
if (ai_manager) {
vector_db = ai_manager->get_vector_db();
}
if (!vector_db) {
proxy_error("RAG_Tool_Handler: Vector database not available\n");
return -1;
}
// Initialize catalog for logging if path is provided
if (!catalog_path.empty()) {
catalog = new Discovery_Schema(catalog_path);
if (catalog->init() != 0) {
proxy_error("RAG_Tool_Handler: Failed to initialize catalog at %s\n", catalog_path.c_str());
delete catalog;
catalog = NULL;
// Continue without catalog - logging will be skipped
} else {
proxy_info("RAG_Tool_Handler: Catalog initialized for logging\n");
}
}
proxy_info("RAG_Tool_Handler initialized\n");
return 0;
}
/**
* @brief Close and cleanup
*
* Cleans up resources and closes database connections. Called automatically
* by the destructor.
*
* @see init()
* @see ~RAG_Tool_Handler()
*/
void RAG_Tool_Handler::close() {
if (catalog) {
delete catalog;
catalog = NULL;
}
}
// ============================================================================
// Helper Functions
// ============================================================================
/**
* @brief Extract string parameter from JSON
*
* Safely extracts a string parameter from a JSON object, handling type
* conversion if necessary. Returns the default value if the key is not
* found or cannot be converted to a string.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @param default_val Default value if key not found
* @return Extracted string value or default
*
* @see get_json_int()
* @see get_json_bool()
* @see get_json_string_array()
* @see get_json_int_array()
*/
std::string RAG_Tool_Handler::get_json_string(const json& j, const std::string& key,
const std::string& default_val) {
if (j.contains(key) && !j[key].is_null()) {
if (j[key].is_string()) {
return j[key].get<std::string>();
} else {
// Convert to string if not already
return j[key].dump();
}
}
return default_val;
}
/**
* @brief Extract int parameter from JSON
*
* Safely extracts an integer parameter from a JSON object, handling type
* conversion from string if necessary. Returns the default value if the
* key is not found or cannot be converted to an integer.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @param default_val Default value if key not found
* @return Extracted int value or default
*
* @see get_json_string()
* @see get_json_bool()
* @see get_json_string_array()
* @see get_json_int_array()
*/
int RAG_Tool_Handler::get_json_int(const json& j, const std::string& key, int default_val) {
if (j.contains(key) && !j[key].is_null()) {
if (j[key].is_number()) {
return j[key].get<int>();
} else if (j[key].is_string()) {
try {
return std::stoi(j[key].get<std::string>());
} catch (const std::exception& e) {
proxy_error("RAG_Tool_Handler: Failed to convert string to int for key '%s': %s\n",
key.c_str(), e.what());
return default_val;
}
}
}
return default_val;
}
/**
* @brief Extract bool parameter from JSON
*
* Safely extracts a boolean parameter from a JSON object, handling type
* conversion from string or integer if necessary. Returns the default
* value if the key is not found or cannot be converted to a boolean.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @param default_val Default value if key not found
* @return Extracted bool value or default
*
* @see get_json_string()
* @see get_json_int()
* @see get_json_string_array()
* @see get_json_int_array()
*/
bool RAG_Tool_Handler::get_json_bool(const json& j, const std::string& key, bool default_val) {
if (j.contains(key) && !j[key].is_null()) {
if (j[key].is_boolean()) {
return j[key].get<bool>();
} else if (j[key].is_string()) {
std::string val = j[key].get<std::string>();
return (val == "true" || val == "1");
} else if (j[key].is_number()) {
return j[key].get<int>() != 0;
}
}
return default_val;
}
/**
* @brief Extract string array from JSON
*
* Safely extracts a string array parameter from a JSON object, filtering
* out non-string elements. Returns an empty vector if the key is not
* found or is not an array.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @return Vector of extracted strings
*
* @see get_json_string()
* @see get_json_int()
* @see get_json_bool()
* @see get_json_int_array()
*/
std::vector<std::string> RAG_Tool_Handler::get_json_string_array(const json& j, const std::string& key) {
std::vector<std::string> result;
if (j.contains(key) && j[key].is_array()) {
for (const auto& item : j[key]) {
if (item.is_string()) {
result.push_back(item.get<std::string>());
}
}
}
return result;
}
/**
* @brief Extract int array from JSON
*
* Safely extracts an integer array parameter from a JSON object, handling
* type conversion from string if necessary. Returns an empty vector if
* the key is not found or is not an array.
*
* @param j JSON object to extract from
* @param key Parameter key to extract
* @return Vector of extracted integers
*
* @see get_json_string()
* @see get_json_int()
* @see get_json_bool()
* @see get_json_string_array()
*/
std::vector<int> RAG_Tool_Handler::get_json_int_array(const json& j, const std::string& key) {
std::vector<int> result;
if (j.contains(key) && j[key].is_array()) {
for (const auto& item : j[key]) {
if (item.is_number()) {
result.push_back(item.get<int>());
} else if (item.is_string()) {
try {
result.push_back(std::stoi(item.get<std::string>()));
} catch (const std::exception& e) {
proxy_error("RAG_Tool_Handler: Failed to convert string to int in array: %s\n", e.what());
}
}
}
}
return result;
}
/**
* @brief Validate and limit k parameter
*
* Ensures the k parameter is within acceptable bounds (1 to k_max).
* Returns default value of 10 if k is invalid.
*
* @param k Requested number of results
* @return Validated k value within configured limits
*
* @see validate_candidates()
* @see k_max
*/
int RAG_Tool_Handler::validate_k(int k) {
if (k <= 0) return 10; // Default
if (k > k_max) return k_max;
return k;
}
/**
* @brief Validate and limit candidates parameter
*
* Ensures the candidates parameter is within acceptable bounds (1 to candidates_max).
* Returns default value of 50 if candidates is invalid.
*
* @param candidates Requested number of candidates
* @return Validated candidates value within configured limits
*
* @see validate_k()
* @see candidates_max
*/
int RAG_Tool_Handler::validate_candidates(int candidates) {
if (candidates <= 0) return 50; // Default
if (candidates > candidates_max) return candidates_max;
return candidates;
}
/**
* @brief Validate query length
*
* Checks if the query string length is within the configured query_max_bytes limit.
*
* @param query Query string to validate
* @return true if query is within length limits, false otherwise
*
* @see query_max_bytes
*/
bool RAG_Tool_Handler::validate_query_length(const std::string& query) {
return static_cast<int>(query.length()) <= query_max_bytes;
}
/**
* @brief Escape FTS query string for safe use in MATCH clause
*
* Escapes single quotes in FTS query strings by doubling them,
* which is the standard escaping method for SQLite FTS5.
* This prevents FTS injection while allowing legitimate single quotes in queries.
*
* @param query Raw FTS query string from user input
* @return Escaped query string safe for use in MATCH clause
*
* @see execute_tool()
*/
std::string RAG_Tool_Handler::escape_fts_query(const std::string& query) {
std::string escaped;
escaped.reserve(query.length() * 2); // Reserve space for potential escaping
for (char c : query) {
if (c == '\'') {
escaped += "''"; // Escape single quote by doubling
} else {
escaped += c;
}
}
return escaped;
}
/**
* @brief Execute database query and return results
*
* Executes a SQL query against the vector database and returns the results.
* Handles error checking and logging. The caller is responsible for freeing
* the returned SQLite3_result.
*
* @param query SQL query string to execute
* @return SQLite3_result pointer or NULL on error
*
* @see vector_db
*/
SQLite3_result* RAG_Tool_Handler::execute_query(const char* query) {
if (!vector_db) {
proxy_error("RAG_Tool_Handler: Vector database not available\n");
return NULL;
}
char* error = NULL;
int cols = 0;
int affected_rows = 0;
SQLite3_result* result = vector_db->execute_statement(query, &error, &cols, &affected_rows);
if (error) {
proxy_error("RAG_Tool_Handler: SQL error: %s\n", error);
(*proxy_sqlite3_free)(error);
return NULL;
}
return result;
}
/**
* @brief Execute parameterized database query with bindings
*
* Executes a parameterized SQL query against the vector database with bound parameters
* and returns the results. This prevents SQL injection vulnerabilities.
* Handles error checking and logging. The caller is responsible for freeing
* the returned SQLite3_result.
*
* @param query SQL query string with placeholders to execute
* @param text_bindings Vector of text parameter bindings (position, value)
* @param int_bindings Vector of integer parameter bindings (position, value)
* @return SQLite3_result pointer or NULL on error
*
* @see vector_db
*/
SQLite3_result* RAG_Tool_Handler::execute_parameterized_query(const char* query, const std::vector<std::pair<int, std::string>>& text_bindings, const std::vector<std::pair<int, int>>& int_bindings) {
if (!vector_db) {
proxy_error("RAG_Tool_Handler: Vector database not available\n");
return NULL;
}
// Prepare the statement
auto prepare_result = vector_db->prepare_v2(query);
if (prepare_result.first != SQLITE_OK) {
proxy_error("RAG_Tool_Handler: Failed to prepare statement: %s\n", (*proxy_sqlite3_errstr)(prepare_result.first));
return NULL;
}
sqlite3_stmt* stmt = prepare_result.second.get();
if (!stmt) {
proxy_error("RAG_Tool_Handler: Prepared statement is NULL\n");
return NULL;
}
// Bind text parameters
for (const auto& binding : text_bindings) {
int position = binding.first;
const std::string& value = binding.second;
int result = (*proxy_sqlite3_bind_text)(stmt, position, value.c_str(), -1, SQLITE_STATIC);
if (result != SQLITE_OK) {
proxy_error("RAG_Tool_Handler: Failed to bind text parameter at position %d: %s\n", position, (*proxy_sqlite3_errstr)(result));
return NULL;
}
}
// Bind integer parameters
for (const auto& binding : int_bindings) {
int position = binding.first;
int value = binding.second;
int result = (*proxy_sqlite3_bind_int)(stmt, position, value);
if (result != SQLITE_OK) {
proxy_error("RAG_Tool_Handler: Failed to bind integer parameter at position %d: %s\n", position, (*proxy_sqlite3_errstr)(result));
return NULL;
}
}
// Execute the prepared statement and get results
char* error = NULL;
int cols = 0;
int affected_rows = 0;
SQLite3_result* result = NULL;
// Use execute_prepared to execute the bound statement, not the raw query
if (!vector_db->execute_prepared(stmt, &error, &cols, &affected_rows, &result)) {
if (error) {
proxy_error("RAG_Tool_Handler: SQL error: %s\n", error);
(*proxy_sqlite3_free)(error);
}
return NULL;
}
return result;
}
/**
* @brief Build SQL filter conditions from JSON filters
*
* Builds SQL WHERE conditions from JSON filter parameters with proper input validation
* to prevent SQL injection. This consolidates the duplicated filter building logic
* across different search tools.
*
* @param filters JSON object containing filter parameters
* @param sql Reference to SQL string to append conditions to
* * @return true on success, false on validation error
*
* @see execute_tool()
*/
bool RAG_Tool_Handler::build_sql_filters(const json& filters, std::string& sql, bool add_where_clause) {
// Add WHERE clause base for filter conditions if requested
if (add_where_clause) {
sql += " WHERE 1=1";
}
// Apply filters with input validation to prevent SQL injection
if (filters.contains("source_ids") && filters["source_ids"].is_array()) {
std::vector<int> source_ids = get_json_int_array(filters, "source_ids");
if (!source_ids.empty()) {
// Validate that all source_ids are integers (they should be by definition)
std::string source_list = "";
for (size_t i = 0; i < source_ids.size(); ++i) {
if (i > 0) source_list += ",";
source_list += std::to_string(source_ids[i]);
}
sql += " AND c.source_id IN (" + source_list + ")";
}
}
if (filters.contains("source_names") && filters["source_names"].is_array()) {
std::vector<std::string> source_names = get_json_string_array(filters, "source_names");
if (!source_names.empty()) {
// Validate source names to prevent SQL injection
std::string source_list = "";
for (size_t i = 0; i < source_names.size(); ++i) {
const std::string& source_name = source_names[i];
// Basic validation - check for dangerous characters
if (source_name.find('\'') != std::string::npos ||
source_name.find('\\') != std::string::npos ||
source_name.find(';') != std::string::npos) {
return false;
}
if (i > 0) source_list += ",";
source_list += "'" + source_name + "'";
}
sql += " AND c.source_id IN (SELECT source_id FROM rag_sources WHERE name IN (" + source_list + "))";
}
}
if (filters.contains("doc_ids") && filters["doc_ids"].is_array()) {
std::vector<std::string> doc_ids = get_json_string_array(filters, "doc_ids");
if (!doc_ids.empty()) {
// Validate doc_ids to prevent SQL injection
std::string doc_list = "";
for (size_t i = 0; i < doc_ids.size(); ++i) {
const std::string& doc_id = doc_ids[i];
// Basic validation - check for dangerous characters
if (doc_id.find('\'') != std::string::npos ||
doc_id.find('\\') != std::string::npos ||
doc_id.find(';') != std::string::npos) {
return false;
}
if (i > 0) doc_list += ",";
doc_list += "'" + doc_id + "'";
}
sql += " AND c.doc_id IN (" + doc_list + ")";
}
}
// Metadata filters
if (filters.contains("post_type_ids") && filters["post_type_ids"].is_array()) {
std::vector<int> post_type_ids = get_json_int_array(filters, "post_type_ids");
if (!post_type_ids.empty()) {
// Validate that all post_type_ids are integers
std::string post_type_conditions = "";
for (size_t i = 0; i < post_type_ids.size(); ++i) {
if (i > 0) post_type_conditions += " OR ";
post_type_conditions += "json_extract(d.metadata_json, '$.PostTypeId') = " + std::to_string(post_type_ids[i]);
}
sql += " AND (" + post_type_conditions + ")";
}
}
if (filters.contains("tags_any") && filters["tags_any"].is_array()) {
std::vector<std::string> tags_any = get_json_string_array(filters, "tags_any");
if (!tags_any.empty()) {
// Validate tags to prevent SQL injection
std::string tag_conditions = "";
for (size_t i = 0; i < tags_any.size(); ++i) {
const std::string& tag = tags_any[i];
// Basic validation - check for dangerous characters
if (tag.find('\'') != std::string::npos ||
tag.find('\\') != std::string::npos ||
tag.find(';') != std::string::npos) {
return false;
}
if (i > 0) tag_conditions += " OR ";
// Escape the tag for LIKE pattern matching
std::string escaped_tag = tag;
// Simple escaping - replace special characters
size_t pos = 0;
while ((pos = escaped_tag.find("'", pos)) != std::string::npos) {
escaped_tag.replace(pos, 1, "''");
pos += 2;
}
tag_conditions += "json_extract(d.metadata_json, '$.Tags') LIKE '%<" + escaped_tag + ">%' ESCAPE '\\'";
}
sql += " AND (" + tag_conditions + ")";
}
}
if (filters.contains("tags_all") && filters["tags_all"].is_array()) {
std::vector<std::string> tags_all = get_json_string_array(filters, "tags_all");
if (!tags_all.empty()) {
// Validate tags to prevent SQL injection
std::string tag_conditions = "";
for (size_t i = 0; i < tags_all.size(); ++i) {
const std::string& tag = tags_all[i];
// Basic validation - check for dangerous characters
if (tag.find('\'') != std::string::npos ||
tag.find('\\') != std::string::npos ||
tag.find(';') != std::string::npos) {
return false;
}
if (i > 0) tag_conditions += " AND ";
// Escape the tag for LIKE pattern matching
std::string escaped_tag = tag;
// Simple escaping - replace special characters
size_t pos = 0;
while ((pos = escaped_tag.find("'", pos)) != std::string::npos) {
escaped_tag.replace(pos, 1, "''");
pos += 2;
}
tag_conditions += "json_extract(d.metadata_json, '$.Tags') LIKE '%<" + escaped_tag + ">%' ESCAPE '\\'";
}
sql += " AND (" + tag_conditions + ")";
}
}
if (filters.contains("created_after") && filters["created_after"].is_string()) {
std::string created_after = filters["created_after"].get<std::string>();
// Validate date format to prevent SQL injection
if (created_after.find('\'') != std::string::npos ||
created_after.find('\\') != std::string::npos ||
created_after.find(';') != std::string::npos) {
return false;
}
// Filter by CreationDate in metadata_json
sql += " AND json_extract(d.metadata_json, '$.CreationDate') >= '" + created_after + "'";
}
if (filters.contains("created_before") && filters["created_before"].is_string()) {
std::string created_before = filters["created_before"].get<std::string>();
// Validate date format to prevent SQL injection
if (created_before.find('\'') != std::string::npos ||
created_before.find('\\') != std::string::npos ||
created_before.find(';') != std::string::npos) {
return false;
}
// Filter by CreationDate in metadata_json
sql += " AND json_extract(d.metadata_json, '$.CreationDate') <= '" + created_before + "'";
}
return true;
}
/**
* @brief Compute Reciprocal Rank Fusion score
*
* Computes the Reciprocal Rank Fusion score for hybrid search ranking.
* Formula: weight / (k0 + rank)
*
* @param rank Rank position (1-based)
* @param k0 Smoothing parameter
* @param weight Weight factor for this ranking
* @return RRF score
*
* @see rag.search_hybrid
*/
double RAG_Tool_Handler::compute_rrf_score(int rank, int k0, double weight) {
if (rank <= 0) return 0.0;
return weight / (k0 + rank);
}
/**
* @brief Normalize scores to 0-1 range (higher is better)
*
* Normalizes various types of scores to a consistent 0-1 range where
* higher values indicate better matches. Different score types may
* require different normalization approaches.
*
* @param score Raw score to normalize
* @param score_type Type of score being normalized
* @return Normalized score in 0-1 range
*/
double RAG_Tool_Handler::normalize_score(double score, const std::string& score_type) {
// For now, return the score as-is
// In the future, we might want to normalize different score types differently
return score;
}
// ============================================================================
// Tool List
// ============================================================================
/**
* @brief Get list of available RAG tools
*
* Returns a comprehensive list of all available RAG tools with their
* input schemas and descriptions. Tools include:
* - rag.search_fts: Keyword search using FTS5
* - rag.search_vector: Semantic search using vector embeddings
* - rag.search_hybrid: Hybrid search combining FTS and vectors
* - rag.get_chunks: Fetch chunk content by chunk_id
* - rag.get_docs: Fetch document content by doc_id
* - rag.fetch_from_source: Refetch authoritative data from source
* - rag.admin.stats: Operational statistics
*
* @return JSON object containing tool definitions and schemas
*
* @see get_tool_description()
* @see execute_tool()
*/
json RAG_Tool_Handler::get_tool_list() {
json tools = json::array();
// FTS search tool
json fts_params = json::object();
fts_params["type"] = "object";
fts_params["properties"] = json::object();
fts_params["properties"]["query"] = {
{"type", "string"},
{"description", "Keyword search query"}
};
fts_params["properties"]["k"] = {
{"type", "integer"},
{"description", "Number of results to return (default: 10, max: 50)"}
};
fts_params["properties"]["offset"] = {
{"type", "integer"},
{"description", "Offset for pagination (default: 0)"}
};
// Filters object
json filters_obj = json::object();
filters_obj["type"] = "object";
filters_obj["properties"] = json::object();
filters_obj["properties"]["source_ids"] = {
{"type", "array"},
{"items", {{"type", "integer"}}},
{"description", "Filter by source IDs"}
};
filters_obj["properties"]["source_names"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "Filter by source names"}
};
filters_obj["properties"]["doc_ids"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "Filter by document IDs"}
};
filters_obj["properties"]["min_score"] = {
{"type", "number"},
{"description", "Minimum score threshold"}
};
filters_obj["properties"]["post_type_ids"] = {
{"type", "array"},
{"items", {{"type", "integer"}}},
{"description", "Filter by post type IDs"}
};
filters_obj["properties"]["tags_any"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "Filter by any of these tags"}
};
filters_obj["properties"]["tags_all"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "Filter by all of these tags"}
};
filters_obj["properties"]["created_after"] = {
{"type", "string"},
{"format", "date-time"},
{"description", "Filter by creation date (after)"}
};
filters_obj["properties"]["created_before"] = {
{"type", "string"},
{"format", "date-time"},
{"description", "Filter by creation date (before)"}
};
fts_params["properties"]["filters"] = filters_obj;
// Return object
json return_obj = json::object();
return_obj["type"] = "object";
return_obj["properties"] = json::object();
return_obj["properties"]["include_title"] = {
{"type", "boolean"},
{"description", "Include title in results (default: true)"}
};
return_obj["properties"]["include_metadata"] = {
{"type", "boolean"},
{"description", "Include metadata in results (default: true)"}
};
return_obj["properties"]["include_snippets"] = {
{"type", "boolean"},
{"description", "Include snippets in results (default: false)"}
};
fts_params["properties"]["return"] = return_obj;
fts_params["required"] = json::array({"query"});
tools.push_back({
{"name", "rag.search_fts"},
{"description", "Keyword search over documents using FTS5"},
{"inputSchema", fts_params}
});
// Vector search tool
json vec_params = json::object();
vec_params["type"] = "object";
vec_params["properties"] = json::object();
vec_params["properties"]["query_text"] = {
{"type", "string"},
{"description", "Text to search semantically"}
};
vec_params["properties"]["k"] = {
{"type", "integer"},
{"description", "Number of results to return (default: 10, max: 50)"}
};
// Filters object (same as FTS)
vec_params["properties"]["filters"] = filters_obj;
// Return object (same as FTS)
vec_params["properties"]["return"] = return_obj;
// Embedding object for precomputed vectors
json embedding_obj = json::object();
embedding_obj["type"] = "object";
embedding_obj["properties"] = json::object();
embedding_obj["properties"]["model"] = {
{"type", "string"},
{"description", "Embedding model to use"}
};
vec_params["properties"]["embedding"] = embedding_obj;
// Query embedding object for precomputed vectors
json query_embedding_obj = json::object();
query_embedding_obj["type"] = "object";
query_embedding_obj["properties"] = json::object();
query_embedding_obj["properties"]["dim"] = {
{"type", "integer"},
{"description", "Dimension of the embedding"}
};
query_embedding_obj["properties"]["values_b64"] = {
{"type", "string"},
{"description", "Base64 encoded float32 array"}
};
vec_params["properties"]["query_embedding"] = query_embedding_obj;
vec_params["required"] = json::array({"query_text"});
tools.push_back({
{"name", "rag.search_vector"},
{"description", "Semantic search over documents using vector embeddings"},
{"inputSchema", vec_params}
});
// Hybrid search tool
json hybrid_params = json::object();
hybrid_params["type"] = "object";
hybrid_params["properties"] = json::object();
hybrid_params["properties"]["query"] = {
{"type", "string"},
{"description", "Search query for both FTS and vector"}
};
hybrid_params["properties"]["k"] = {
{"type", "integer"},
{"description", "Number of results to return (default: 10, max: 50)"}
};
hybrid_params["properties"]["mode"] = {
{"type", "string"},
{"description", "Search mode: 'fuse' or 'fts_then_vec'"}
};
// Filters object (same as FTS and vector)
hybrid_params["properties"]["filters"] = filters_obj;
// Fuse object for mode "fuse"
json fuse_obj = json::object();
fuse_obj["type"] = "object";
fuse_obj["properties"] = json::object();
fuse_obj["properties"]["fts_k"] = {
{"type", "integer"},
{"description", "Number of FTS results to retrieve for fusion (default: 50)"}
};
fuse_obj["properties"]["vec_k"] = {
{"type", "integer"},
{"description", "Number of vector results to retrieve for fusion (default: 50)"}
};
fuse_obj["properties"]["rrf_k0"] = {
{"type", "integer"},
{"description", "RRF smoothing parameter (default: 60)"}
};
fuse_obj["properties"]["w_fts"] = {
{"type", "number"},
{"description", "Weight for FTS scores in fusion (default: 1.0)"}
};
fuse_obj["properties"]["w_vec"] = {
{"type", "number"},
{"description", "Weight for vector scores in fusion (default: 1.0)"}
};
hybrid_params["properties"]["fuse"] = fuse_obj;
// Fts_then_vec object for mode "fts_then_vec"
json fts_then_vec_obj = json::object();
fts_then_vec_obj["type"] = "object";
fts_then_vec_obj["properties"] = json::object();
fts_then_vec_obj["properties"]["candidates_k"] = {
{"type", "integer"},
{"description", "Number of FTS candidates to generate (default: 200)"}
};
fts_then_vec_obj["properties"]["rerank_k"] = {
{"type", "integer"},
{"description", "Number of candidates to rerank with vector search (default: 50)"}
};
fts_then_vec_obj["properties"]["vec_metric"] = {
{"type", "string"},
{"description", "Vector similarity metric (default: 'cosine')"}
};
hybrid_params["properties"]["fts_then_vec"] = fts_then_vec_obj;
hybrid_params["required"] = json::array({"query"});
tools.push_back({
{"name", "rag.search_hybrid"},
{"description", "Hybrid search combining FTS and vector"},
{"inputSchema", hybrid_params}
});
// Get chunks tool
json chunks_params = json::object();
chunks_params["type"] = "object";
chunks_params["properties"] = json::object();
chunks_params["properties"]["chunk_ids"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "List of chunk IDs to fetch"}
};
json return_params = json::object();
return_params["type"] = "object";
return_params["properties"] = json::object();
return_params["properties"]["include_title"] = {
{"type", "boolean"},
{"description", "Include title in response (default: true)"}
};
return_params["properties"]["include_doc_metadata"] = {
{"type", "boolean"},
{"description", "Include document metadata in response (default: true)"}
};
return_params["properties"]["include_chunk_metadata"] = {
{"type", "boolean"},
{"description", "Include chunk metadata in response (default: true)"}
};
chunks_params["properties"]["return"] = return_params;
chunks_params["required"] = json::array({"chunk_ids"});
tools.push_back({
{"name", "rag.get_chunks"},
{"description", "Fetch chunk content by chunk_id"},
{"inputSchema", chunks_params}
});
// Get docs tool
json docs_params = json::object();
docs_params["type"] = "object";
docs_params["properties"] = json::object();
docs_params["properties"]["doc_ids"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "List of document IDs to fetch"}
};
json docs_return_params = json::object();
docs_return_params["type"] = "object";
docs_return_params["properties"] = json::object();
docs_return_params["properties"]["include_body"] = {
{"type", "boolean"},
{"description", "Include body in response (default: true)"}
};
docs_return_params["properties"]["include_metadata"] = {
{"type", "boolean"},
{"description", "Include metadata in response (default: true)"}
};
docs_params["properties"]["return"] = docs_return_params;
docs_params["required"] = json::array({"doc_ids"});
tools.push_back({
{"name", "rag.get_docs"},
{"description", "Fetch document content by doc_id"},
{"inputSchema", docs_params}
});
// Fetch from source tool
json fetch_params = json::object();
fetch_params["type"] = "object";
fetch_params["properties"] = json::object();
fetch_params["properties"]["doc_ids"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "List of document IDs to refetch"}
};
fetch_params["properties"]["columns"] = {
{"type", "array"},
{"items", {{"type", "string"}}},
{"description", "List of columns to fetch"}
};
// Limits object
json limits_obj = json::object();
limits_obj["type"] = "object";
limits_obj["properties"] = json::object();
limits_obj["properties"]["max_rows"] = {
{"type", "integer"},
{"description", "Maximum number of rows to return (default: 10, max: 100)"}
};
limits_obj["properties"]["max_bytes"] = {
{"type", "integer"},
{"description", "Maximum number of bytes to return (default: 200000, max: 1000000)"}
};
fetch_params["properties"]["limits"] = limits_obj;
fetch_params["required"] = json::array({"doc_ids"});
tools.push_back({
{"name", "rag.fetch_from_source"},
{"description", "Refetch authoritative data from source database"},
{"inputSchema", fetch_params}
});
// Admin stats tool
json stats_params = json::object();
stats_params["type"] = "object";
stats_params["properties"] = json::object();
tools.push_back({
{"name", "rag.admin.stats"},
{"description", "Get operational statistics for RAG system"},
{"inputSchema", stats_params}
});
json result;
result["tools"] = tools;
return result;
}
/**
* @brief Get description of a specific tool
*
* Returns the schema and description for a specific RAG tool.
*
* @param tool_name Name of the tool to describe
* @return JSON object with tool description or error response
*
* @see get_tool_list()
* @see execute_tool()
*/
json RAG_Tool_Handler::get_tool_description(const std::string& tool_name) {
json tools_list = get_tool_list();
for (const auto& tool : tools_list["tools"]) {
if (tool["name"] == tool_name) {
return tool;
}
}
return create_error_response("Tool not found: " + tool_name);
}
// ============================================================================
// Tool Execution
// ============================================================================
/**
* @brief Execute a RAG tool
*
* Executes the specified RAG tool with the provided arguments. Handles
* input validation, parameter processing, database queries, and result
* formatting according to MCP specifications.
*
* Supported tools:
* - rag.search_fts: Full-text search over documents
* - rag.search_vector: Vector similarity search
* - rag.search_hybrid: Hybrid search with two modes (fuse, fts_then_vec)
* - rag.get_chunks: Retrieve chunk content by ID
* - rag.get_docs: Retrieve document content by ID
* - rag.fetch_from_source: Refetch data from authoritative source
* - rag.admin.stats: Get operational statistics
*
* @param tool_name Name of the tool to execute
* @param arguments JSON object containing tool arguments
* @return JSON response with results or error information
*
* @see get_tool_list()
* @see get_tool_description()
*/
json RAG_Tool_Handler::execute_tool(const std::string& tool_name, const json& arguments) {
proxy_debug(PROXY_DEBUG_GENAI, 3, "RAG_Tool_Handler: execute_tool(%s)\n", tool_name.c_str());
// Record start time for timing stats
auto start_time = std::chrono::high_resolution_clock::now();
try {
json result;
if (tool_name == "rag.search_fts") {
// FTS search implementation
std::string query = get_json_string(arguments, "query");
int k = validate_k(get_json_int(arguments, "k", 10));
int offset = get_json_int(arguments, "offset", 0);
// Get filters
json filters = json::object();
if (arguments.contains("filters") && arguments["filters"].is_object()) {
filters = arguments["filters"];
// Validate filter parameters
if (filters.contains("source_ids") && !filters["source_ids"].is_array()) {
return create_error_response("Invalid source_ids filter: must be an array of integers");
}
if (filters.contains("source_names") && !filters["source_names"].is_array()) {
return create_error_response("Invalid source_names filter: must be an array of strings");
}
if (filters.contains("doc_ids") && !filters["doc_ids"].is_array()) {
return create_error_response("Invalid doc_ids filter: must be an array of strings");
}
if (filters.contains("post_type_ids") && !filters["post_type_ids"].is_array()) {
return create_error_response("Invalid post_type_ids filter: must be an array of integers");
}
if (filters.contains("tags_any") && !filters["tags_any"].is_array()) {
return create_error_response("Invalid tags_any filter: must be an array of strings");
}
if (filters.contains("tags_all") && !filters["tags_all"].is_array()) {
return create_error_response("Invalid tags_all filter: must be an array of strings");
}
if (filters.contains("created_after") && !filters["created_after"].is_string()) {
return create_error_response("Invalid created_after filter: must be a string in ISO 8601 format");
}
if (filters.contains("created_before") && !filters["created_before"].is_string()) {
return create_error_response("Invalid created_before filter: must be a string in ISO 8601 format");
}
if (filters.contains("min_score") && !(filters["min_score"].is_number() || filters["min_score"].is_string())) {
return create_error_response("Invalid min_score filter: must be a number or numeric string");
}
}
// Get return parameters
bool include_title = true;
bool include_metadata = true;
bool include_snippets = false;
if (arguments.contains("return") && arguments["return"].is_object()) {
const json& return_params = arguments["return"];
include_title = get_json_bool(return_params, "include_title", true);
include_metadata = get_json_bool(return_params, "include_metadata", true);
include_snippets = get_json_bool(return_params, "include_snippets", false);
}
if (!validate_query_length(query)) {
return create_error_response("Query too long");
}
// Validate FTS query for SQL injection patterns
// This is a basic validation - in production, more robust validation should be used
if (query.find(';') != std::string::npos ||
query.find("--") != std::string::npos ||
query.find("/*") != std::string::npos ||
query.find("DROP") != std::string::npos ||
query.find("DELETE") != std::string::npos ||
query.find("INSERT") != std::string::npos ||
query.find("UPDATE") != std::string::npos) {
return create_error_response("Invalid characters in query");
}
// Log the RAG FTS search
if (catalog) {
std::string filters_str = filters.empty() ? "" : filters.dump();
catalog->log_rag_search_fts(query, k, filters_str);
}
// Build FTS query with filters
std::string sql = "SELECT c.chunk_id, c.doc_id, c.source_id, "
"(SELECT name FROM rag_sources WHERE source_id = c.source_id) as source_name, "
"c.title, bm25(rag_fts_chunks) as score_fts_raw, "
"c.metadata_json, c.body "
"FROM rag_fts_chunks f "
"JOIN rag_chunks c ON c.chunk_id = f.chunk_id "
"JOIN rag_documents d ON d.doc_id = c.doc_id "
"WHERE rag_fts_chunks MATCH '" + escape_fts_query(query) + "'";
// Apply filters using consolidated filter building function
if (!build_sql_filters(filters, sql, false)) {
return create_error_response("Invalid filter parameters");
}
sql += " ORDER BY score_fts_raw "
"LIMIT " + std::to_string(k) + " OFFSET " + std::to_string(offset);
SQLite3_result* db_result = execute_query(sql.c_str());
if (!db_result) {
return create_error_response("Database query failed");
}
// Build result array
json results = json::array();
double min_score = 0.0;
bool has_min_score = false;
if (filters.contains("min_score") && (filters["min_score"].is_number() || filters["min_score"].is_string())) {
min_score = filters["min_score"].is_number() ?
filters["min_score"].get<double>() :
std::stod(filters["min_score"].get<std::string>());
has_min_score = true;
}
for (const auto& row : db_result->rows) {
if (row->fields) {
json item;
item["chunk_id"] = row->fields[0] ? row->fields[0] : "";
item["doc_id"] = row->fields[1] ? row->fields[1] : "";
item["source_id"] = row->fields[2] ? std::stoi(row->fields[2]) : 0;
item["source_name"] = row->fields[3] ? row->fields[3] : "";
// Normalize FTS score (bm25 - lower is better, so we invert it)
double score_fts_raw = row->fields[5] ? std::stod(row->fields[5]) : 0.0;
// Convert to 0-1 scale where higher is better
double score_fts = 1.0 / (1.0 + std::abs(score_fts_raw));
// Apply min_score filter
if (has_min_score && score_fts < min_score) {
continue; // Skip this result
}
item["score_fts"] = score_fts;
if (include_title) {
item["title"] = row->fields[4] ? row->fields[4] : "";
}
if (include_metadata && row->fields[6]) {
try {
item["metadata"] = json::parse(row->fields[6]);
} catch (...) {
item["metadata"] = json::object();
}
}
if (include_snippets && row->fields[7]) {
// For now, just include the first 200 characters as a snippet
std::string body = row->fields[7];
if (body.length() > 200) {
item["snippet"] = body.substr(0, 200) + "...";
} else {
item["snippet"] = body;
}
}
results.push_back(item);
}
}
delete db_result;
result["results"] = results;
result["truncated"] = false;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["k_requested"] = k;
stats["k_returned"] = static_cast<int>(results.size());
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else if (tool_name == "rag.search_vector") {
// Vector search implementation
std::string query_text = get_json_string(arguments, "query_text");
int k = validate_k(get_json_int(arguments, "k", 10));
// Get filters
json filters = json::object();
if (arguments.contains("filters") && arguments["filters"].is_object()) {
filters = arguments["filters"];
// Validate filter parameters
if (filters.contains("source_ids") && !filters["source_ids"].is_array()) {
return create_error_response("Invalid source_ids filter: must be an array of integers");
}
if (filters.contains("source_names") && !filters["source_names"].is_array()) {
return create_error_response("Invalid source_names filter: must be an array of strings");
}
if (filters.contains("doc_ids") && !filters["doc_ids"].is_array()) {
return create_error_response("Invalid doc_ids filter: must be an array of strings");
}
if (filters.contains("post_type_ids") && !filters["post_type_ids"].is_array()) {
return create_error_response("Invalid post_type_ids filter: must be an array of integers");
}
if (filters.contains("tags_any") && !filters["tags_any"].is_array()) {
return create_error_response("Invalid tags_any filter: must be an array of strings");
}
if (filters.contains("tags_all") && !filters["tags_all"].is_array()) {
return create_error_response("Invalid tags_all filter: must be an array of strings");
}
if (filters.contains("created_after") && !filters["created_after"].is_string()) {
return create_error_response("Invalid created_after filter: must be a string in ISO 8601 format");
}
if (filters.contains("created_before") && !filters["created_before"].is_string()) {
return create_error_response("Invalid created_before filter: must be a string in ISO 8601 format");
}
if (filters.contains("min_score") && !(filters["min_score"].is_number() || filters["min_score"].is_string())) {
return create_error_response("Invalid min_score filter: must be a number or numeric string");
}
}
// Get return parameters
bool include_title = true;
bool include_metadata = true;
bool include_snippets = false;
if (arguments.contains("return") && arguments["return"].is_object()) {
const json& return_params = arguments["return"];
include_title = get_json_bool(return_params, "include_title", true);
include_metadata = get_json_bool(return_params, "include_metadata", true);
include_snippets = get_json_bool(return_params, "include_snippets", false);
}
if (!validate_query_length(query_text)) {
return create_error_response("Query text too long");
}
// Get embedding for query text
std::vector<float> query_embedding;
if (ai_manager && GloGATH) {
GenAI_EmbeddingResult result = GloGATH->embed_documents({query_text});
if (result.data && result.count > 0) {
// Convert to std::vector<float>
query_embedding.assign(result.data, result.data + result.embedding_size);
// Free the result data (GenAI allocates with malloc)
free(result.data);
}
}
if (query_embedding.empty()) {
return create_error_response("Failed to generate embedding for query");
}
// Convert embedding to JSON array format for sqlite-vec
std::string embedding_json = "[";
for (size_t i = 0; i < query_embedding.size(); ++i) {
if (i > 0) embedding_json += ",";
embedding_json += std::to_string(query_embedding[i]);
}
embedding_json += "]";
// Build vector search query using sqlite-vec syntax with filters
// Must use subquery approach: LIMIT must be at same query level as MATCH
std::string sql = "SELECT v.chunk_id, c.doc_id, c.source_id, "
"(SELECT name FROM rag_sources WHERE source_id = c.source_id) as source_name, "
"c.title, v.distance as score_vec_raw, "
"c.metadata_json, c.body "
"FROM ("
" SELECT chunk_id, distance "
" FROM rag_vec_chunks "
" WHERE embedding MATCH '" + escape_fts_query(embedding_json) + "' "
" ORDER BY distance "
" LIMIT " + std::to_string(k) + " "
") v "
"JOIN rag_chunks c ON c.chunk_id = v.chunk_id "
"JOIN rag_documents d ON d.doc_id = c.doc_id";
// Apply filters using consolidated filter building function
if (!build_sql_filters(filters, sql)) {
return create_error_response("Invalid filter parameters");
}
sql += " ORDER BY v.distance";
SQLite3_result* db_result = execute_query(sql.c_str());
if (!db_result) {
return create_error_response("Database query failed");
}
// Build result array
json results = json::array();
double min_score = 0.0;
bool has_min_score = false;
if (filters.contains("min_score") && (filters["min_score"].is_number() || filters["min_score"].is_string())) {
min_score = filters["min_score"].is_number() ?
filters["min_score"].get<double>() :
std::stod(filters["min_score"].get<std::string>());
has_min_score = true;
}
for (const auto& row : db_result->rows) {
if (row->fields) {
json item;
item["chunk_id"] = row->fields[0] ? row->fields[0] : "";
item["doc_id"] = row->fields[1] ? row->fields[1] : "";
item["source_id"] = row->fields[2] ? std::stoi(row->fields[2]) : 0;
item["source_name"] = row->fields[3] ? row->fields[3] : "";
// Normalize vector score (distance - lower is better, so we invert it)
double score_vec_raw = row->fields[5] ? std::stod(row->fields[5]) : 0.0;
// Convert to 0-1 scale where higher is better
double score_vec = 1.0 / (1.0 + score_vec_raw);
// Apply min_score filter
if (has_min_score && score_vec < min_score) {
continue; // Skip this result
}
item["score_vec"] = score_vec;
if (include_title) {
item["title"] = row->fields[4] ? row->fields[4] : "";
}
if (include_metadata && row->fields[6]) {
try {
item["metadata"] = json::parse(row->fields[6]);
} catch (...) {
item["metadata"] = json::object();
}
}
if (include_snippets && row->fields[7]) {
// For now, just include the first 200 characters as a snippet
std::string body = row->fields[7];
if (body.length() > 200) {
item["snippet"] = body.substr(0, 200) + "...";
} else {
item["snippet"] = body;
}
}
results.push_back(item);
}
}
delete db_result;
result["results"] = results;
result["truncated"] = false;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["k_requested"] = k;
stats["k_returned"] = static_cast<int>(results.size());
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else if (tool_name == "rag.search_hybrid") {
// Hybrid search implementation
std::string query = get_json_string(arguments, "query");
int k = validate_k(get_json_int(arguments, "k", 10));
std::string mode = get_json_string(arguments, "mode", "fuse");
// Get filters
json filters = json::object();
if (arguments.contains("filters") && arguments["filters"].is_object()) {
filters = arguments["filters"];
// Validate filter parameters
if (filters.contains("source_ids") && !filters["source_ids"].is_array()) {
return create_error_response("Invalid source_ids filter: must be an array of integers");
}
if (filters.contains("source_names") && !filters["source_names"].is_array()) {
return create_error_response("Invalid source_names filter: must be an array of strings");
}
if (filters.contains("doc_ids") && !filters["doc_ids"].is_array()) {
return create_error_response("Invalid doc_ids filter: must be an array of strings");
}
if (filters.contains("post_type_ids") && !filters["post_type_ids"].is_array()) {
return create_error_response("Invalid post_type_ids filter: must be an array of integers");
}
if (filters.contains("tags_any") && !filters["tags_any"].is_array()) {
return create_error_response("Invalid tags_any filter: must be an array of strings");
}
if (filters.contains("tags_all") && !filters["tags_all"].is_array()) {
return create_error_response("Invalid tags_all filter: must be an array of strings");
}
if (filters.contains("created_after") && !filters["created_after"].is_string()) {
return create_error_response("Invalid created_after filter: must be a string in ISO 8601 format");
}
if (filters.contains("created_before") && !filters["created_before"].is_string()) {
return create_error_response("Invalid created_before filter: must be a string in ISO 8601 format");
}
if (filters.contains("min_score") && !(filters["min_score"].is_number() || filters["min_score"].is_string())) {
return create_error_response("Invalid min_score filter: must be a number or numeric string");
}
}
if (!validate_query_length(query)) {
return create_error_response("Query too long");
}
json results = json::array();
if (mode == "fuse") {
// Mode A: parallel FTS + vector, fuse results (RRF recommended)
// Get FTS parameters from fuse object
int fts_k = 50;
int vec_k = 50;
int rrf_k0 = 60;
double w_fts = 1.0;
double w_vec = 1.0;
if (arguments.contains("fuse") && arguments["fuse"].is_object()) {
const json& fuse_params = arguments["fuse"];
fts_k = validate_k(get_json_int(fuse_params, "fts_k", 50));
vec_k = validate_k(get_json_int(fuse_params, "vec_k", 50));
rrf_k0 = get_json_int(fuse_params, "rrf_k0", 60);
w_fts = get_json_int(fuse_params, "w_fts", 1.0);
w_vec = get_json_int(fuse_params, "w_vec", 1.0);
} else {
// Fallback to top-level parameters for backward compatibility
fts_k = validate_k(get_json_int(arguments, "fts_k", 50));
vec_k = validate_k(get_json_int(arguments, "vec_k", 50));
rrf_k0 = get_json_int(arguments, "rrf_k0", 60);
w_fts = get_json_int(arguments, "w_fts", 1.0);
w_vec = get_json_int(arguments, "w_vec", 1.0);
}
// Run FTS search with filters
std::string fts_sql = "SELECT c.chunk_id, c.doc_id, c.source_id, "
"(SELECT name FROM rag_sources WHERE source_id = c.source_id) as source_name, "
"c.title, bm25(rag_fts_chunks) as score_fts_raw, "
"c.metadata_json "
"FROM rag_fts_chunks f "
"JOIN rag_chunks c ON c.chunk_id = f.chunk_id "
"JOIN rag_documents d ON d.doc_id = c.doc_id "
"WHERE rag_fts_chunks MATCH '" + escape_fts_query(query) + "'";
// Apply filters using consolidated filter building function
if (!build_sql_filters(filters, fts_sql, false)) {
return create_error_response("Invalid filter parameters");
}
if (filters.contains("source_names") && filters["source_names"].is_array()) {
std::vector<std::string> source_names = get_json_string_array(filters, "source_names");
if (!source_names.empty()) {
std::string source_list = "";
for (size_t i = 0; i < source_names.size(); ++i) {
if (i > 0) source_list += ",";
source_list += "'" + source_names[i] + "'";
}
fts_sql += " AND c.source_id IN (SELECT source_id FROM rag_sources WHERE name IN (" + source_list + "))";
}
}
if (filters.contains("doc_ids") && filters["doc_ids"].is_array()) {
std::vector<std::string> doc_ids = get_json_string_array(filters, "doc_ids");
if (!doc_ids.empty()) {
std::string doc_list = "";
for (size_t i = 0; i < doc_ids.size(); ++i) {
if (i > 0) doc_list += ",";
doc_list += "'" + doc_ids[i] + "'";
}
fts_sql += " AND c.doc_id IN (" + doc_list + ")";
}
}
// Metadata filters
if (filters.contains("post_type_ids") && filters["post_type_ids"].is_array()) {
std::vector<int> post_type_ids = get_json_int_array(filters, "post_type_ids");
if (!post_type_ids.empty()) {
// Filter by PostTypeId in metadata_json
std::string post_type_conditions = "";
for (size_t i = 0; i < post_type_ids.size(); ++i) {
if (i > 0) post_type_conditions += " OR ";
post_type_conditions += "json_extract(d.metadata_json, '$.PostTypeId') = " + std::to_string(post_type_ids[i]);
}
fts_sql += " AND (" + post_type_conditions + ")";
}
}
if (filters.contains("tags_any") && filters["tags_any"].is_array()) {
std::vector<std::string> tags_any = get_json_string_array(filters, "tags_any");
if (!tags_any.empty()) {
// Filter by any of the tags in metadata_json Tags field
std::string tag_conditions = "";
for (size_t i = 0; i < tags_any.size(); ++i) {
if (i > 0) tag_conditions += " OR ";
tag_conditions += "json_extract(d.metadata_json, '$.Tags') LIKE '%<" + tags_any[i] + ">%'";
}
fts_sql += " AND (" + tag_conditions + ")";
}
}
if (filters.contains("tags_all") && filters["tags_all"].is_array()) {
std::vector<std::string> tags_all = get_json_string_array(filters, "tags_all");
if (!tags_all.empty()) {
// Filter by all of the tags in metadata_json Tags field
std::string tag_conditions = "";
for (size_t i = 0; i < tags_all.size(); ++i) {
if (i > 0) tag_conditions += " AND ";
tag_conditions += "json_extract(d.metadata_json, '$.Tags') LIKE '%<" + tags_all[i] + ">%'";
}
fts_sql += " AND (" + tag_conditions + ")";
}
}
if (filters.contains("created_after") && filters["created_after"].is_string()) {
std::string created_after = filters["created_after"].get<std::string>();
// Filter by CreationDate in metadata_json
fts_sql += " AND json_extract(d.metadata_json, '$.CreationDate') >= '" + created_after + "'";
}
if (filters.contains("created_before") && filters["created_before"].is_string()) {
std::string created_before = filters["created_before"].get<std::string>();
// Filter by CreationDate in metadata_json
fts_sql += " AND json_extract(d.metadata_json, '$.CreationDate') <= '" + created_before + "'";
}
fts_sql += " ORDER BY score_fts_raw "
"LIMIT " + std::to_string(fts_k);
SQLite3_result* fts_result = execute_query(fts_sql.c_str());
if (!fts_result) {
return create_error_response("FTS database query failed");
}
// Run vector search with filters
std::vector<float> query_embedding;
if (ai_manager && GloGATH) {
GenAI_EmbeddingResult result = GloGATH->embed_documents({query});
if (result.data && result.count > 0) {
// Convert to std::vector<float>
query_embedding.assign(result.data, result.data + result.embedding_size);
// Free the result data (GenAI allocates with malloc)
free(result.data);
}
}
if (query_embedding.empty()) {
delete fts_result;
return create_error_response("Failed to generate embedding for query");
}
// Convert embedding to JSON array format for sqlite-vec
std::string embedding_json = "[";
for (size_t i = 0; i < query_embedding.size(); ++i) {
if (i > 0) embedding_json += ",";
embedding_json += std::to_string(query_embedding[i]);
}
embedding_json += "]";
// Build vector search query using sqlite-vec syntax with filters
// Must use subquery approach: LIMIT must be at same query level as MATCH
std::string vec_sql = "SELECT v.chunk_id, c.doc_id, c.source_id, "
"(SELECT name FROM rag_sources WHERE source_id = c.source_id) as source_name, "
"c.title, v.distance as score_vec_raw, "
"c.metadata_json "
"FROM ("
" SELECT chunk_id, distance "
" FROM rag_vec_chunks "
" WHERE embedding MATCH '" + escape_fts_query(embedding_json) + "' "
" ORDER BY distance "
" LIMIT " + std::to_string(vec_k) + " "
") v "
"JOIN rag_chunks c ON c.chunk_id = v.chunk_id "
"JOIN rag_documents d ON d.doc_id = c.doc_id";
// Apply filters using consolidated filter building function
// These filters are applied to the outer query after JOINs
if (!build_sql_filters(filters, vec_sql)) {
return create_error_response("Invalid filter parameters");
}
vec_sql += " ORDER BY v.distance";
SQLite3_result* vec_result = execute_query(vec_sql.c_str());
if (!vec_result) {
delete fts_result;
return create_error_response("Vector database query failed");
}
// Merge candidates by chunk_id and compute fused scores
std::map<std::string, json> fused_results;
// Process FTS results
int fts_rank = 1;
for (const auto& row : fts_result->rows) {
if (row->fields) {
std::string chunk_id = row->fields[0] ? row->fields[0] : "";
if (!chunk_id.empty()) {
json item;
item["chunk_id"] = chunk_id;
item["doc_id"] = row->fields[1] ? row->fields[1] : "";
item["source_id"] = row->fields[2] ? std::stoi(row->fields[2]) : 0;
item["source_name"] = row->fields[3] ? row->fields[3] : "";
item["title"] = row->fields[4] ? row->fields[4] : "";
double score_fts_raw = row->fields[5] ? std::stod(row->fields[5]) : 0.0;
// Normalize FTS score (bm25 - lower is better, so we invert it)
double score_fts = 1.0 / (1.0 + std::abs(score_fts_raw));
item["score_fts"] = score_fts;
item["rank_fts"] = fts_rank;
item["rank_vec"] = 0; // Will be updated if found in vector results
item["score_vec"] = 0.0;
// Add metadata if available
if (row->fields[6]) {
try {
item["metadata"] = json::parse(row->fields[6]);
} catch (...) {
item["metadata"] = json::object();
}
}
fused_results[chunk_id] = item;
fts_rank++;
}
}
}
// Process vector results
int vec_rank = 1;
for (const auto& row : vec_result->rows) {
if (row->fields) {
std::string chunk_id = row->fields[0] ? row->fields[0] : "";
if (!chunk_id.empty()) {
double score_vec_raw = row->fields[5] ? std::stod(row->fields[5]) : 0.0;
// For vector search, lower distance is better, so we invert it
double score_vec = 1.0 / (1.0 + score_vec_raw);
auto it = fused_results.find(chunk_id);
if (it != fused_results.end()) {
// Chunk already in FTS results, update vector info
it->second["rank_vec"] = vec_rank;
it->second["score_vec"] = score_vec;
} else {
// New chunk from vector results
json item;
item["chunk_id"] = chunk_id;
item["doc_id"] = row->fields[1] ? row->fields[1] : "";
item["source_id"] = row->fields[2] ? std::stoi(row->fields[2]) : 0;
item["source_name"] = row->fields[3] ? row->fields[3] : "";
item["title"] = row->fields[4] ? row->fields[4] : "";
item["score_vec"] = score_vec;
item["rank_vec"] = vec_rank;
item["rank_fts"] = 0; // Not found in FTS
item["score_fts"] = 0.0;
// Add metadata if available
if (row->fields[6]) {
try {
item["metadata"] = json::parse(row->fields[6]);
} catch (...) {
item["metadata"] = json::object();
}
}
fused_results[chunk_id] = item;
}
vec_rank++;
}
}
}
// Compute fused scores using RRF
std::vector<std::pair<double, json>> scored_results;
double min_score = 0.0;
bool has_min_score = false;
if (filters.contains("min_score") && (filters["min_score"].is_number() || filters["min_score"].is_string())) {
min_score = filters["min_score"].is_number() ?
filters["min_score"].get<double>() :
std::stod(filters["min_score"].get<std::string>());
has_min_score = true;
}
for (auto& pair : fused_results) {
json& item = pair.second;
int rank_fts = item["rank_fts"].get<int>();
int rank_vec = item["rank_vec"].get<int>();
double score_fts = item["score_fts"].get<double>();
double score_vec = item["score_vec"].get<double>();
// Compute fused score using weighted RRF
double fused_score = 0.0;
if (rank_fts > 0) {
fused_score += w_fts / (rrf_k0 + rank_fts);
}
if (rank_vec > 0) {
fused_score += w_vec / (rrf_k0 + rank_vec);
}
// Apply min_score filter
if (has_min_score && fused_score < min_score) {
continue; // Skip this result
}
item["score"] = fused_score;
item["score_fts"] = score_fts;
item["score_vec"] = score_vec;
// Add debug info
json debug;
debug["rank_fts"] = rank_fts;
debug["rank_vec"] = rank_vec;
item["debug"] = debug;
scored_results.push_back({fused_score, item});
}
// Sort by fused score descending
std::sort(scored_results.begin(), scored_results.end(),
[](const std::pair<double, json>& a, const std::pair<double, json>& b) {
return a.first > b.first;
});
// Take top k results
for (size_t i = 0; i < scored_results.size() && i < static_cast<size_t>(k); ++i) {
results.push_back(scored_results[i].second);
}
delete fts_result;
delete vec_result;
} else if (mode == "fts_then_vec") {
// Mode B: broad FTS candidate generation, then vector rerank
// Get parameters from fts_then_vec object
int candidates_k = 200;
int rerank_k = 50;
if (arguments.contains("fts_then_vec") && arguments["fts_then_vec"].is_object()) {
const json& fts_then_vec_params = arguments["fts_then_vec"];
candidates_k = validate_candidates(get_json_int(fts_then_vec_params, "candidates_k", 200));
rerank_k = validate_k(get_json_int(fts_then_vec_params, "rerank_k", 50));
} else {
// Fallback to top-level parameters for backward compatibility
candidates_k = validate_candidates(get_json_int(arguments, "candidates_k", 200));
rerank_k = validate_k(get_json_int(arguments, "rerank_k", 50));
}
// Run FTS search to get candidates with filters
std::string fts_sql = "SELECT c.chunk_id "
"FROM rag_fts_chunks f "
"JOIN rag_chunks c ON c.chunk_id = f.chunk_id "
"JOIN rag_documents d ON d.doc_id = c.doc_id "
"WHERE rag_fts_chunks MATCH '" + escape_fts_query(query) + "'";
// Apply filters using consolidated filter building function
if (!build_sql_filters(filters, fts_sql, false)) {
return create_error_response("Invalid filter parameters");
}
if (filters.contains("source_names") && filters["source_names"].is_array()) {
std::vector<std::string> source_names = get_json_string_array(filters, "source_names");
if (!source_names.empty()) {
std::string source_list = "";
for (size_t i = 0; i < source_names.size(); ++i) {
if (i > 0) source_list += ",";
source_list += "'" + source_names[i] + "'";
}
fts_sql += " AND c.source_id IN (SELECT source_id FROM rag_sources WHERE name IN (" + source_list + "))";
}
}
if (filters.contains("doc_ids") && filters["doc_ids"].is_array()) {
std::vector<std::string> doc_ids = get_json_string_array(filters, "doc_ids");
if (!doc_ids.empty()) {
std::string doc_list = "";
for (size_t i = 0; i < doc_ids.size(); ++i) {
if (i > 0) doc_list += ",";
doc_list += "'" + doc_ids[i] + "'";
}
fts_sql += " AND c.doc_id IN (" + doc_list + ")";
}
}
// Metadata filters
if (filters.contains("post_type_ids") && filters["post_type_ids"].is_array()) {
std::vector<int> post_type_ids = get_json_int_array(filters, "post_type_ids");
if (!post_type_ids.empty()) {
// Filter by PostTypeId in metadata_json
std::string post_type_conditions = "";
for (size_t i = 0; i < post_type_ids.size(); ++i) {
if (i > 0) post_type_conditions += " OR ";
post_type_conditions += "json_extract(d.metadata_json, '$.PostTypeId') = " + std::to_string(post_type_ids[i]);
}
fts_sql += " AND (" + post_type_conditions + ")";
}
}
if (filters.contains("tags_any") && filters["tags_any"].is_array()) {
std::vector<std::string> tags_any = get_json_string_array(filters, "tags_any");
if (!tags_any.empty()) {
// Filter by any of the tags in metadata_json Tags field
std::string tag_conditions = "";
for (size_t i = 0; i < tags_any.size(); ++i) {
if (i > 0) tag_conditions += " OR ";
tag_conditions += "json_extract(d.metadata_json, '$.Tags') LIKE '%<" + tags_any[i] + ">%'";
}
fts_sql += " AND (" + tag_conditions + ")";
}
}
if (filters.contains("tags_all") && filters["tags_all"].is_array()) {
std::vector<std::string> tags_all = get_json_string_array(filters, "tags_all");
if (!tags_all.empty()) {
// Filter by all of the tags in metadata_json Tags field
std::string tag_conditions = "";
for (size_t i = 0; i < tags_all.size(); ++i) {
if (i > 0) tag_conditions += " AND ";
tag_conditions += "json_extract(d.metadata_json, '$.Tags') LIKE '%<" + tags_all[i] + ">%'";
}
fts_sql += " AND (" + tag_conditions + ")";
}
}
if (filters.contains("created_after") && filters["created_after"].is_string()) {
std::string created_after = filters["created_after"].get<std::string>();
// Filter by CreationDate in metadata_json
fts_sql += " AND json_extract(d.metadata_json, '$.CreationDate') >= '" + created_after + "'";
}
if (filters.contains("created_before") && filters["created_before"].is_string()) {
std::string created_before = filters["created_before"].get<std::string>();
// Filter by CreationDate in metadata_json
fts_sql += " AND json_extract(d.metadata_json, '$.CreationDate') <= '" + created_before + "'";
}
fts_sql += " ORDER BY bm25(rag_fts_chunks) "
"LIMIT " + std::to_string(candidates_k);
SQLite3_result* fts_result = execute_query(fts_sql.c_str());
if (!fts_result) {
return create_error_response("FTS database query failed");
}
// Build candidate list
std::vector<std::string> candidate_ids;
for (const auto& row : fts_result->rows) {
if (row->fields && row->fields[0]) {
candidate_ids.push_back(row->fields[0]);
}
}
delete fts_result;
if (candidate_ids.empty()) {
// No candidates found
} else {
// Run vector search on candidates with filters
std::vector<float> query_embedding;
if (ai_manager && GloGATH) {
GenAI_EmbeddingResult result = GloGATH->embed_documents({query});
if (result.data && result.count > 0) {
// Convert to std::vector<float>
query_embedding.assign(result.data, result.data + result.embedding_size);
// Free the result data (GenAI allocates with malloc)
free(result.data);
}
}
if (query_embedding.empty()) {
return create_error_response("Failed to generate embedding for query");
}
// Convert embedding to JSON array format for sqlite-vec
std::string embedding_json = "[";
for (size_t i = 0; i < query_embedding.size(); ++i) {
if (i > 0) embedding_json += ",";
embedding_json += std::to_string(query_embedding[i]);
}
embedding_json += "]";
// Build candidate ID list for SQL
std::string candidate_list = "'";
for (size_t i = 0; i < candidate_ids.size(); ++i) {
if (i > 0) candidate_list += "','";
candidate_list += candidate_ids[i];
}
candidate_list += "'";
// Build vector search query using sqlite-vec syntax with filters
// Must use subquery approach: LIMIT must be at same query level as MATCH
std::string vec_sql = "SELECT v.chunk_id, c.doc_id, c.source_id, "
"(SELECT name FROM rag_sources WHERE source_id = c.source_id) as source_name, "
"c.title, v.distance as score_vec_raw, "
"c.metadata_json "
"FROM ("
" SELECT chunk_id, distance "
" FROM rag_vec_chunks "
" WHERE embedding MATCH '" + escape_fts_query(embedding_json) + "' "
" AND chunk_id IN (" + candidate_list + ") "
" ORDER BY distance "
" LIMIT " + std::to_string(rerank_k) + " "
") v "
"JOIN rag_chunks c ON c.chunk_id = v.chunk_id "
"JOIN rag_documents d ON d.doc_id = c.doc_id";
// Apply filters using consolidated filter building function
// These filters are applied to the outer query after JOINs
if (!build_sql_filters(filters, vec_sql)) {
return create_error_response("Invalid filter parameters");
}
vec_sql += " ORDER BY v.distance";
SQLite3_result* vec_result = execute_query(vec_sql.c_str());
if (!vec_result) {
return create_error_response("Vector database query failed");
}
// Build results with min_score filtering
int rank = 1;
double min_score = 0.0;
bool has_min_score = false;
if (filters.contains("min_score") && (filters["min_score"].is_number() || filters["min_score"].is_string())) {
min_score = filters["min_score"].is_number() ?
filters["min_score"].get<double>() :
std::stod(filters["min_score"].get<std::string>());
has_min_score = true;
}
for (const auto& row : vec_result->rows) {
if (row->fields) {
double score_vec_raw = row->fields[5] ? std::stod(row->fields[5]) : 0.0;
// For vector search, lower distance is better, so we invert it
double score_vec = 1.0 / (1.0 + score_vec_raw);
// Apply min_score filter
if (has_min_score && score_vec < min_score) {
continue; // Skip this result
}
json item;
item["chunk_id"] = row->fields[0] ? row->fields[0] : "";
item["doc_id"] = row->fields[1] ? row->fields[1] : "";
item["source_id"] = row->fields[2] ? std::stoi(row->fields[2]) : 0;
item["source_name"] = row->fields[3] ? row->fields[3] : "";
item["title"] = row->fields[4] ? row->fields[4] : "";
item["score"] = score_vec;
item["score_vec"] = score_vec;
item["rank"] = rank;
// Add metadata if available
if (row->fields[6]) {
try {
item["metadata"] = json::parse(row->fields[6]);
} catch (...) {
item["metadata"] = json::object();
}
}
results.push_back(item);
rank++;
}
}
delete vec_result;
}
}
result["results"] = results;
result["truncated"] = false;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["mode"] = mode;
stats["k_requested"] = k;
stats["k_returned"] = static_cast<int>(results.size());
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else if (tool_name == "rag.get_chunks") {
// Get chunks implementation
std::vector<std::string> chunk_ids = get_json_string_array(arguments, "chunk_ids");
if (chunk_ids.empty()) {
return create_error_response("No chunk_ids provided");
}
// Validate chunk_ids to prevent SQL injection
for (const std::string& chunk_id : chunk_ids) {
if (chunk_id.find('\'') != std::string::npos ||
chunk_id.find('\\') != std::string::npos ||
chunk_id.find(';') != std::string::npos) {
return create_error_response("Invalid characters in chunk_ids");
}
}
// Get return parameters
bool include_title = true;
bool include_doc_metadata = true;
bool include_chunk_metadata = true;
if (arguments.contains("return")) {
const json& return_params = arguments["return"];
include_title = get_json_bool(return_params, "include_title", true);
include_doc_metadata = get_json_bool(return_params, "include_doc_metadata", true);
include_chunk_metadata = get_json_bool(return_params, "include_chunk_metadata", true);
}
// Build chunk ID list for SQL with proper escaping
std::string chunk_list = "";
for (size_t i = 0; i < chunk_ids.size(); ++i) {
if (i > 0) chunk_list += ",";
// Properly escape single quotes in chunk IDs
std::string escaped_chunk_id = chunk_ids[i];
size_t pos = 0;
while ((pos = escaped_chunk_id.find("'", pos)) != std::string::npos) {
escaped_chunk_id.replace(pos, 1, "''");
pos += 2;
}
chunk_list += "'" + escaped_chunk_id + "'";
}
// Build query with proper joins to get metadata
std::string sql = "SELECT c.chunk_id, c.doc_id, c.title, c.body, "
"d.metadata_json as doc_metadata, c.metadata_json as chunk_metadata "
"FROM rag_chunks c "
"LEFT JOIN rag_documents d ON d.doc_id = c.doc_id "
"WHERE c.chunk_id IN (" + chunk_list + ")";
SQLite3_result* db_result = execute_query(sql.c_str());
if (!db_result) {
return create_error_response("Database query failed");
}
// Build chunks array
json chunks = json::array();
for (const auto& row : db_result->rows) {
if (row->fields) {
json chunk;
chunk["chunk_id"] = row->fields[0] ? row->fields[0] : "";
chunk["doc_id"] = row->fields[1] ? row->fields[1] : "";
if (include_title) {
chunk["title"] = row->fields[2] ? row->fields[2] : "";
}
// Always include body for get_chunks
chunk["body"] = row->fields[3] ? row->fields[3] : "";
if (include_doc_metadata && row->fields[4]) {
try {
chunk["doc_metadata"] = json::parse(row->fields[4]);
} catch (...) {
chunk["doc_metadata"] = json::object();
}
}
if (include_chunk_metadata && row->fields[5]) {
try {
chunk["chunk_metadata"] = json::parse(row->fields[5]);
} catch (...) {
chunk["chunk_metadata"] = json::object();
}
}
chunks.push_back(chunk);
}
}
delete db_result;
result["chunks"] = chunks;
result["truncated"] = false;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else if (tool_name == "rag.get_docs") {
// Get docs implementation
std::vector<std::string> doc_ids = get_json_string_array(arguments, "doc_ids");
if (doc_ids.empty()) {
return create_error_response("No doc_ids provided");
}
// Get return parameters
bool include_body = true;
bool include_metadata = true;
if (arguments.contains("return")) {
const json& return_params = arguments["return"];
include_body = get_json_bool(return_params, "include_body", true);
include_metadata = get_json_bool(return_params, "include_metadata", true);
}
// Build doc ID list for SQL
std::string doc_list = "'";
for (size_t i = 0; i < doc_ids.size(); ++i) {
if (i > 0) doc_list += "','";
doc_list += doc_ids[i];
}
doc_list += "'";
// Build query
std::string sql = "SELECT doc_id, source_id, "
"(SELECT name FROM rag_sources WHERE source_id = rag_documents.source_id) as source_name, "
"pk_json, title, body, metadata_json "
"FROM rag_documents "
"WHERE doc_id IN (" + doc_list + ")";
SQLite3_result* db_result = execute_query(sql.c_str());
if (!db_result) {
return create_error_response("Database query failed");
}
// Build docs array
json docs = json::array();
for (const auto& row : db_result->rows) {
if (row->fields) {
json doc;
doc["doc_id"] = row->fields[0] ? row->fields[0] : "";
doc["source_id"] = row->fields[1] ? std::stoi(row->fields[1]) : 0;
doc["source_name"] = row->fields[2] ? row->fields[2] : "";
doc["pk_json"] = row->fields[3] ? row->fields[3] : "{}";
// Always include title
doc["title"] = row->fields[4] ? row->fields[4] : "";
if (include_body) {
doc["body"] = row->fields[5] ? row->fields[5] : "";
}
if (include_metadata && row->fields[6]) {
try {
doc["metadata"] = json::parse(row->fields[6]);
} catch (...) {
doc["metadata"] = json::object();
}
}
docs.push_back(doc);
}
}
delete db_result;
result["docs"] = docs;
result["truncated"] = false;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else if (tool_name == "rag.fetch_from_source") {
// Fetch from source implementation
std::vector<std::string> doc_ids = get_json_string_array(arguments, "doc_ids");
std::vector<std::string> columns = get_json_string_array(arguments, "columns");
// Get limits
int max_rows = 10;
int max_bytes = 200000;
if (arguments.contains("limits")) {
const json& limits = arguments["limits"];
max_rows = get_json_int(limits, "max_rows", 10);
max_bytes = get_json_int(limits, "max_bytes", 200000);
}
if (doc_ids.empty()) {
return create_error_response("No doc_ids provided");
}
// Validate limits
if (max_rows > 100) max_rows = 100;
if (max_bytes > 1000000) max_bytes = 1000000;
// Build doc ID list for SQL
std::string doc_list = "'";
for (size_t i = 0; i < doc_ids.size(); ++i) {
if (i > 0) doc_list += "','";
doc_list += doc_ids[i];
}
doc_list += "'";
// Look up documents to get source connection info
std::string doc_sql = "SELECT d.doc_id, d.source_id, d.pk_json, d.source_name, "
"s.backend_type, s.backend_host, s.backend_port, s.backend_user, s.backend_pass, s.backend_db, "
"s.table_name, s.pk_column "
"FROM rag_documents d "
"JOIN rag_sources s ON s.source_id = d.source_id "
"WHERE d.doc_id IN (" + doc_list + ")";
SQLite3_result* doc_result = execute_query(doc_sql.c_str());
if (!doc_result) {
return create_error_response("Database query failed");
}
// Build rows array
json rows = json::array();
int total_bytes = 0;
bool truncated = false;
// Process each document
for (const auto& row : doc_result->rows) {
if (row->fields && rows.size() < static_cast<size_t>(max_rows) && total_bytes < max_bytes) {
std::string doc_id = row->fields[0] ? row->fields[0] : "";
// int source_id = row->fields[1] ? std::stoi(row->fields[1]) : 0;
std::string pk_json = row->fields[2] ? row->fields[2] : "{}";
std::string source_name = row->fields[3] ? row->fields[3] : "";
// std::string backend_type = row->fields[4] ? row->fields[4] : "";
// std::string backend_host = row->fields[5] ? row->fields[5] : "";
// int backend_port = row->fields[6] ? std::stoi(row->fields[6]) : 0;
// std::string backend_user = row->fields[7] ? row->fields[7] : "";
// std::string backend_pass = row->fields[8] ? row->fields[8] : "";
// std::string backend_db = row->fields[9] ? row->fields[9] : "";
// std::string table_name = row->fields[10] ? row->fields[10] : "";
std::string pk_column = row->fields[11] ? row->fields[11] : "";
// For now, we'll return a simplified response since we can't actually connect to external databases
// In a full implementation, this would connect to the source database and fetch the data
json result_row;
result_row["doc_id"] = doc_id;
result_row["source_name"] = source_name;
// Parse pk_json to get the primary key value
try {
json pk_data = json::parse(pk_json);
json row_data = json::object();
// If specific columns are requested, only include those
if (!columns.empty()) {
for (const std::string& col : columns) {
// For demo purposes, we'll just echo back some mock data
if (col == "Id" && pk_data.contains("Id")) {
row_data["Id"] = pk_data["Id"];
} else if (col == pk_column) {
// This would be the actual primary key value
row_data[col] = "mock_value";
} else {
// For other columns, provide mock data
row_data[col] = "mock_" + col + "_value";
}
}
} else {
// If no columns specified, include basic info
row_data["Id"] = pk_data.contains("Id") ? pk_data["Id"] : json(0);
row_data[pk_column] = "mock_pk_value";
}
result_row["row"] = row_data;
// Check size limits
std::string row_str = result_row.dump();
if (total_bytes + static_cast<int>(row_str.length()) > max_bytes) {
truncated = true;
break;
}
total_bytes += static_cast<int>(row_str.length());
rows.push_back(result_row);
} catch (...) {
// Skip malformed pk_json
continue;
}
} else if (rows.size() >= static_cast<size_t>(max_rows) || total_bytes >= max_bytes) {
truncated = true;
break;
}
}
delete doc_result;
result["rows"] = rows;
result["truncated"] = truncated;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else if (tool_name == "rag.admin.stats") {
// Admin stats implementation
// Build query to get source statistics
std::string sql = "SELECT s.source_id, s.name, "
"COUNT(d.doc_id) as docs, "
"COUNT(c.chunk_id) as chunks "
"FROM rag_sources s "
"LEFT JOIN rag_documents d ON d.source_id = s.source_id "
"LEFT JOIN rag_chunks c ON c.source_id = s.source_id "
"GROUP BY s.source_id, s.name";
SQLite3_result* db_result = execute_query(sql.c_str());
if (!db_result) {
return create_error_response("Database query failed");
}
// Build sources array
json sources = json::array();
for (const auto& row : db_result->rows) {
if (row->fields) {
json source;
source["source_id"] = row->fields[0] ? std::stoi(row->fields[0]) : 0;
source["source_name"] = row->fields[1] ? row->fields[1] : "";
source["docs"] = row->fields[2] ? std::stoi(row->fields[2]) : 0;
source["chunks"] = row->fields[3] ? std::stoi(row->fields[3]) : 0;
source["last_sync"] = nullptr; // Placeholder
sources.push_back(source);
}
}
delete db_result;
result["sources"] = sources;
// Add timing stats
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
json stats;
stats["ms"] = static_cast<int>(duration.count());
result["stats"] = stats;
} else {
// Unknown tool
auto end_time = std::chrono::high_resolution_clock::now();
auto duration_us = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
track_tool_invocation(this, "RAG", tool_name, "rag", duration_us);
return create_error_response("Unknown tool: " + tool_name);
}
// Track invocation with timing
auto end_time = std::chrono::high_resolution_clock::now();
auto duration_us = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
track_tool_invocation(this, "RAG", tool_name, "rag", duration_us);
return create_success_response(result);
} catch (const std::exception& e) {
proxy_error("RAG_Tool_Handler: Exception in execute_tool: %s\n", e.what());
return create_error_response(std::string("Exception: ") + e.what());
} catch (...) {
proxy_error("RAG_Tool_Handler: Unknown exception in execute_tool\n");
return create_error_response("Unknown exception");
}
}
// ============================================================================
// Tool Usage Statistics
// ============================================================================
RAG_Tool_Handler::ToolUsageStatsMap RAG_Tool_Handler::get_tool_usage_stats() {
// Thread-safe copy of counters
pthread_mutex_lock(&counters_lock);
ToolUsageStatsMap copy = tool_usage_stats;
pthread_mutex_unlock(&counters_lock);
return copy;
}
SQLite3_result* RAG_Tool_Handler::get_tool_usage_stats_resultset(bool reset) {
SQLite3_result* result = new SQLite3_result(9);
result->add_column_definition(SQLITE_TEXT, "endpoint");
result->add_column_definition(SQLITE_TEXT, "tool");
result->add_column_definition(SQLITE_TEXT, "schema");
result->add_column_definition(SQLITE_TEXT, "count");
result->add_column_definition(SQLITE_TEXT, "first_seen");
result->add_column_definition(SQLITE_TEXT, "last_seen");
result->add_column_definition(SQLITE_TEXT, "sum_time");
result->add_column_definition(SQLITE_TEXT, "min_time");
result->add_column_definition(SQLITE_TEXT, "max_time");
pthread_mutex_lock(&counters_lock);
for (ToolUsageStatsMap::const_iterator endpoint_it = tool_usage_stats.begin();
endpoint_it != tool_usage_stats.end(); ++endpoint_it) {
const std::string& endpoint = endpoint_it->first;
const ToolStatsMap& tools = endpoint_it->second;
for (ToolStatsMap::const_iterator tool_it = tools.begin();
tool_it != tools.end(); ++tool_it) {
const std::string& tool_name = tool_it->first;
const SchemaStatsMap& schemas = tool_it->second;
for (SchemaStatsMap::const_iterator schema_it = schemas.begin();
schema_it != schemas.end(); ++schema_it) {
const std::string& schema_name = schema_it->first;
const ToolUsageStats& stats = schema_it->second;
char** row = new char*[9];
row[0] = strdup(endpoint.c_str());
row[1] = strdup(tool_name.c_str());
row[2] = strdup(schema_name.c_str());
char buf[32];
snprintf(buf, sizeof(buf), "%llu", stats.count);
row[3] = strdup(buf);
snprintf(buf, sizeof(buf), "%llu", stats.first_seen);
row[4] = strdup(buf);
snprintf(buf, sizeof(buf), "%llu", stats.last_seen);
row[5] = strdup(buf);
snprintf(buf, sizeof(buf), "%llu", stats.sum_time);
row[6] = strdup(buf);
snprintf(buf, sizeof(buf), "%llu", stats.min_time);
row[7] = strdup(buf);
snprintf(buf, sizeof(buf), "%llu", stats.max_time);
row[8] = strdup(buf);
result->add_row(row);
}
}
}
if (reset) {
tool_usage_stats.clear();
}
pthread_mutex_unlock(&counters_lock);
return result;
}
#endif /* PROXYSQLGENAI */