Optimize transaction command parsing to avoid unnecessary tokenization

Previously, the parser always tokenized the full command, even when we only
needed to check whether it was a transaction command. Now, it first extracts
the first word to determine relevance and performs full tokenization only
when necessary.
v3.0_optimizations_and_stability
Rahim Kanji 5 months ago
parent 7c665b9f78
commit e744c2bbb7

@ -56,6 +56,9 @@ struct TxnCmd {
*/ */
class PgSQL_TxnCmdParser { class PgSQL_TxnCmdParser {
public: public:
PgSQL_TxnCmdParser() noexcept { tokens.reserve(16); }
~PgSQL_TxnCmdParser() noexcept = default;
TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept; TxnCmd parse(std::string_view input, bool in_transaction_mode) noexcept;
private: private:
@ -67,14 +70,20 @@ private:
TxnCmd parse_start(size_t& pos) noexcept; TxnCmd parse_start(size_t& pos) noexcept;
// Helpers // Helpers
static std::string to_lower(std::string_view s) noexcept { inline static bool iequals(std::string_view a, std::string_view b) noexcept {
std::string s_copy(s); if (a.size() != b.size()) return false;
std::transform(s_copy.begin(), s_copy.end(), s_copy.begin(), ::tolower); for (size_t i = 0; i < a.size(); ++i) {
return s_copy; char ca = a[i];
char cb = b[i];
if (ca >= 'A' && ca <= 'Z') ca += 32;
if (cb >= 'A' && cb <= 'Z') cb += 32;
if (ca != cb) return false;
}
return true;
} }
inline static bool contains(std::vector<std::string_view>&& list, std::string_view value) noexcept { inline static bool contains(std::vector<std::string_view>&& list, std::string_view value) noexcept {
for (const auto& item : list) if (item == value) return true; for (const auto& item : list) if (iequals(item, value)) return true;
return false; return false;
} }
}; };

@ -327,16 +327,97 @@ bool PgSQL_ExplicitTxnStateMgr::handle_transaction(std::string_view input) {
return true; return true;
} }
TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mode) noexcept { TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mode) noexcept {
tokens.clear();
TxnCmd cmd; TxnCmd cmd;
bool in_quote = false;
if (input.empty()) return cmd;
// Extract first word without full tokenization
size_t start = 0; size_t start = 0;
size_t end = 0;
while (start < input.size() && fast_isspace(input[start])) {
start++;
}
if (start >= input.size()) return cmd;
// Find end of first word
end = start;
bool in_quote = false;
char quote_char = 0; char quote_char = 0;
// Tokenize with quote handling while (end < input.size()) {
for (size_t i = 0; i <= input.size(); ++i) { char c = input[end];
if (!in_quote && (c == '"' || c == '\'')) {
// If we hit a quote at the start, this isn't a transaction command
return cmd;
}
if (fast_isspace(c) || c == ';') {
break;
}
end++;
}
std::string_view first_word = input.substr(start, end - start);
// Check if this is a transaction command we care about
TxnCmd::Type cmd_type = TxnCmd::UNKNOWN;
if (in_transaction_mode) {
if (iequals(first_word, "begin")) {
cmd.type = TxnCmd::BEGIN;
return cmd;
}
if (iequals(first_word, "start")) {
cmd_type = TxnCmd::BEGIN;
} else if (iequals(first_word, "savepoint")) {
cmd_type = TxnCmd::SAVEPOINT;
} else if (iequals(first_word, "release")) {
cmd_type = TxnCmd::RELEASE;
} else if (iequals(first_word, "rollback")) {
cmd_type = TxnCmd::ROLLBACK;
}
} else {
if (iequals(first_word, "commit") || iequals(first_word, "end")) {
cmd.type = TxnCmd::COMMIT;
return cmd;
}
if (iequals(first_word, "abort")) {
cmd.type = TxnCmd::ROLLBACK;
return cmd;
}
if (iequals(first_word, "rollback")) {
cmd_type = TxnCmd::ROLLBACK;
}
}
// If not a transaction command, return early
if (cmd_type == TxnCmd::UNKNOWN) {
return cmd;
}
// Continue tokenization from where we left off
tokens.clear();
// Continue tokenizing the rest of the input
in_quote = false;
quote_char = 0;
start = end; // Continue from after the first word
while (start < input.size() && fast_isspace(input[start])) {
start++;
}
// Tokenize the remaining input
for (size_t i = start; i <= input.size(); ++i) {
const bool at_end = i == input.size(); const bool at_end = i == input.size();
const char c = at_end ? 0 : input[i]; const char c = at_end ? 0 : input[i];
@ -344,6 +425,7 @@ TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mod
if (c == quote_char || at_end) { if (c == quote_char || at_end) {
tokens.emplace_back(input.substr(start + 1, i - start - 1)); tokens.emplace_back(input.substr(start + 1, i - start - 1));
in_quote = false; in_quote = false;
start = i + 1;
} }
continue; continue;
} }
@ -353,41 +435,50 @@ TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mod
quote_char = c; quote_char = c;
start = i; start = i;
} }
else if (isspace(c) || c == ';' || at_end) { else if (fast_isspace(c) || c == ';' || at_end) {
if (start < i) tokens.emplace_back(input.substr(start, i - start)); if (start < i) tokens.emplace_back(input.substr(start, i - start));
start = i + 1; start = i + 1;
} }
} }
if (tokens.empty()) return cmd;
size_t pos = 0; size_t pos = 0;
const std::string first = to_lower(tokens[pos++]);
if (in_transaction_mode) {
if (in_transaction_mode == true) {
if (first == "begin") cmd.type = TxnCmd::BEGIN; switch (cmd_type) {
else if (first == "start") cmd = parse_start(pos); case TxnCmd::BEGIN:
else if (first == "savepoint") cmd = parse_savepoint(pos); cmd = parse_start(pos);
else if (first == "release") cmd = parse_release(pos); break;
else if (first == "rollback") cmd = parse_rollback(pos); case TxnCmd::SAVEPOINT:
cmd = parse_savepoint(pos);
break;
case TxnCmd::RELEASE:
cmd = parse_release(pos);
break;
case TxnCmd::ROLLBACK:
cmd = parse_rollback(pos);
break;
default:
break;
}
} else { } else {
if (first == "commit" || first == "end") cmd.type = TxnCmd::COMMIT; if (cmd_type == TxnCmd::ROLLBACK)
else if (first == "abort") cmd.type = TxnCmd::ROLLBACK; cmd = parse_rollback(pos);
else if (first == "rollback") cmd = parse_rollback(pos);
} }
return cmd; return cmd;
} }
TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept { TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::ROLLBACK }; TxnCmd cmd{ TxnCmd::ROLLBACK };
while (pos < tokens.size() && contains({ "work", "transaction" }, to_lower(tokens[pos]))) pos++; while (pos < tokens.size() && contains({ "work", "transaction" }, tokens[pos])) pos++;
if (pos < tokens.size() && to_lower(tokens[pos]) == "to") { if (pos < tokens.size() && iequals(tokens[pos], "to")) {
cmd.type = TxnCmd::ROLLBACK_TO; cmd.type = TxnCmd::ROLLBACK_TO;
if (++pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++; if (++pos < tokens.size() && iequals(tokens[pos], "savepoint")) pos++;
if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
} else if (pos < tokens.size() && to_lower(tokens[pos]) == "and") { } else if (pos < tokens.size() && iequals(tokens[pos], "and")) {
if (++pos < tokens.size() && to_lower(tokens[pos]) == "chain") { if (++pos < tokens.size() && iequals(tokens[pos], "chain")) {
cmd.type = TxnCmd::ROLLBACK_AND_CHAIN; cmd.type = TxnCmd::ROLLBACK_AND_CHAIN;
pos++; pos++;
} }
@ -403,14 +494,14 @@ TxnCmd PgSQL_TxnCmdParser::parse_savepoint(size_t& pos) noexcept {
TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept { TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::RELEASE }; TxnCmd cmd{ TxnCmd::RELEASE };
if (pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++; if (pos < tokens.size() && iequals(tokens[pos], "savepoint")) pos++;
if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
return cmd; return cmd;
} }
TxnCmd PgSQL_TxnCmdParser::parse_start(size_t& pos) noexcept { TxnCmd PgSQL_TxnCmdParser::parse_start(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::UNKNOWN }; TxnCmd cmd{ TxnCmd::UNKNOWN };
if (pos < tokens.size() && to_lower(tokens[pos]) == "transaction") { if (pos < tokens.size() && iequals(tokens[pos], "transaction")) {
cmd.type = TxnCmd::BEGIN; cmd.type = TxnCmd::BEGIN;
pos++; pos++;
} }

@ -741,7 +741,7 @@ std::vector<std::pair<std::string, std::string>> PgSQL_Protocol::parse_options(c
while (pos < input.size()) { while (pos < input.size()) {
// Skip leading spaces // Skip leading spaces
while (pos < input.size() && std::isspace(input[pos])) { while (pos < input.size() && fast_isspace(input[pos])) {
++pos; ++pos;
} }
@ -751,7 +751,7 @@ std::vector<std::pair<std::string, std::string>> PgSQL_Protocol::parse_options(c
pos += 2; // Skip "-c", "--" pos += 2; // Skip "-c", "--"
} }
while (pos < input.size() && std::isspace(input[pos])) { while (pos < input.size() && fast_isspace(input[pos])) {
++pos; ++pos;
} }
@ -772,7 +772,7 @@ std::vector<std::pair<std::string, std::string>> PgSQL_Protocol::parse_options(c
bool last_was_escape = false; bool last_was_escape = false;
while (pos < input.size()) { while (pos < input.size()) {
char c = input[pos]; char c = input[pos];
if (std::isspace(c) && !last_was_escape) { if (fast_isspace(c) && !last_was_escape) {
break; break;
} }
if (c == '\\' && !last_was_escape) { if (c == '\\' && !last_was_escape) {

@ -6557,7 +6557,7 @@ std::vector<std::string> PgSQL_DateStyle_Util::split_datestyle(std::string_view
int* lastNonSpace = (currentToken == 1) ? &lastNonSpace1 : &lastNonSpace2; int* lastNonSpace = (currentToken == 1) ? &lastNonSpace1 : &lastNonSpace2;
// Cache is-space check. // Cache is-space check.
bool is_space = std::isspace(static_cast<unsigned char>(c)); bool is_space = fast_isspace(static_cast<unsigned char>(c));
// Skip leading whitespace for a new token. // Skip leading whitespace for a new token.
if (currentStr->empty() && is_space) { if (currentStr->empty() && is_space) {
continue; continue;

@ -183,7 +183,7 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
if (transformed_value) *transformed_value = nullptr; if (transformed_value) *transformed_value = nullptr;
// Skip leading whitespace // Skip leading whitespace
while (isspace((unsigned char)*p)) p++; while (fast_isspace((unsigned char)*p)) p++;
// Parse numeric part // Parse numeric part
num = strtoll(p, &endptr, 10); num = strtoll(p, &endptr, 10);
@ -196,11 +196,11 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
p = endptr; p = endptr;
// Skip whitespace after number // Skip whitespace after number
while (isspace((unsigned char)*p)) p++; while (fast_isspace((unsigned char)*p)) p++;
// Parse unit // Parse unit
if (*p != '\0') { if (*p != '\0') {
char tmp_unit = tolower(*p); char tmp_unit = ::tolower(*p);
switch (tmp_unit) { switch (tmp_unit) {
case 'k': case 'k':
case 'm': case 'm':
@ -210,7 +210,7 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
unit = toupper(*p++); unit = toupper(*p++);
has_unit = true; has_unit = true;
// Check optional 'b'/'B' // Check optional 'b'/'B'
if (tolower(*p) == 'b') p++; if (::tolower(*p) == 'b') p++;
break; break;
default: default:
return false; return false;
@ -218,7 +218,7 @@ bool pgsql_variable_validate_maintenance_work_mem(const char* value, const param
} }
// Skip trailing whitespace // Skip trailing whitespace
while (isspace((unsigned char)*p)) p++; while (fast_isspace((unsigned char)*p)) p++;
// Validate entire string consumed // Validate entire string consumed
if (*p != '\0') return false; if (*p != '\0') return false;
@ -241,7 +241,7 @@ bool pgsql_variable_validate_maintenance_work_mem_v2(const char* value, const pa
const char* input = value; const char* input = value;
/* Trim leading whitespace */ /* Trim leading whitespace */
while (isspace((unsigned char)*input)) input++; while (fast_isspace((unsigned char)*input)) input++;
/* Parse numeric part */ /* Parse numeric part */
uint64_t number; uint64_t number;
@ -256,7 +256,7 @@ bool pgsql_variable_validate_maintenance_work_mem_v2(const char* value, const pa
//num_len = endptr - input; //num_len = endptr - input;
// Skip whitespace after number // Skip whitespace after number
while (isspace((unsigned char)*endptr)) endptr++; while (fast_isspace((unsigned char)*endptr)) endptr++;
/* Parse unit part */ /* Parse unit part */
const char* unit_ptr = endptr; const char* unit_ptr = endptr;
@ -273,7 +273,7 @@ bool pgsql_variable_validate_maintenance_work_mem_v2(const char* value, const pa
/* Convert unit to lowercase for validation */ /* Convert unit to lowercase for validation */
char u[3] = { 0 }; char u[3] = { 0 };
for (int i = 0; i < 2 && unit_ptr[i]; i++) for (int i = 0; i < 2 && unit_ptr[i]; i++)
u[i] = tolower((unsigned char)unit_ptr[i]); u[i] = ::tolower((unsigned char)unit_ptr[i]);
/* Validate unit and set multiplier */ /* Validate unit and set multiplier */
if (unit_len == 1 && u[0] == 'b') { if (unit_len == 1 && u[0] == 'b') {
@ -332,7 +332,7 @@ bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const pa
(void)session; (void)session;
// Trim leading whitespace // Trim leading whitespace
while (isspace((unsigned char)*value)) value++; while (fast_isspace((unsigned char)*value)) value++;
char* endptr; char* endptr;
const char* num_start = value; const char* num_start = value;
@ -371,7 +371,7 @@ bool pgsql_variable_validate_maintenance_work_mem_v3(const char* value, const pa
// Convert unit to lowercase for validation // Convert unit to lowercase for validation
char u[3] = { 0 }; char u[3] = { 0 };
for (int i = 0; i < 2 && unit_ptr[i]; i++) for (int i = 0; i < 2 && unit_ptr[i]; i++)
u[i] = tolower((unsigned char)unit_ptr[i]); u[i] = ::tolower((unsigned char)unit_ptr[i]);
// Validate units and set multipliers // Validate units and set multipliers
if (unit_len == 1 && u[0] == 'b') { if (unit_len == 1 && u[0] == 'b') {
@ -471,7 +471,7 @@ bool pgsql_variable_validate_search_path(const char* value, const params_t* para
while (*token && result) { while (*token && result) {
/* skip leading whitespace */ /* skip leading whitespace */
while (*token && isspace((unsigned char)*token)) token++; while (*token && fast_isspace((unsigned char)*token)) token++;
if (*token == '\0') break; if (*token == '\0') break;
const char* part_start = token; const char* part_start = token;
@ -508,7 +508,7 @@ bool pgsql_variable_validate_search_path(const char* value, const params_t* para
} }
} else { } else {
// unquoted identifier or $user // unquoted identifier or $user
while (*token && *token != ',' && !isspace(*token)) token++; while (*token && *token != ',' && !fast_isspace(*token)) token++;
part_len = (size_t)(token - part_start); part_len = (size_t)(token - part_start);
if (part_len == 0 || part_len > 63) { if (part_len == 0 || part_len > 63) {
result = false; result = false;
@ -543,7 +543,7 @@ bool pgsql_variable_validate_search_path(const char* value, const params_t* para
normalized[norm_pos] = '\0'; normalized[norm_pos] = '\0';
// skip whitespace after part // skip whitespace after part
while (*token && isspace(*token)) token++; while (*token && fast_isspace(*token)) token++;
// expect comma or end // expect comma or end
if (*token == ',') { if (*token == ',') {

@ -92,7 +92,7 @@ int ProxySQL_Config::Read_Global_Variables_from_configfile(const char *prefix) {
char *query=(char *)malloc(strlen(q)+strlen(prefix)+strlen(n)+strlen(value_string.c_str())); char *query=(char *)malloc(strlen(q)+strlen(prefix)+strlen(n)+strlen(value_string.c_str()));
sprintf(query,q, prefix, n, value_string.c_str()); sprintf(query,q, prefix, n, value_string.c_str());
//fprintf(stderr, "%s\n", query); //fprintf(stderr, "%s\n", query);
admindb->execute(query); admindb->execute(query);
free(query); free(query);
} }
admindb->execute("PRAGMA foreign_keys = ON"); admindb->execute("PRAGMA foreign_keys = ON");

Loading…
Cancel
Save