diff --git a/include/query_processor.h b/include/query_processor.h index a288b10fe..741830589 100644 --- a/include/query_processor.h +++ b/include/query_processor.h @@ -111,6 +111,9 @@ struct _Query_Processor_rule_t { void *regex_engine2; uint64_t hits; struct _Query_Processor_rule_t *parent; // pointer to parent, to speed up parent update + std::vector * flagOUT_ids; + std::vector * flagOUT_weights; + int flagOUT_weights_total; }; typedef struct _Query_Processor_rule_t QP_rule_t; diff --git a/lib/Query_Processor.cpp b/lib/Query_Processor.cpp index 7826fa30a..c09b8edbc 100644 --- a/lib/Query_Processor.cpp +++ b/lib/Query_Processor.cpp @@ -425,6 +425,16 @@ static void __delete_query_rule(QP_rule_t *qr) { if (r->re2) { delete r->re2; r->re2=NULL; } free(qr->regex_engine2); } + if (qr->flagOUT_ids != NULL) { + qr->flagOUT_ids->clear(); + delete qr->flagOUT_ids; + qr->flagOUT_ids = NULL; + } + if (qr->flagOUT_weights != NULL) { + qr->flagOUT_weights->clear(); + delete qr->flagOUT_weights; + qr->flagOUT_weights = NULL; + } free(qr); }; @@ -732,6 +742,40 @@ QP_rule_t * Query_Processor::new_query_rule(int rule_id, bool active, char *user proxy_error("Incorrect digest for rule_id %d : %s\n" , rule_id, digest); } } + newQR->flagOUT_weights_total = 0; + newQR->flagOUT_ids = NULL; + newQR->flagOUT_weights = NULL; + if (newQR->attributes != NULL) { + if (strlen(newQR->attributes)) { + nlohmann::json j_attributes = nlohmann::json::parse(newQR->attributes); + if ( j_attributes.find("flagOUTs") != j_attributes.end() ) { + newQR->flagOUT_ids = new vector; + newQR->flagOUT_weights = new vector; + const nlohmann::json& flagOUTs = j_attributes["flagOUTs"]; + if (flagOUTs.type() == nlohmann::json::value_t::array) { + for (auto it = flagOUTs.begin(); it != flagOUTs.end(); it++) { + bool parsed = false; + const nlohmann::json& j = *it; + if (j.find("id") != j.end() && j.find("weight") != j.end()) { + if (j["id"].type() == nlohmann::json::value_t::number_unsigned && j["weight"].type() == nlohmann::json::value_t::number_unsigned) { + int id = j["id"]; + int weight = j["weight"]; + newQR->flagOUT_ids->push_back(id); + newQR->flagOUT_weights->push_back(weight); + newQR->flagOUT_weights_total += weight; + parsed = true; + } + } + if (parsed == false) { + proxy_error("Failed to parse flagOUTs in JSON on attributes for rule_id %d : %s\n" , newQR->rule_id, j.dump().c_str()); + } + } + } else { + proxy_error("Failed to parse flagOUTs attributes for rule_id %d : %s\n" , newQR->rule_id, flagOUTs.dump().c_str()); + } + } + } + } proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 5, "Creating new rule in %p : rule_id:%d, active:%d, username=%s, schemaname=%s, flagIN:%d, %smatch_digest=\"%s\", %smatch_pattern=\"%s\", flagOUT:%d replace_pattern=\"%s\", destination_hostgroup:%d, apply:%d\n", newQR, newQR->rule_id, newQR->active, newQR->username, newQR->schemaname, newQR->flagIN, (newQR->negate_match_pattern ? "(!)" : "") , newQR->match_digest, (newQR->negate_match_pattern ? "(!)" : "") , newQR->match_pattern, newQR->flagOUT, newQR->replace_pattern, newQR->destination_hostgroup, newQR->apply); return newQR; }; @@ -1938,12 +1982,26 @@ __internal_loop: // if we arrived here, we have a match qr->hits++; // this is done without atomic function because it updates only the local variables bool set_flagOUT=false; + if (qr->flagOUT_weights_total > 0) { + int rnd = random() % qr->flagOUT_weights_total; + for (unsigned int i=0; i< qr->flagOUT_weights->size(); i++) { + int w = qr->flagOUT_weights->at(i); + if (rnd < w) { + flagIN= qr->flagOUT_ids->at(i); + proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 5, "query rule %d has changed flagOUT based on weight\n", qr->rule_id); + set_flagOUT=true; + break; + } else { + rnd -= w; + } + } + } if (qr->flagOUT >= 0) { proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 5, "query rule %d has changed flagOUT\n", qr->rule_id); flagIN=qr->flagOUT; set_flagOUT=true; //sess->query_info.flagOUT=flagIN; - } + } if (qr->reconnect >= 0) { // Note: negative reconnect means this rule doesn't change proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 5, "query rule %d has set reconnect: %d. Query will%s be rexecuted if connection is lost\n", qr->rule_id, qr->reconnect, (qr->reconnect == 0 ? " NOT" : "" ));