diff --git a/include/PgSQL_ExplicitTxnStateMgr.h b/include/PgSQL_ExplicitTxnStateMgr.h index ea50cb28c..3fbf66da5 100644 --- a/include/PgSQL_ExplicitTxnStateMgr.h +++ b/include/PgSQL_ExplicitTxnStateMgr.h @@ -41,7 +41,8 @@ struct TxnCmd { ROLLBACK, SAVEPOINT, RELEASE, - ROLLBACK_TO + ROLLBACK_TO, + ROLLBACK_AND_CHAIN } type = Type::UNKNOWN; std::string savepoint; //< The name of the savepoint, if applicable. }; @@ -63,6 +64,7 @@ private: TxnCmd parse_rollback(size_t& pos) noexcept; TxnCmd parse_savepoint(size_t& pos) noexcept; TxnCmd parse_release(size_t& pos) noexcept; + TxnCmd parse_start(size_t& pos) noexcept; // Helpers static std::string to_lower(std::string_view s) noexcept { @@ -101,7 +103,7 @@ private: void start_transaction(); void commit(); - void rollback(); + void rollback(bool rollback_and_chain); bool add_savepoint(std::string_view name); bool rollback_to_savepoint(std::string_view name); bool release_savepoint(std::string_view name); diff --git a/lib/PgSQL_ExplicitTxnStateMgr.cpp b/lib/PgSQL_ExplicitTxnStateMgr.cpp index 268a11101..3fd4d1bf0 100644 --- a/lib/PgSQL_ExplicitTxnStateMgr.cpp +++ b/lib/PgSQL_ExplicitTxnStateMgr.cpp @@ -86,7 +86,7 @@ void PgSQL_ExplicitTxnStateMgr::commit() { verify_server_variables(session); } -void PgSQL_ExplicitTxnStateMgr::rollback() { +void PgSQL_ExplicitTxnStateMgr::rollback(bool rollback_and_chain) { if (transaction_state.empty()) { proxy_warning("Received ROLLBACK command. There is no transaction in progress\n"); @@ -124,11 +124,15 @@ void PgSQL_ExplicitTxnStateMgr::rollback() { verify_server_variables(session); } - // Clear savepoints and reset the initial snapshot - for (auto& tran_state : transaction_state) { - reset_variable_snapshot(tran_state); + // Keep the transaction state intact when executing ROLLBACK AND CHAIN + if (rollback_and_chain == false) { + // Clear savepoints and reset the initial snapshot + for (auto& tran_state : transaction_state) { + reset_variable_snapshot(tran_state); + } + transaction_state.clear(); } - transaction_state.clear(); + savepoint.clear(); } @@ -296,7 +300,10 @@ bool PgSQL_ExplicitTxnStateMgr::handle_transaction(std::string_view input) { commit(); break; case TxnCmd::ROLLBACK: - rollback(); + rollback(false); + break; + case TxnCmd::ROLLBACK_AND_CHAIN: + rollback(true); break; case TxnCmd::SAVEPOINT: return add_savepoint(cmd.savepoint); @@ -350,12 +357,14 @@ TxnCmd PgSQL_TxnCmdParser::parse(std::string_view input, bool in_transaction_mod if (in_transaction_mode == true) { if (first == "begin") cmd.type = TxnCmd::BEGIN; + else if (first == "start") cmd = parse_start(pos); else if (first == "savepoint") cmd = parse_savepoint(pos); else if (first == "release") cmd = parse_release(pos); else if (first == "rollback") cmd = parse_rollback(pos); } else { - if (first == "commit") cmd.type = TxnCmd::COMMIT; - else if (first == "rollback" || (first == "abort")) cmd = parse_rollback(pos); + if (first == "commit" || first == "end") cmd.type = TxnCmd::COMMIT; + else if (first == "abort") cmd.type = TxnCmd::ROLLBACK; + else if (first == "rollback") cmd = parse_rollback(pos); } return cmd; } @@ -368,6 +377,11 @@ TxnCmd PgSQL_TxnCmdParser::parse_rollback(size_t& pos) noexcept { cmd.type = TxnCmd::ROLLBACK_TO; if (++pos < tokens.size() && to_lower(tokens[pos]) == "savepoint") pos++; if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; + } else if (pos < tokens.size() && to_lower(tokens[pos]) == "and") { + if (++pos < tokens.size() && to_lower(tokens[pos]) == "chain") { + cmd.type = TxnCmd::ROLLBACK_AND_CHAIN; + pos++; + } } return cmd; } @@ -384,3 +398,12 @@ TxnCmd PgSQL_TxnCmdParser::parse_release(size_t& pos) noexcept { if (pos < tokens.size()) cmd.savepoint = tokens[pos++]; return cmd; } + +TxnCmd PgSQL_TxnCmdParser::parse_start(size_t& pos) noexcept { + TxnCmd cmd{ TxnCmd::UNKNOWN }; + if (pos < tokens.size() && to_lower(tokens[pos]) == "transaction") { + cmd.type = TxnCmd::BEGIN; + pos++; + } + return cmd; +} diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index 4392a193a..078c79f1d 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -517,22 +517,17 @@ bool PgSQL_Session::handler_CommitRollback(PtrSize_t* pkt) { if (pkt->size <= 5) { return false; } char c = ((char*)pkt->ptr)[5]; bool ret = false; - if (c == 'c' || c == 'C') { - if (pkt->size >= sizeof("commit") + 5) { - if (strncasecmp((char*)"commit", (char*)pkt->ptr + 5, 6) == 0) { - __sync_fetch_and_add(&PgHGM->status.commit_cnt, 1); - ret = true; - } + if (c == 'c' || c == 'C' || c == 'e' || c == 'E') { + if ((pkt->size >= 5 + 6 && strncasecmp("commit", (char*)pkt->ptr + 5, 6) == 0) || + (pkt->size >= 5 + 3 && strncasecmp("end", (char*)pkt->ptr + 5, 3) == 0)) { + __sync_fetch_and_add(&PgHGM->status.commit_cnt, 1); + ret = true; } - } - else { - if (c == 'r' || c == 'R') { - if (pkt->size >= sizeof("rollback") + 5) { - if (strncasecmp((char*)"rollback", (char*)pkt->ptr + 5, 8) == 0) { - __sync_fetch_and_add(&PgHGM->status.rollback_cnt, 1); - ret = true; - } - } + } else if (c == 'r' || c == 'R' || c == 'a' || c == 'A') { + if ((pkt->size >= 5 + 8 && strncasecmp("rollback", (char*)pkt->ptr + 5, 8) == 0) || + (pkt->size >= 5 + 5 && strncasecmp("abort", (char*)pkt->ptr + 5, 5) == 0)) { + __sync_fetch_and_add(&PgHGM->status.rollback_cnt, 1); + ret = true; } } @@ -568,7 +563,7 @@ bool PgSQL_Session::handler_CommitRollback(PtrSize_t* pkt) { status = WAITING_CLIENT_DATA; } l_free(pkt->size, pkt->ptr); - if (c == 'c' || c == 'C') { + if (c == 'c' || c == 'C' || c == 'e' || c == 'E') { __sync_fetch_and_add(&PgHGM->status.commit_cnt_filtered, 1); } else { __sync_fetch_and_add(&PgHGM->status.rollback_cnt_filtered, 1);