You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
proxysql/lib/PgSQL_ExplicitTxnStateMgr.cpp

510 lines
17 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "PgSQL_ExplicitTxnStateMgr.h"
#include "proxysql.h"
#include "PgSQL_Session.h"
#include "PgSQL_Data_Stream.h"
#include "PgSQL_Connection.h"
extern class PgSQL_Variables pgsql_variables;
PgSQL_ExplicitTxnStateMgr::PgSQL_ExplicitTxnStateMgr(PgSQL_Session* sess) : session(sess) {
}
PgSQL_ExplicitTxnStateMgr::~PgSQL_ExplicitTxnStateMgr() {
for (auto& tran_state : transaction_state) {
reset_variable_snapshot(tran_state);
}
transaction_state.clear();
savepoint.clear();
}
void PgSQL_ExplicitTxnStateMgr::reset_variable_snapshot(PgSQL_Variable_Snapshot& var_snapshot) noexcept {
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
if (var_snapshot.var_value[idx]) {
free(var_snapshot.var_value[idx]);
var_snapshot.var_value[idx] = nullptr;
}
var_snapshot.var_hash[idx] = 0;
}
}
void verify_server_variables(PgSQL_Session* session) {
#ifdef DEBUG
for (int idx = 0; idx < PGSQL_NAME_LAST_LOW_WM; idx++) {
const char* conn_param_status = session->mybe->server_myds->myconn->get_pg_parameter_status(pgsql_tracked_variables[idx].set_variable_name);
const char* param_value = session->mybe->server_myds->myconn->variables[idx].value;
if (conn_param_status && param_value) {
//assert(strcmp(conn_param_status, param_value) == 0);
if (strcmp(conn_param_status, param_value) != 0) {
// This isnt actually a bug, but it can occur in an edge case — for example, when a COPY FROM STDIN fails.
// In that situation, the ParameterStatus message sent from the server is received and forwarded to the client
// via fast-forwarding, so the internal ParameterStatus in libpq isnt updated.
proxy_warning("Server variable '%s' mismatch. Parameter status value: '%s', Expected value: '%s'\n",
pgsql_tracked_variables[idx].set_variable_name,
conn_param_status,
param_value);
}
}
}
#endif
}
void PgSQL_ExplicitTxnStateMgr::start_transaction() {
if (transaction_state.empty() == false) {
// Transaction already started, do nothing and return
proxy_warning("Received BEGIN command. There is already a transaction in progress\n");
assert(session->NumActiveTransactions() > 0);
return;
}
assert(session->client_myds && session->client_myds->myconn);
PgSQL_Variable_Snapshot var_snapshot{};
// check if already in transaction, if yes then do nothing
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = pgsql_variables.client_get_hash(session, idx);
if (hash != 0) {
var_snapshot.var_hash[idx] = hash;
var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx));
} else {
assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null
// no need to store the value
//var_snapshot.var_hash[idx] = 0;
//var_snapshot.var_value[idx] = NULL;
}
}
transaction_state.emplace_back(std::move(var_snapshot));
}
void PgSQL_ExplicitTxnStateMgr::commit() {
if (transaction_state.empty()) {
proxy_warning("Received COMMIT command. There is no transaction in progress\n");
assert(session->NumActiveTransactions() == 0);
return;
}
assert(session->client_myds && session->client_myds->myconn);
for (auto& tran_state : transaction_state) {
reset_variable_snapshot(tran_state);
}
transaction_state.clear();
savepoint.clear();
verify_server_variables(session);
}
void PgSQL_ExplicitTxnStateMgr::rollback(bool rollback_and_chain) {
if (transaction_state.empty()) {
proxy_warning("Received ROLLBACK command. There is no transaction in progress\n");
assert(session->NumActiveTransactions() == 0);
return;
}
assert(session->client_myds && session->client_myds->myconn);
const PgSQL_Variable_Snapshot& var_snapshot = transaction_state.front();
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = var_snapshot.var_hash[idx];
if (hash != 0) {
uint32_t client_hash = pgsql_variables.client_get_hash(session, idx);
uint32_t server_hash = pgsql_variables.server_get_hash(session, idx);
assert(client_hash == server_hash);
if (hash == client_hash)
continue;
pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
} else {
assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null
pgsql_variables.client_reset_value(session, idx, false);
pgsql_variables.server_reset_value(session, idx, false);
}
}
// reuse of connection that has extra param set in connection
session->client_myds->myconn->reorder_dynamic_variables_idx();
if (session->mybe) {
session->mybe->server_myds->myconn->reorder_dynamic_variables_idx();
verify_server_variables(session);
}
// Keep the transaction state intact when executing ROLLBACK AND CHAIN
if (rollback_and_chain == false) {
// Clear savepoints and reset the initial snapshot
for (auto& tran_state : transaction_state) {
reset_variable_snapshot(tran_state);
}
transaction_state.clear();
}
savepoint.clear();
}
bool PgSQL_ExplicitTxnStateMgr::rollback_to_savepoint(std::string_view name) {
if (transaction_state.empty()) {
proxy_warning("Received ROLLBACK TO SAVEPOINT '%s' command. There is no transaction in progress\n", name.data());
assert(session->NumActiveTransactions() == 0);
return false;
}
assert(session->client_myds && session->client_myds->myconn);
int tran_state_idx = -1;
for (size_t idx = 0; idx < savepoint.size(); idx++) {
if (savepoint[idx].size() == name.size() &&
strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) {
tran_state_idx = idx;
break;
}
}
if (tran_state_idx == -1) {
proxy_warning("Savepoint '%s' not found.\n", name.data());
return false;
};
assert(tran_state_idx + 1 < (int)transaction_state.size());
PgSQL_Variable_Snapshot& var_snapshot = transaction_state[tran_state_idx+1];
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = var_snapshot.var_hash[idx];
if (hash != 0) {
uint32_t client_hash = pgsql_variables.client_get_hash(session, idx);
uint32_t server_hash = pgsql_variables.server_get_hash(session, idx);
assert(client_hash == server_hash);
if (hash == client_hash)
continue;
pgsql_variables.client_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
pgsql_variables.server_set_hash_and_value(session, idx, var_snapshot.var_value[idx], hash);
}
else {
assert(idx >= PGSQL_NAME_LAST_LOW_WM); // Critical parameters/variables cannot be null
pgsql_variables.client_reset_value(session, idx, false);
pgsql_variables.server_reset_value(session, idx, false);
}
}
session->client_myds->myconn->reorder_dynamic_variables_idx();
if (session->mybe) {
session->mybe->server_myds->myconn->reorder_dynamic_variables_idx();
verify_server_variables(session);
}
for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) {
reset_variable_snapshot(transaction_state[idx]);
}
transaction_state.resize(tran_state_idx + 1);
savepoint.resize(tran_state_idx);
return true;
}
bool PgSQL_ExplicitTxnStateMgr::release_savepoint(std::string_view name) {
if (transaction_state.empty()) {
proxy_warning("Received RELEASE SAVEPOINT '%s' command. There is no transaction in progress\n", name.data());
assert(session->NumActiveTransactions() == 0);
return false;
}
assert(session->client_myds && session->client_myds->myconn);
int tran_state_idx = -1;
for (size_t idx = 0; idx < savepoint.size(); idx++) {
if (savepoint[idx].size() == name.size() &&
strncasecmp(savepoint[idx].c_str(), name.data(), name.size()) == 0) {
tran_state_idx = idx;
break;
}
}
if (tran_state_idx == -1) {
proxy_warning("SAVEPOINT '%s' not found.\n", name.data());
return false;
};
for (size_t idx = tran_state_idx + 1; idx < transaction_state.size(); idx++) {
reset_variable_snapshot(transaction_state[idx]);
}
transaction_state.resize(tran_state_idx + 1);
savepoint.resize(tran_state_idx);
return true;
}
bool PgSQL_ExplicitTxnStateMgr::add_savepoint(std::string_view name) {
if (transaction_state.empty()) {
proxy_warning("Received SAVEPOINT '%s' command. There is no transaction in progress\n", name.data());
assert(session->NumActiveTransactions() == 0);
return false;
}
assert(session->client_myds && session->client_myds->myconn);
auto it = std::find_if(savepoint.begin(), savepoint.end(), [name](std::string_view sp) {
return sp.size() == name.size() &&
strncasecmp(sp.data(), name.data(), name.size()) == 0;
});
if (it != savepoint.end()) return false;
PgSQL_Variable_Snapshot var_snapshot{};
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = pgsql_variables.client_get_hash(session, idx);
if (hash != 0) {
var_snapshot.var_hash[idx] = hash;
var_snapshot.var_value[idx] = strdup(pgsql_variables.client_get_value(session, idx));
}
}
transaction_state.emplace_back(std::move(var_snapshot));
savepoint.emplace_back(name);
assert((transaction_state.size() - 1) == savepoint.size());
return true;
}
void PgSQL_ExplicitTxnStateMgr::fill_internal_session(nlohmann::json& j) {
if (transaction_state.empty()) return;
auto& initial_state = j["initial_state"];
for (int idx = 0; idx < PGSQL_NAME_LAST_HIGH_WM; idx++) {
uint32_t hash = transaction_state[0].var_hash[idx];
if (hash != 0) {
initial_state[pgsql_tracked_variables[idx].set_variable_name] = transaction_state[0].var_value[idx];
}
}
if (savepoint.empty()) return;
for (size_t idx = 0; idx < savepoint.size(); idx++) {
auto& savepoint_json = j["savepoints"][savepoint[idx]];
int tran_state_idx = idx + 1;
for (int idx2 = 0; idx2 < PGSQL_NAME_LAST_HIGH_WM; idx2++) {
uint32_t hash = transaction_state[tran_state_idx].var_hash[idx2];
if (hash != 0) {
savepoint_json[pgsql_tracked_variables[idx2].set_variable_name] = transaction_state[tran_state_idx].var_value[idx2];
}
}
}
}
bool PgSQL_ExplicitTxnStateMgr::handle_transaction(std::string_view input) {
TxnCmd cmd = tx_parser.parse(input, (session->active_transactions > 0));
switch (cmd.type) {
case TxnCmd::BEGIN:
start_transaction();
break;
case TxnCmd::COMMIT:
commit();
break;
case TxnCmd::ROLLBACK:
rollback(false);
break;
case TxnCmd::ROLLBACK_AND_CHAIN:
rollback(true);
break;
case TxnCmd::SAVEPOINT:
return add_savepoint(cmd.savepoint);
case TxnCmd::RELEASE:
return release_savepoint(cmd.savepoint);
case TxnCmd::ROLLBACK_TO:
return rollback_to_savepoint(cmd.savepoint);
default:
// Unknown command
return false;
}
return true;
}
TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mode) noexcept {
TxnCmd cmd;
if (input.empty()) return cmd;
// Extract first word without full tokenization
size_t start = 0;
size_t end = 0;
while (start < input.size() && fast_isspace(input[start])) {
start++;
}
if (start >= input.size()) return cmd;
// Find end of first word
end = start;
bool in_quote = false;
char quote_char = 0;
while (end < input.size()) {
char c = input[end];
if (!in_quote && (c == '"' || c == '\'')) {
// If we hit a quote at the start, this isn't a transaction command
return cmd;
}
if (fast_isspace(c) || c == ';') {
break;
}
end++;
}
std::string_view first_word = input.substr(start, end - start);
// Check if this is a transaction command we care about
TxnCmd::Type cmd_type = TxnCmd::UNKNOWN;
if (in_transaction_mode) {
if (iequals(first_word, "begin")) {
cmd.type = TxnCmd::BEGIN;
return cmd;
}
if (iequals(first_word, "start")) {
cmd_type = TxnCmd::BEGIN;
} else if (iequals(first_word, "savepoint")) {
cmd_type = TxnCmd::SAVEPOINT;
} else if (iequals(first_word, "release")) {
cmd_type = TxnCmd::RELEASE;
} else if (iequals(first_word, "rollback")) {
cmd_type = TxnCmd::ROLLBACK;
}
} else {
if (iequals(first_word, "commit") || iequals(first_word, "end")) {
cmd.type = TxnCmd::COMMIT;
return cmd;
}
if (iequals(first_word, "abort")) {
cmd.type = TxnCmd::ROLLBACK;
return cmd;
}
if (iequals(first_word, "rollback")) {
cmd_type = TxnCmd::ROLLBACK;
}
}
// If not a transaction command, return early
if (cmd_type == TxnCmd::UNKNOWN) {
return cmd;
}
// Continue tokenization from where we left off
tokens.clear();
// Continue tokenizing the rest of the input
in_quote = false;
quote_char = 0;
start = end; // Continue from after the first word
while (start < input.size() && fast_isspace(input[start])) {
start++;
}
// Tokenize the remaining input
for (size_t i = start; i <= input.size(); ++i) {
const bool at_end = i == input.size();
const char c = at_end ? 0 : input[i];
if (in_quote) {
if (c == quote_char || at_end) {
tokens.emplace_back(input.substr(start + 1, i - start - 1));
in_quote = false;
start = i + 1;
}
continue;
}
if (c == '"' || c == '\'') {
in_quote = true;
quote_char = c;
start = i;
}
else if (fast_isspace(c) || c == ';' || at_end) {
if (start < i) tokens.emplace_back(input.substr(start, i - start));
start = i + 1;
}
}
size_t pos = 0;
if (in_transaction_mode) {
switch (cmd_type) {
case TxnCmd::BEGIN:
cmd = parse_start(pos);
break;
case TxnCmd::SAVEPOINT:
cmd = parse_savepoint(pos);
break;
case TxnCmd::RELEASE:
cmd = parse_release(pos);
break;
case TxnCmd::ROLLBACK:
cmd = parse_rollback(pos);
break;
default:
break;
}
} else {
if (cmd_type == TxnCmd::ROLLBACK)
cmd = parse_rollback(pos);
}
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::ROLLBACK };
while (pos < tokens.size() && contains({ "work", "transaction" }, tokens[pos])) pos++;
if (pos < tokens.size() && iequals(tokens[pos], "to")) {
cmd.type = TxnCmd::ROLLBACK_TO;
if (++pos < tokens.size() && iequals(tokens[pos], "savepoint")) pos++;
if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
} else if (pos < tokens.size() && iequals(tokens[pos], "and")) {
if (++pos < tokens.size() && iequals(tokens[pos], "chain")) {
cmd.type = TxnCmd::ROLLBACK_AND_CHAIN;
pos++;
}
}
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_savepoint(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::SAVEPOINT };
if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::RELEASE };
if (pos < tokens.size() && iequals(tokens[pos], "savepoint")) pos++;
if (pos < tokens.size()) cmd.savepoint = tokens[pos++];
return cmd;
}
TxnCmd PgSQL_TxnCmdParser::parse_start(size_t& pos) noexcept {
TxnCmd cmd{ TxnCmd::UNKNOWN };
if (pos < tokens.size() && iequals(tokens[pos], "transaction")) {
cmd.type = TxnCmd::BEGIN;
pos++;
}
return cmd;
}