Implement batch embedding generation in rag_ingest

Previously, embeddings were generated one chunk at a time, with each
chunk requiring a separate HTTP API call. This was inefficient due to
HTTP call overhead (latency).

Changes:
- Add PendingEmbedding struct to hold chunk metadata for batching
- Add flush_embedding_batch() function to process multiple embeddings
  in a single API call
- Modify ingestion loop to collect chunks and flush batches when full
- Use the existing batch_size configuration from embedding_json
  (default: 16 chunks per batch)

Performance improvement:
- Before: 1000 chunks = 1000 HTTP calls (e.g., 100ms each = 100 seconds)
- After: 1000 chunks = ~63 HTTP calls (1000 / 16 = ~63 batches)

This significantly reduces ingestion time for sources with embeddings
enabled.
pull/5324/head
Rene Cannao 4 months ago
parent 6166adaf05
commit ff50e834f0

Binary file not shown.

@ -501,6 +501,19 @@ struct SyncCursor {
*/
typedef std::unordered_map<std::string, std::string> RowMap;
/**
* @brief Pending embedding for batched processing
*
* Holds metadata for a chunk that needs embedding generation.
* Used to batch multiple chunks together for efficient API calls.
*/
struct PendingEmbedding {
std::string chunk_id; ///< Unique chunk identifier (doc_id#index)
std::string doc_id; ///< Parent document identifier
int source_id; ///< Source identifier
std::string input_text; ///< Text to embed (already built via build_embedding_input)
};
// ===========================================================================
// JSON Parsing Functions
// ===========================================================================
@ -1823,6 +1836,51 @@ static std::string build_embedding_input(const EmbeddingConfig& ecfg,
return chunk_body;
}
/**
* @brief Process a batch of pending embeddings
*
* Generates embeddings for all pending chunks in a single API call
* and inserts the resulting vectors into the database. This is much
* more efficient than generating embeddings one chunk at a time.
*
* @param pending Vector of pending embeddings to process
* @param embedder Embedding provider instance
* @param ecfg Embedding configuration (dim, etc.)
* @param ss Prepared SQLite statements
* @param now_epoch Current epoch time for updated_at field
* @return size_t Number of embeddings processed
*
* @note Clears the pending vector after processing.
* @note Throws std::runtime_error on embedding API failure.
*/
static size_t flush_embedding_batch(std::vector<PendingEmbedding>& pending,
EmbeddingProvider* embedder,
const EmbeddingConfig& ecfg,
SqliteStmts& ss,
std::int64_t now_epoch) {
if (pending.empty()) return 0;
// Build batch inputs
std::vector<std::string> inputs;
inputs.reserve(pending.size());
for (const auto& p : pending) {
inputs.push_back(p.input_text);
}
// Generate embeddings in a single API call
std::vector<std::vector<float>> embeddings = embedder->embed(inputs, ecfg.dim);
// Insert all vectors into the database
for (size_t i = 0; i < pending.size() && i < embeddings.size(); i++) {
const auto& p = pending[i];
sqlite_insert_vec(ss, embeddings[i], p.chunk_id, p.doc_id, p.source_id, now_epoch);
}
size_t count = pending.size();
pending.clear();
return count;
}
// ===========================================================================
// Statement Preparation
// ===========================================================================
@ -1947,6 +2005,13 @@ static void ingest_source(sqlite3* sdb, const RagSource& src) {
std::uint64_t ingested_docs = 0;
std::uint64_t skipped_docs = 0;
// Note: updated_at is set to 0 for v0. For accurate timestamps,
// query SELECT unixepoch() once at the start of main().
std::int64_t now_epoch = 0;
// Batch embeddings for efficiency
std::vector<PendingEmbedding> pending_embeddings;
// Track max watermark value (for next run)
MYSQL_ROW r;
bool max_set = false;
@ -2022,10 +2087,6 @@ static void ingest_source(sqlite3* sdb, const RagSource& src) {
// Chunk document body
std::vector<std::string> chunks = chunk_text_chars(doc.body, ccfg);
// Note: updated_at is set to 0 for v0. For accurate timestamps,
// query SELECT unixepoch() once at the start of main().
std::int64_t now_epoch = 0;
// Process each chunk
for (size_t i = 0; i < chunks.size(); i++) {
std::string chunk_id = doc.doc_id + "#" + std::to_string(i);
@ -2044,12 +2105,15 @@ static void ingest_source(sqlite3* sdb, const RagSource& src) {
// Insert into FTS index
sqlite_insert_fts(ss, chunk_id, chunk_title, chunks[i]);
// Generate and insert embedding (if enabled)
// Collect embedding for batched processing (if enabled)
if (ecfg.enabled) {
std::string emb_input = build_embedding_input(ecfg, row, chunks[i]);
std::vector<std::string> batch_inputs = {emb_input};
std::vector<std::vector<float>> vecs = embedder->embed(batch_inputs, ecfg.dim);
sqlite_insert_vec(ss, vecs[0], chunk_id, doc.doc_id, src.source_id, now_epoch);
pending_embeddings.push_back({chunk_id, doc.doc_id, src.source_id, emb_input});
// Flush batch if full
if ((int)pending_embeddings.size() >= ecfg.batch_size) {
flush_embedding_batch(pending_embeddings, embedder.get(), ecfg, ss, now_epoch);
}
}
}
@ -2060,6 +2124,11 @@ static void ingest_source(sqlite3* sdb, const RagSource& src) {
}
}
// Flush any remaining pending embeddings
if (ecfg.enabled && !pending_embeddings.empty()) {
flush_embedding_batch(pending_embeddings, embedder.get(), ecfg, ss, now_epoch);
}
// Cleanup
mysql_free_result(res);
mysql_close(mdb);

Loading…
Cancel
Save