mirror of https://github.com/sysown/proxysql
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
248 lines
7.3 KiB
248 lines
7.3 KiB
#include "mysqlx_protocol.h"
|
|
|
|
#include "mysqlx.pb.h"
|
|
#include "mysqlx_session.pb.h"
|
|
|
|
#include <cstring>
|
|
#include <openssl/evp.h>
|
|
#include <unistd.h>
|
|
|
|
namespace {
|
|
|
|
// SHA1 helper using EVP API (OpenSSL 3.0+).
|
|
bool sha1_digest(const uint8_t* data, size_t len, uint8_t out[20]) {
|
|
EVP_MD_CTX* ctx = EVP_MD_CTX_new();
|
|
if (!ctx) return false;
|
|
unsigned int out_len = 0;
|
|
bool ok = EVP_DigestInit_ex(ctx, EVP_sha1(), nullptr) == 1
|
|
&& EVP_DigestUpdate(ctx, data, len) == 1
|
|
&& EVP_DigestFinal_ex(ctx, out, &out_len) == 1
|
|
&& out_len == 20;
|
|
EVP_MD_CTX_free(ctx);
|
|
return ok;
|
|
}
|
|
|
|
bool sha1_digest_multi(const uint8_t* d1, size_t l1,
|
|
const uint8_t* d2, size_t l2,
|
|
uint8_t out[20]) {
|
|
EVP_MD_CTX* ctx = EVP_MD_CTX_new();
|
|
if (!ctx) return false;
|
|
unsigned int out_len = 0;
|
|
bool ok = EVP_DigestInit_ex(ctx, EVP_sha1(), nullptr) == 1
|
|
&& EVP_DigestUpdate(ctx, d1, l1) == 1
|
|
&& EVP_DigestUpdate(ctx, d2, l2) == 1
|
|
&& EVP_DigestFinal_ex(ctx, out, &out_len) == 1
|
|
&& out_len == 20;
|
|
EVP_MD_CTX_free(ctx);
|
|
return ok;
|
|
}
|
|
|
|
static constexpr size_t SHA1_LEN = 20;
|
|
|
|
} // namespace
|
|
|
|
std::vector<uint8_t> mysqlx_encode_frame_header(const MysqlxFrameHeader& hdr) {
|
|
std::vector<uint8_t> buf(MYSQLX_FRAME_HEADER_SIZE);
|
|
uint32_t ps = hdr.payload_size;
|
|
buf[0] = static_cast<uint8_t>(ps & 0xFF);
|
|
buf[1] = static_cast<uint8_t>((ps >> 8) & 0xFF);
|
|
buf[2] = static_cast<uint8_t>((ps >> 16) & 0xFF);
|
|
buf[3] = static_cast<uint8_t>((ps >> 24) & 0xFF);
|
|
buf[4] = hdr.message_type;
|
|
return buf;
|
|
}
|
|
|
|
std::optional<MysqlxFrameHeader> mysqlx_decode_frame_header(const uint8_t* data, size_t len) {
|
|
if (len < MYSQLX_FRAME_HEADER_SIZE) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
MysqlxFrameHeader hdr {};
|
|
hdr.payload_size = static_cast<uint32_t>(data[0])
|
|
| (static_cast<uint32_t>(data[1]) << 8)
|
|
| (static_cast<uint32_t>(data[2]) << 16)
|
|
| (static_cast<uint32_t>(data[3]) << 24);
|
|
hdr.message_type = data[4];
|
|
return hdr;
|
|
}
|
|
|
|
bool mysqlx_is_supported_auth_method(const std::string& method) {
|
|
return method == "MYSQL41" || method == "PLAIN";
|
|
}
|
|
|
|
std::vector<uint8_t> mysqlx_build_frame(uint8_t message_type, const std::string& serialized_payload) {
|
|
// payload_size in the header includes the message_type byte.
|
|
uint32_t payload_size = static_cast<uint32_t>(serialized_payload.size()) + 1;
|
|
MysqlxFrameHeader hdr { payload_size, message_type };
|
|
std::vector<uint8_t> frame = mysqlx_encode_frame_header(hdr);
|
|
frame.insert(frame.end(), serialized_payload.begin(), serialized_payload.end());
|
|
return frame;
|
|
}
|
|
|
|
bool mysqlx_read_exact(int fd, uint8_t* buf, size_t len) {
|
|
size_t total = 0;
|
|
while (total < len) {
|
|
ssize_t n = read(fd, buf + total, len - total);
|
|
if (n <= 0) {
|
|
return false;
|
|
}
|
|
total += static_cast<size_t>(n);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool mysqlx_read_frame(int fd, MysqlxFrameHeader& header, std::vector<uint8_t>& payload) {
|
|
uint8_t hdr_buf[MYSQLX_FRAME_HEADER_SIZE];
|
|
if (!mysqlx_read_exact(fd, hdr_buf, MYSQLX_FRAME_HEADER_SIZE)) {
|
|
return false;
|
|
}
|
|
|
|
auto opt = mysqlx_decode_frame_header(hdr_buf, MYSQLX_FRAME_HEADER_SIZE);
|
|
if (!opt.has_value()) {
|
|
return false;
|
|
}
|
|
|
|
header = opt.value();
|
|
|
|
if (header.payload_size > MYSQLX_MAX_PAYLOAD_SIZE) {
|
|
return false;
|
|
}
|
|
|
|
// payload_size includes the 1-byte message_type already consumed in header.
|
|
uint32_t body_size = header.payload_size > 0 ? header.payload_size - 1 : 0;
|
|
payload.resize(body_size);
|
|
if (body_size > 0) {
|
|
if (!mysqlx_read_exact(fd, payload.data(), body_size)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool mysqlx_write_all(int fd, const uint8_t* data, size_t len) {
|
|
size_t total = 0;
|
|
while (total < len) {
|
|
ssize_t n = write(fd, data + total, len - total);
|
|
if (n <= 0) {
|
|
return false;
|
|
}
|
|
total += static_cast<size_t>(n);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool mysqlx_send_error(int fd, uint16_t code, const std::string& msg, const std::string& sql_state) {
|
|
Mysqlx::Error error_msg;
|
|
error_msg.set_severity(Mysqlx::Error::ERROR);
|
|
error_msg.set_code(code);
|
|
error_msg.set_sql_state(sql_state);
|
|
error_msg.set_msg(msg);
|
|
|
|
std::string serialized;
|
|
error_msg.SerializeToString(&serialized);
|
|
|
|
auto frame = mysqlx_build_frame(
|
|
Mysqlx::ServerMessages_Type_ERROR,
|
|
serialized
|
|
);
|
|
|
|
return mysqlx_write_all(fd, frame.data(), frame.size());
|
|
}
|
|
|
|
bool mysqlx_send_ok(int fd, const std::string& msg) {
|
|
Mysqlx::Ok ok_msg;
|
|
if (!msg.empty()) {
|
|
ok_msg.set_msg(msg);
|
|
}
|
|
|
|
std::string serialized;
|
|
ok_msg.SerializeToString(&serialized);
|
|
|
|
auto frame = mysqlx_build_frame(
|
|
Mysqlx::ServerMessages_Type_OK,
|
|
serialized
|
|
);
|
|
|
|
return mysqlx_write_all(fd, frame.data(), frame.size());
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// MYSQL41 auth helpers
|
|
// ---------------------------------------------------------------------------
|
|
// MYSQL41 uses double-SHA1:
|
|
// hash_stage1 = SHA1(password)
|
|
// hash_stage2 = SHA1(hash_stage1)
|
|
// scramble = XOR(hash_stage1, SHA1(challenge + hash_stage2))
|
|
|
|
std::string mysqlx_hex_encode(const std::vector<uint8_t>& data) {
|
|
static const char hex_chars[] = "0123456789ABCDEF";
|
|
std::string result;
|
|
result.reserve(data.size() * 2);
|
|
for (uint8_t b : data) {
|
|
result += hex_chars[(b >> 4) & 0x0F];
|
|
result += hex_chars[b & 0x0F];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
bool mysqlx_hex_decode(const std::string& hex, std::vector<uint8_t>& out) {
|
|
if (hex.size() % 2 != 0) {
|
|
return false;
|
|
}
|
|
out.clear();
|
|
out.reserve(hex.size() / 2);
|
|
for (size_t i = 0; i < hex.size(); i += 2) {
|
|
uint8_t hi = 0, lo = 0;
|
|
char c_hi = hex[i], c_lo = hex[i + 1];
|
|
if (c_hi >= '0' && c_hi <= '9') hi = static_cast<uint8_t>(c_hi - '0');
|
|
else if (c_hi >= 'A' && c_hi <= 'F') hi = static_cast<uint8_t>(c_hi - 'A' + 10);
|
|
else if (c_hi >= 'a' && c_hi <= 'f') hi = static_cast<uint8_t>(c_hi - 'a' + 10);
|
|
else return false;
|
|
if (c_lo >= '0' && c_lo <= '9') lo = static_cast<uint8_t>(c_lo - '0');
|
|
else if (c_lo >= 'A' && c_lo <= 'F') lo = static_cast<uint8_t>(c_lo - 'A' + 10);
|
|
else if (c_lo >= 'a' && c_lo <= 'f') lo = static_cast<uint8_t>(c_lo - 'a' + 10);
|
|
else return false;
|
|
out.push_back(static_cast<uint8_t>((hi << 4) | lo));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::vector<uint8_t> mysqlx_mysql41_hash(const std::string& password) {
|
|
uint8_t stage1[SHA1_LEN];
|
|
sha1_digest(reinterpret_cast<const uint8_t*>(password.data()), password.size(), stage1);
|
|
|
|
std::vector<uint8_t> stage2(SHA1_LEN);
|
|
sha1_digest(stage1, SHA1_LEN, stage2.data());
|
|
return stage2;
|
|
}
|
|
|
|
std::vector<uint8_t> mysqlx_mysql41_scramble(const std::vector<uint8_t>& challenge,
|
|
const std::string& password) {
|
|
uint8_t stage1[SHA1_LEN];
|
|
sha1_digest(reinterpret_cast<const uint8_t*>(password.data()), password.size(), stage1);
|
|
|
|
uint8_t stage2[SHA1_LEN];
|
|
sha1_digest(stage1, SHA1_LEN, stage2);
|
|
|
|
uint8_t combined[SHA1_LEN];
|
|
sha1_digest_multi(challenge.data(), challenge.size(), stage2, SHA1_LEN, combined);
|
|
|
|
std::vector<uint8_t> result(SHA1_LEN);
|
|
for (size_t i = 0; i < SHA1_LEN; i++) {
|
|
result[i] = stage1[i] ^ combined[i];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
bool mysqlx_mysql41_verify(const std::vector<uint8_t>& challenge,
|
|
const std::vector<uint8_t>& client_response,
|
|
const std::string& password) {
|
|
if (client_response.size() != SHA1_LEN) {
|
|
return false;
|
|
}
|
|
|
|
auto expected = mysqlx_mysql41_scramble(challenge, password);
|
|
return expected == client_response;
|
|
}
|