mirror of https://github.com/sysown/proxysql
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
7.0 KiB
175 lines
7.0 KiB
#include "PgSQL_Backend_Protocol.h"
|
|
#include <cstring>
|
|
#include <cstdlib>
|
|
#include <cstdio>
|
|
#include <new>
|
|
#include <string>
|
|
#include <openssl/md5.h> // project-existing one-shot MD5(); also pulls MD5_DIGEST_LENGTH
|
|
#include "scram.h" // vendored libscram (same include used by PgSQL_Data_Stream.h)
|
|
|
|
static void put_be32(unsigned char* p, uint32_t v) {
|
|
p[0] = (v >> 24) & 0xff;
|
|
p[1] = (v >> 16) & 0xff;
|
|
p[2] = (v >> 8) & 0xff;
|
|
p[3] = v & 0xff;
|
|
}
|
|
|
|
void pg_build_ssl_request(unsigned char out[8]) {
|
|
put_be32(out, 8);
|
|
put_be32(out + 4, 80877103u); // 0x04d2162f
|
|
}
|
|
|
|
bool pg_build_startup(unsigned char* out, size_t* out_len, size_t out_cap,
|
|
const char* user, const char* database) {
|
|
// Compute the required size first so a bound check can reject before any write,
|
|
// guaranteeing no partial/oversized output is left in the caller buffer.
|
|
// length(4) + protocol(4) + "user\0" + user\0 + "database\0" + database\0 + \0
|
|
size_t need = 8 + 5 + (strlen(user) + 1) + 9 + (strlen(database) + 1) + 1;
|
|
if (need > out_cap) {
|
|
*out_len = 0;
|
|
return false;
|
|
}
|
|
|
|
size_t off = 8; // reserve length(4) + protocol(4)
|
|
auto add = [&](const char* s) { size_t l = strlen(s) + 1; memcpy(out + off, s, l); off += l; };
|
|
add("user"); add(user);
|
|
add("database"); add(database);
|
|
out[off++] = 0; // terminating empty key
|
|
|
|
put_be32(out, (uint32_t)off); // total length (includes the length field itself)
|
|
put_be32(out + 4, 196608u); // protocol 3.0 = 0x00030000
|
|
*out_len = off;
|
|
return true;
|
|
}
|
|
|
|
// Lowercase-hex of an MD5 digest over [in, in+inlen). out_hex receives 32 hex
|
|
// chars plus a terminating NUL (33 bytes total).
|
|
static void md5_hex(const unsigned char* in, size_t inlen, char out_hex[33]) {
|
|
unsigned char digest[MD5_DIGEST_LENGTH];
|
|
MD5(in, inlen, digest);
|
|
static const char hexd[] = "0123456789abcdef";
|
|
for (int i = 0; i < MD5_DIGEST_LENGTH; i++) {
|
|
out_hex[i * 2] = hexd[(digest[i] >> 4) & 0xf];
|
|
out_hex[i * 2 + 1] = hexd[digest[i] & 0xf];
|
|
}
|
|
out_hex[MD5_DIGEST_LENGTH * 2] = '\0';
|
|
}
|
|
|
|
void pg_build_md5(char out[36], const char* user, const char* password, const unsigned char salt[4]) {
|
|
// inner = hex(md5(password + user)). Hash the concatenation without an
|
|
// intermediate NUL-terminated copy by passing each part length explicitly.
|
|
size_t plen = strlen(password);
|
|
size_t ulen = strlen(user);
|
|
{
|
|
unsigned char digest[MD5_DIGEST_LENGTH];
|
|
MD5_CTX ctx;
|
|
MD5_Init(&ctx);
|
|
MD5_Update(&ctx, password, plen);
|
|
MD5_Update(&ctx, user, ulen);
|
|
MD5_Final(digest, &ctx);
|
|
static const char hexd[] = "0123456789abcdef";
|
|
char inner_hex[33];
|
|
for (int i = 0; i < MD5_DIGEST_LENGTH; i++) {
|
|
inner_hex[i * 2] = hexd[(digest[i] >> 4) & 0xf];
|
|
inner_hex[i * 2 + 1] = hexd[digest[i] & 0xf];
|
|
}
|
|
// outer input = 32 inner hex chars + 4 raw salt bytes (NOT NUL-terminated).
|
|
unsigned char outer_in[MD5_DIGEST_LENGTH * 2 + 4];
|
|
memcpy(outer_in, inner_hex, MD5_DIGEST_LENGTH * 2);
|
|
memcpy(outer_in + MD5_DIGEST_LENGTH * 2, salt, 4);
|
|
|
|
char outer_hex[33];
|
|
md5_hex(outer_in, sizeof(outer_in), outer_hex);
|
|
|
|
memcpy(out, "md5", 3);
|
|
memcpy(out + 3, outer_hex, 33); // 32 hex chars + NUL -> out[3..35]
|
|
}
|
|
}
|
|
|
|
// --- SCRAM-SHA-256 client exchange (thin wrappers over libscram) ---
|
|
|
|
// Owns the libscram ScramState plus a cached PgCredentials and the message strings
|
|
// libscram hands back as malloc'd C-strings. Holding the latest message of each kind
|
|
// keeps the returned pointers valid for the caller (the libscram functions otherwise
|
|
// leak the strings to their caller) and lets the destructor free them.
|
|
struct PgSQL_Scram_State {
|
|
ScramState* st = nullptr;
|
|
PgCredentials creds{}; // value-initialized -> all fields zeroed, has_scram_keys=false
|
|
char* client_first = nullptr;
|
|
char* client_final = nullptr;
|
|
};
|
|
|
|
PgSQL_Scram_State* pg_scram_new() {
|
|
PgSQL_Scram_State* s = new (std::nothrow) PgSQL_Scram_State();
|
|
if (s == nullptr) return nullptr;
|
|
s->st = scram_state_init();
|
|
if (s->st == nullptr) { delete s; return nullptr; }
|
|
return s;
|
|
}
|
|
|
|
void pg_scram_free(PgSQL_Scram_State* s) {
|
|
if (s == nullptr) return;
|
|
if (s->st) free_scram_state(s->st); // frees ScramState's owned buffers + the struct
|
|
free(s->client_first);
|
|
free(s->client_final);
|
|
delete s;
|
|
}
|
|
|
|
const char* pg_scram_client_first(PgSQL_Scram_State* s, bool channel_binding) {
|
|
if (s == nullptr || s->st == nullptr) return nullptr;
|
|
// Channel binding ('p'/'y' gs2 flag) is a separate task; this wrapper only does
|
|
// plain SCRAM-SHA-256 with gs2 flag 'n' ("n,," header).
|
|
if (channel_binding) return nullptr;
|
|
scram_reset_error();
|
|
// libscram emits "n,,n=,r=<nonce>" and stashes client_nonce / client_first_message_bare
|
|
// ("n=,r=<nonce>") into the ScramState for the later proof computation.
|
|
char* msg = build_client_first_message(s->st);
|
|
if (msg == nullptr) return nullptr;
|
|
free(s->client_first);
|
|
s->client_first = msg;
|
|
return s->client_first;
|
|
}
|
|
|
|
const char* pg_scram_client_final(PgSQL_Scram_State* s, const char* password,
|
|
const char* server_first, size_t server_first_len) {
|
|
if (s == nullptr || s->st == nullptr || password == nullptr || server_first == nullptr)
|
|
return nullptr;
|
|
scram_reset_error();
|
|
|
|
// read_server_first_message() mutates its input (read_attr_value writes NULs and
|
|
// advances), so feed it a private, NUL-terminated, mutable copy.
|
|
std::string sf(server_first, server_first_len);
|
|
|
|
char* server_nonce = nullptr;
|
|
char* salt = nullptr;
|
|
int saltlen = 0;
|
|
int iterations = 0;
|
|
if (!read_server_first_message(s->st, &sf[0], &server_nonce, &salt, &saltlen, &iterations)) {
|
|
free(salt);
|
|
return nullptr;
|
|
}
|
|
|
|
// The password is the SCRAM plaintext secret; libscram derives keys ad-hoc.
|
|
// has_scram_keys stays false (value-initialized) so the plaintext path is used.
|
|
snprintf(s->creds.passwd, sizeof(s->creds.passwd), "%s", password);
|
|
|
|
char* msg = build_client_final_message(s->st, &s->creds, server_nonce, salt, saltlen, iterations);
|
|
// server_nonce / salt point into the parsed buffers: server_nonce into the local
|
|
// `sf` copy (no free), salt is malloc'd by read_server_first_message (must free).
|
|
free(salt);
|
|
if (msg == nullptr) return nullptr;
|
|
free(s->client_final);
|
|
s->client_final = msg;
|
|
return s->client_final;
|
|
}
|
|
|
|
bool pg_scram_verify_server_final(PgSQL_Scram_State* s, const char* server_final, size_t len) {
|
|
if (s == nullptr || s->st == nullptr || server_final == nullptr) return false;
|
|
scram_reset_error();
|
|
// read_server_final_message() mutates its input; use a private NUL-terminated copy.
|
|
std::string sf(server_final, len);
|
|
char ServerSignature[32]; // SCRAM_KEY_LEN
|
|
if (!read_server_final_message(&sf[0], ServerSignature)) return false;
|
|
return verify_server_signature(s->st, &s->creds, ServerSignature);
|
|
}
|