Implementation of PROXY protocol V1

This commit introduces:
- class ProxyProtocolInfo() , that:
  - performs parsing
  - validates subnet
  - run automated tests in DEBUG build
- variable mysql-proxy_protocol_networks . Accepted values:
  - empty string: disables PROXY protocol
  - '*' : allows connections from any IP
  - comma separated list of subnets
- automated testing in DEBUG build during start
- export of PROXY protocol information in internal session, using PROXY_V1
- a TAP test to verify various connections
v2.x_proxy
Rene Cannao 2 years ago
parent 27e71d2972
commit d85430b709

@ -5,6 +5,7 @@
#include "cpp.h"
#include "MySQL_Protocol.h"
#include "proxy_protocol_info.h"
#ifndef uchar
typedef unsigned char uchar;
@ -140,6 +141,7 @@ class MySQL_Data_Stream
char *addr;
int port;
} proxy_addr;
ProxyProtocolInfo * PROXY_info;
unsigned int connect_tries;
int query_retries_on_failure;

@ -460,6 +460,7 @@ class MySQL_Threads_Handler
char *server_version;
char *keep_multiplexing_variables;
char *default_authentication_plugin;
char *proxy_protocol_networks;
//unsigned int default_charset; // removed in 2.0.13 . Obsoleted previously using MySQL_Variables instead
int handle_unknown_charset;
int default_authentication_plugin_int;

@ -0,0 +1,51 @@
#ifndef PROXY_PROTOCOL_INFO_H
#define PROXY_PROTOCOL_INFO_H
#include <string.h>
#include <netinet/in.h>
#include <string>
#include <arpa/inet.h>
class ProxyProtocolInfo {
public:
char source_address[INET6_ADDRSTRLEN+1];
char destination_address[INET6_ADDRSTRLEN+1];
char proxy_address[INET6_ADDRSTRLEN+1];
uint16_t source_port;
uint16_t destination_port;
uint16_t proxy_port;
// Constructor (initializes to zeros)
ProxyProtocolInfo() {
memset(this, 0, sizeof(ProxyProtocolInfo));
}
// Copy constructor
ProxyProtocolInfo(const ProxyProtocolInfo& other) {
memcpy(this, &other, sizeof(ProxyProtocolInfo));
}
// Function to parse the PROXY protocol header (declared)
bool parseProxyProtocolHeader(const char* packet, size_t packet_length);
bool is_in_network(const struct sockaddr* client_addr, const std::string& subnet_mask);
bool is_client_in_any_subnet(const struct sockaddr* client_addr, const char* subnet_list);
// Copy method
ProxyProtocolInfo& copy(const ProxyProtocolInfo& other) {
if (this != &other) {
memcpy(this, &other, sizeof(ProxyProtocolInfo));
}
return *this;
}
#ifdef DEBUG
sockaddr_in create_ipv4_addr(const std::string& ip);
sockaddr_in6 create_ipv6_addr(const std::string& ip);
void run_tests();
#endif // DEBUG
bool is_valid_subnet_list(const char* subnet_list);
bool is_valid_subnet(const char* subnet);
};
#endif // PROXY_PROTOCOL_INFO_H

@ -777,6 +777,7 @@ __thread char *mysql_thread___default_schema;
__thread char *mysql_thread___server_version;
__thread char *mysql_thread___keep_multiplexing_variables;
__thread char *mysql_thread___default_authentication_plugin;
__thread char *mysql_thread___proxy_protocol_networks;
__thread char *mysql_thread___init_connect;
__thread char *mysql_thread___ldap_user_variable;
__thread char *mysql_thread___default_session_track_gtids;
@ -949,6 +950,7 @@ extern __thread char *mysql_thread___default_schema;
extern __thread char *mysql_thread___server_version;
extern __thread char *mysql_thread___keep_multiplexing_variables;
extern __thread char *mysql_thread___default_authentication_plugin;
extern __thread char *mysql_thread___proxy_protocol_networks;
extern __thread char *mysql_thread___init_connect;
extern __thread char *mysql_thread___ldap_user_variable;
extern __thread char *mysql_thread___default_session_track_gtids;

@ -130,6 +130,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo
QP_rule_text.oo QP_query_digest_stats.oo \
GTID_Server_Data.oo MyHGC.oo MySrvConnList.oo MySrvList.oo MySrvC.oo \
MySQL_encode.oo MySQL_ResultSet.oo \
proxy_protocol_info.oo \
proxysql_find_charset.oo ProxySQL_Poll.oo
OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX))
HEADERS := ../include/*.h ../include/*.hpp

@ -499,6 +499,7 @@ static char * mysql_thread_variables_names[]= {
(char *)"data_packets_history_size",
(char *)"handle_warnings",
(char *)"evaluate_replication_lag_on_servers_load",
(char *)"proxy_protocol_networks",
NULL
};
@ -1119,6 +1120,7 @@ MySQL_Threads_Handler::MySQL_Threads_Handler() {
variables.ssl_p2s_crl=NULL;
variables.ssl_p2s_crlpath=NULL;
variables.keep_multiplexing_variables=strdup((char *)"tx_isolation,transaction_isolation,version");
variables.proxy_protocol_networks = strdup((char *)"");
variables.default_authentication_plugin=strdup((char *)"mysql_native_password");
variables.default_authentication_plugin_int = 0; // mysql_native_password
#ifdef DEBUG
@ -1350,6 +1352,7 @@ char * MySQL_Threads_Handler::get_variable_string(char *name) {
if (!strcmp(name,"interfaces")) return strdup(variables.interfaces);
if (!strcmp(name,"keep_multiplexing_variables")) return strdup(variables.keep_multiplexing_variables);
if (!strcmp(name,"default_authentication_plugin")) return strdup(variables.default_authentication_plugin);
if (!strcmp(name,"proxy_protocol_networks")) return strdup(variables.proxy_protocol_networks);
// LCOV_EXCL_START
proxy_error("Not existing variable: %s\n", name); assert(0);
return NULL;
@ -1505,6 +1508,7 @@ char * MySQL_Threads_Handler::get_variable(char *name) { // this is the public f
if (!strcasecmp(name,"default_schema")) return strdup(variables.default_schema);
if (!strcasecmp(name,"keep_multiplexing_variables")) return strdup(variables.keep_multiplexing_variables);
if (!strcasecmp(name,"default_authentication_plugin")) return strdup(variables.default_authentication_plugin);
if (!strcasecmp(name,"proxy_protocol_networks")) return strdup(variables.proxy_protocol_networks);
if (!strcasecmp(name,"interfaces")) return strdup(variables.interfaces);
if (!strcasecmp(name,"server_capabilities")) {
// FIXME : make it human readable
@ -1878,6 +1882,28 @@ bool MySQL_Threads_Handler::set_variable(char *name, const char *value) { // thi
return false;
}
}
if (!strcasecmp(name,"proxy_protocol_networks")) {
bool ret = false;
if (vallen == 0) {
// accept empty string
ret = true;
} else if ( (vallen == 1) && strcmp(value,"*")==0) {
// accept `*`
ret = true;
} else {
ProxyProtocolInfo ppi;
if (ppi.is_valid_subnet_list(value) == true) {
ret = true;
}
}
if (ret == true) {
free(variables.proxy_protocol_networks);
variables.proxy_protocol_networks=strdup(value);
return true;
} else {
return true;
}
}
// SSL proxy to server variables
if (!strcasecmp(name,"ssl_p2s_ca")) {
if (variables.ssl_p2s_ca) free(variables.ssl_p2s_ca);
@ -2703,6 +2729,7 @@ MySQL_Threads_Handler::~MySQL_Threads_Handler() {
if (variables.server_version) free(variables.server_version);
if (variables.keep_multiplexing_variables) free(variables.keep_multiplexing_variables);
if (variables.default_authentication_plugin) free(variables.default_authentication_plugin);
if (variables.proxy_protocol_networks) free(variables.proxy_protocol_networks);
if (variables.firewall_whitelist_errormsg) free(variables.firewall_whitelist_errormsg);
if (variables.init_connect) free(variables.init_connect);
if (variables.ldap_user_variable) free(variables.ldap_user_variable);
@ -2834,6 +2861,7 @@ MySQL_Thread::~MySQL_Thread() {
if (mysql_thread___server_version) { free(mysql_thread___server_version); mysql_thread___server_version=NULL; }
if (mysql_thread___keep_multiplexing_variables) { free(mysql_thread___keep_multiplexing_variables); mysql_thread___keep_multiplexing_variables=NULL; }
if (mysql_thread___default_authentication_plugin) { free(mysql_thread___default_authentication_plugin); mysql_thread___default_authentication_plugin=NULL; }
if (mysql_thread___proxy_protocol_networks) { free(mysql_thread___proxy_protocol_networks); mysql_thread___proxy_protocol_networks=NULL; }
if (mysql_thread___firewall_whitelist_errormsg) { free(mysql_thread___firewall_whitelist_errormsg); mysql_thread___firewall_whitelist_errormsg=NULL; }
if (mysql_thread___init_connect) { free(mysql_thread___init_connect); mysql_thread___init_connect=NULL; }
if (mysql_thread___ldap_user_variable) { free(mysql_thread___ldap_user_variable); mysql_thread___ldap_user_variable=NULL; }
@ -4377,6 +4405,7 @@ void MySQL_Thread::refresh_variables() {
GloMyLogger->audit_set_base_filename(); // both filename and filesize are set here
REFRESH_VARIABLE_CHAR(default_schema);
REFRESH_VARIABLE_CHAR(keep_multiplexing_variables);
REFRESH_VARIABLE_CHAR(proxy_protocol_networks);
REFRESH_VARIABLE_CHAR(default_authentication_plugin);
mysql_thread___default_authentication_plugin_int = GloMTH->variables.default_authentication_plugin_int;
mysql_thread___server_capabilities=GloMTH->get_variable_uint16((char *)"server_capabilities");

@ -307,6 +307,8 @@ MySQL_Data_Stream::MySQL_Data_Stream() {
proxy_addr.addr=NULL;
proxy_addr.port=0;
PROXY_info = NULL;
sess=NULL;
mysql_real_query.pkt.ptr=NULL;
mysql_real_query.pkt.size=0;
@ -380,6 +382,10 @@ MySQL_Data_Stream::~MySQL_Data_Stream() {
free(proxy_addr.addr);
proxy_addr.addr=NULL;
}
if (PROXY_info) {
delete PROXY_info;
PROXY_info = NULL;
}
free_mysql_real_query();
@ -1081,6 +1087,90 @@ int MySQL_Data_Stream::buffer2array() {
} else {
if ((queueIN.pkt.size==0) && queue_data(queueIN)>=sizeof(mysql_hdr)) {
// check if this is a PROXY protocol packet
if (
pkts_recv==0 && // checks if no packets have been received yet
queueIN.tail == 0 && // checks if the input queue (`queueIN`) was never rotated . This check is redundant
queueIN.head > 7 && // ensures that there are at least 8 bytes in the input buffer (`queueIN.buffer`)
// This is because the PROXY protocol signature (`PROXY`) is 5 bytes long, and we need at least 3 more bytes to check for the `\r\n` delimiter.
strncmp((char *)queueIN.buffer,"PROXY ",6) == 0 // checks if the first 6 bytes of the buffer match the "PROXY " string, indicating a potential PROXY protocol packet
) {
bool found_delimiter = false;
size_t b = 0;
const char *ptr = (char *)queueIN.buffer;
// This loop iterates through the buffer, starting from the 8th byte (index 7) until the end of the buffer (index `queueIN.head - 1`).
// The loop continues as long as the delimiter hasn't been found (`found_delimiter == false`)
// the loop looks for \r\n , the delimiter of the PROXY packet
for (size_t i = 7; found_delimiter == false && i < queueIN.head - 1; i++) {
if (
ptr[i] == '\r'
&&
ptr[i+1] == '\n'
) {
found_delimiter = true;
b = i+2;
}
}
if (found_delimiter) {
/*
// we could return a packet, but it is actually better to handle it here
queueIN.pkt.size = b;
queueIN.pkt.ptr=l_alloc(queueIN.pkt.size);
memcpy(queueIN.pkt.ptr, queueIN.buffer, b);
PSarrayIN->add(queueIN.pkt.ptr,queueIN.pkt.size);
add_to_data_packet_history(data_packets_history_IN,queueIN.pkt.ptr,queueIN.pkt.size);
*/
// we move forward the internal pointer.
// note that parseProxyProtocolHeader() will read from the beginning of the buffer
queue_r(queueIN, b);
bool accept_proxy = false; // by default, we do not accept a PROXY header
const char * proxy_protocol_networks = mysql_thread___proxy_protocol_networks;
ProxyProtocolInfo ppi;
if (strcmp(proxy_protocol_networks,"*") == 0) { // all networks are accepted
accept_proxy = true;
} else {
if (client_addr) {
if (ppi.is_client_in_any_subnet(client_addr, proxy_protocol_networks) == true) {
accept_proxy = true;
}
}
}
if (accept_proxy == true) {
if (ppi.parseProxyProtocolHeader((const char *)queueIN.buffer, b)) {
PROXY_info = new ProxyProtocolInfo(ppi);
// we take a copy of old address/port
if (addr.addr) {
strncpy(PROXY_info->proxy_address, addr.addr, INET6_ADDRSTRLEN);
free(addr.addr);
}
PROXY_info->proxy_port = addr.port;
// we override old address/port
addr.addr = strdup(PROXY_info->source_address);
addr.port = PROXY_info->source_port;
} else {
// TODO: error handling
// maybe just generate a warning
}
} else { // the PROXY header was not accepted
// TODO: error handling
// maybe just generate a warning
}
pkts_recv++;
queueIN.pkt.size=0;
queueIN.pkt.ptr=NULL;
return b;
} else {
// set the connection unhealthy , this will cause the session to be destroyed
if (sess) {
sess->set_unhealthy();
}
}
return 0; // we always return
}
proxy_debug(PROXY_DEBUG_PKT_ARRAY, 5, "Session=%p . Reading the header of a new packet\n", sess);
memcpy(&queueIN.hdr,queue_r_ptr(queueIN),sizeof(mysql_hdr));
pkt_sid=queueIN.hdr.pkt_id;
@ -1576,6 +1666,14 @@ void MySQL_Data_Stream::get_client_myds_info_json(json& j) {
jc1["client_addr"]["port"] = addr.port;
jc1["proxy_addr"]["address"] = ( proxy_addr.addr ? proxy_addr.addr : "" );
jc1["proxy_addr"]["port"] = proxy_addr.port;
if (PROXY_info != NULL) {
jc1["PROXY_V1"]["source_address"] = PROXY_info->source_address;
jc1["PROXY_V1"]["destination_address"] = PROXY_info->destination_address;
jc1["PROXY_V1"]["proxy_address"] = PROXY_info->proxy_address;
jc1["PROXY_V1"]["source_port"] = PROXY_info->source_port;
jc1["PROXY_V1"]["destination_port"] = PROXY_info->destination_port;
jc1["PROXY_V1"]["proxy_port"] = PROXY_info->proxy_port;
}
jc1["encrypted"] = encrypted;
if (encrypted) {
const SSL_CIPHER *cipher = SSL_get_current_cipher(ssl);

@ -0,0 +1,382 @@
#include "proxy_protocol_info.h"
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <iostream>
static bool DEBUG_ProxyProtocolInfo = false;
// Function to parse the PROXY protocol header
bool ProxyProtocolInfo::parseProxyProtocolHeader(const char* packet, size_t packet_length) {
// Check for minimum header length (including CRLF)
if (packet_length < 15) {
return false; // Not a valid PROXY protocol header
}
// Create a temporary buffer on the stack
char temp_buffer[packet_length + 1];
// Copy the packet data
memcpy(temp_buffer, packet, packet_length);
temp_buffer[packet_length] = '\0'; // Null-terminate the buffer
// Verify the PROXY protocol signature
if (memcmp(temp_buffer, "PROXY", 5) != 0) {
return false; // Not a valid PROXY protocol header
}
// Check for the space after "PROXY"
if (temp_buffer[5] != ' ') {
return false; // Invalid header format
}
// Check for the protocol type
if (memcmp(temp_buffer + 6, "TCP4", 4) == 0 ||
memcmp(temp_buffer + 6, "TCP6", 4) == 0 ||
memcmp(temp_buffer + 6, "UNKNOWN", 7) == 0) {
// Parse the header using sscanf
int result = sscanf(temp_buffer, "PROXY %*s %s %s %hu %hu\r\n",
source_address, destination_address,
&source_port, &destination_port);
// Check if sscanf successfully parsed all fields
if (result == 4) {
return true; // Successful parsing
} else {
// Handle partial parsing or invalid format
return false; // Indicate an error
}
}
return false; // Invalid header format
}
/**
* Checks if a client address is within a specified subnet.
*
* @param client_addr Pointer to the client's sockaddr structure (either sockaddr_in or sockaddr_in6).
* @param subnet_mask The subnet in CIDR notation (e.g., "192.168.1.0/24" for IPv4 or "2001:db8::/32" for IPv6).
* @return True if the client address is within the specified subnet, otherwise false.
*/
bool ProxyProtocolInfo::is_in_network(const struct sockaddr* client_addr, const std::string& subnet_mask) {
// Determine address family (IPv4 or IPv6)
int family = client_addr->sa_family;
// Parse the subnet and mask
union {
struct in_addr v4;
struct in6_addr v6;
} subnet_addr;
uint8_t mask = 0;
char addr_str[INET6_ADDRSTRLEN];
if (family == AF_INET) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Parsing IPv4 subnet mask" << std::endl;
// Parse the IPv4 subnet mask using sscanf
if (sscanf(subnet_mask.c_str(), "%[^/]/%hhu", addr_str, &mask) != 2) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Invalid subnet/mask format" << std::endl;
return false; // Invalid subnet/mask format
}
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Subnet: " << addr_str << ", Mask: " << (int)mask << std::endl;
// Convert the parsed subnet address to binary format
if (inet_pton(AF_INET, addr_str, &subnet_addr.v4) != 1) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Invalid IPv4 address" << std::endl;
return false; // Invalid IPv4 address
}
} else if (family == AF_INET6) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Parsing IPv6 subnet mask" << std::endl;
// Parse the IPv6 subnet mask using sscanf
if (sscanf(subnet_mask.c_str(), "%[^/]/%hhu", addr_str, &mask) != 2) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Invalid subnet/mask format" << std::endl;
return false; // Invalid subnet/mask format
}
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Subnet: " << addr_str << ", Mask: " << (int)mask << std::endl;
// Convert the parsed subnet address to binary format
if (inet_pton(AF_INET6, addr_str, &subnet_addr.v6) != 1) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Invalid IPv6 address" << std::endl;
return false; // Invalid IPv6 address
}
} else {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Unsupported address family" << std::endl;
return false; // Unsupported address family
}
uint8_t network_addr[16] = {0};
if (family == AF_INET) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Calculating network address for IPv4" << std::endl;
// Calculate the network address for IPv4
uint32_t subnet = ntohl(subnet_addr.v4.s_addr) & (0xFFFFFFFF << (32 - mask));
subnet = htonl(subnet);
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Subnet address (masked): " << inet_ntoa(*(struct in_addr*)&subnet) << std::endl;
// Copy the masked subnet address into the network_addr array
memcpy(network_addr, &subnet, sizeof(subnet));
} else if (family == AF_INET6) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Calculating network address for IPv6" << std::endl;
// Calculate the network address for IPv6
uint8_t* addr = subnet_addr.v6.s6_addr;
int bits_left = mask;
for (int i = 0; i < 16; ++i) {
if (bits_left >= 8) {
network_addr[i] = addr[i];
bits_left -= 8;
} else if (bits_left > 0) {
network_addr[i] = addr[i] & (0xFF << (8 - bits_left));
bits_left = 0;
} else {
network_addr[i] = 0;
}
}
if (DEBUG_ProxyProtocolInfo==true) {
char network_addr_str[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, network_addr, network_addr_str, INET6_ADDRSTRLEN);
std::cout << "Subnet address (masked): " << network_addr_str << std::endl;
}
}
uint8_t client_addr_int[16] = {0};
if (family == AF_INET) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Extracting client address for IPv4" << std::endl;
// Extract the client address for IPv4
uint32_t client = ntohl(((struct sockaddr_in*)client_addr)->sin_addr.s_addr);
client = htonl(client);
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Client address: " << inet_ntoa(*(struct in_addr*)&client) << std::endl;
// Copy the client address into the client_addr_int array
memcpy(client_addr_int, &client, sizeof(client));
} else if (family == AF_INET6) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Extracting client address for IPv6" << std::endl;
// Copy the client address into the client_addr_int array
memcpy(client_addr_int, ((struct sockaddr_in6*)client_addr)->sin6_addr.s6_addr, 16);
if (DEBUG_ProxyProtocolInfo==true) {
char client_addr_str[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, client_addr_int, client_addr_str, INET6_ADDRSTRLEN);
std::cout << "Client address: " << client_addr_str << std::endl;
}
}
// Calculate the number of bytes to compare based on the mask
int bytes_to_compare = mask / 8;
int remaining_bits = mask % 8;
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Comparing full bytes covered by the mask" << std::endl;
// Compare the full bytes covered by the mask
if (memcmp(network_addr, client_addr_int, bytes_to_compare) != 0) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Address does not match in full byte comparison" << std::endl;
return false;
}
if (remaining_bits > 0) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Comparing remaining bits" << std::endl;
// Compare the remaining bits covered by the mask
uint8_t mask_byte = 0xFF << (8 - remaining_bits);
if ((network_addr[bytes_to_compare] & mask_byte) != (client_addr_int[bytes_to_compare] & mask_byte)) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Address does not match in remaining bits comparison" << std::endl;
return false; // Addresses don't match in remaining bits comparison
}
}
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Client address is within the subnet" << std::endl;
return true; // Client address is within the subnet
}
bool ProxyProtocolInfo::is_client_in_any_subnet(const struct sockaddr* client_addr, const char* subnet_list) {
// Create a copy of the subnet list to avoid modifying the original string
char* subnet_list_copy = new char[strlen(subnet_list) + 1];
strcpy(subnet_list_copy, subnet_list);
char* token = strtok(subnet_list_copy, ","); // Get the first subnet
while (token != NULL) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Checking subnet: " << token << std::endl;
if (is_in_network(client_addr, token)) {
if (DEBUG_ProxyProtocolInfo==true)
std::cout << "Client is in subnet: " << token << std::endl;
delete[] subnet_list_copy; // Deallocate the copy
return true; // Client is in at least one subnet
}
token = strtok(NULL, ","); // Get the next subnet
}
delete[] subnet_list_copy; // Deallocate the copy
return false; // Client is not in any of the subnets
}
#ifdef DEBUG
// Helper function to create an IPv4 sockaddr structure
sockaddr_in ProxyProtocolInfo::create_ipv4_addr(const std::string& ip) {
sockaddr_in addr;
addr.sin_family = AF_INET;
inet_pton(AF_INET, ip.c_str(), &addr.sin_addr);
return addr;
}
// Helper function to create an IPv6 sockaddr structure
sockaddr_in6 ProxyProtocolInfo::create_ipv6_addr(const std::string& ip) {
sockaddr_in6 addr;
addr.sin6_family = AF_INET6;
inet_pton(AF_INET6, ip.c_str(), &addr.sin6_addr);
return addr;
}
// Test cases for the is_in_network function
void ProxyProtocolInfo::run_tests() {
// IPv4 Tests
{
sockaddr_in client_addr = create_ipv4_addr("192.168.1.10");
assert(is_in_network((sockaddr*)&client_addr, "192.168.1.0/24") == true);
assert(is_in_network((sockaddr*)&client_addr, "192.168.1.0/25") == true);
assert(is_in_network((sockaddr*)&client_addr, "192.168.1.0/26") == true);
assert(is_in_network((sockaddr*)&client_addr, "192.168.2.0/24") == false);
assert(is_in_network((sockaddr*)&client_addr, "192.168.0.0/16") == true);
assert(is_in_network((sockaddr*)&client_addr, "192.168.1.10/32") == true);
assert(is_in_network((sockaddr*)&client_addr, "192.168.1.11/32") == false);
}
// IPv6 Tests
{
sockaddr_in6 client_addr = create_ipv6_addr("2001:db8::1");
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/32") == true);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/48") == true);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/64") == true);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:1::/64") == false);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::1/128") == true);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::2/128") == false);
}
{
sockaddr_in6 client_addr = create_ipv6_addr("2001:db8:0:1::1");
assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:1::/64") == true);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/32") == true);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:2::/64") == false);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8::1/128") == false);
assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:1::1/128") == true);
}
{
struct sockaddr_in client_addr = create_ipv4_addr("172.16.14.1");
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.16.0.0/16,192.168.1.0/24") == true);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.17.0.0/16,192.168.1.0/24") == false);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,172.16.0.0/16,192.168.1.0/24") == true);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,172.17.0.0/16,192.168.1.0/24") == false);
}
{
sockaddr_in6 client_addr = create_ipv6_addr("2001:db8:0:1::1");
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,2001:db8:0:2::/64") == true);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:2::/64,2001:db8:0:1::/64") == true);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,172.16.0.0/16") == true);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.16.0.0/16,2001:db8:0:1::/64") == true);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:2::/64,172.16.0.0/16") == false);
assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.16.0.0/16,2001:db8:0:2::/64") == false);
}
{
const char* subnet_list1 = "192.168.1.0/24,10.0.0.0/8,2001:0:200::/32";
const char* subnet_list2 = "192.168.1.0/24,10.0.0.0/not_a_mask,2001:0:200::/32";
const char* subnet_list3 = "192.168.1.0/24,invalid_ipv4,2001:0:200::/32";
const char* subnet_list4 = "";
assert(is_valid_subnet_list(subnet_list1) == true);
assert(is_valid_subnet_list(subnet_list2) == false);
assert(is_valid_subnet_list(subnet_list3) == false);
assert(is_valid_subnet_list(subnet_list4) == false);
}
}
#endif // DEBUG
bool ProxyProtocolInfo::is_valid_subnet_list(const char* subnet_list) {
// Check if the string is empty
if (subnet_list == nullptr || *subnet_list == '\0') {
return false; // Empty string is not a valid subnet list
}
// Create a copy of the string to avoid modifying the original
char* subnet_list_copy = new char[strlen(subnet_list) + 1];
strcpy(subnet_list_copy, subnet_list);
// Tokenize the string using ',' as the delimiter
char* token = strtok(subnet_list_copy, ",");
while (token != NULL) {
// Check if the token is a valid subnet
if (!is_valid_subnet(token)) {
delete[] subnet_list_copy; // Deallocate the copy
return false; // Invalid subnet found
}
token = strtok(NULL, ","); // Get the next token
}
delete[] subnet_list_copy; // Deallocate the copy
return true; // All subnets are valid
}
// Helper function to verify a single subnet
bool ProxyProtocolInfo::is_valid_subnet(const char* subnet) {
// Check if the subnet is empty
if (subnet == NULL || *subnet == '\0') {
return false; // Empty subnet is not valid
}
// Check if the subnet contains a '/' character (CIDR notation)
if (strchr(subnet, '/') == NULL) {
return false; // Missing '/' character in subnet
}
// Check if the subnet is a valid IPv4 or IPv6 address
int family = AF_INET; // Default to IPv4
if (strchr(subnet, ':') != NULL) {
family = AF_INET6; // IPv6 if a colon is found
}
char addr_str[INET6_ADDRSTRLEN];
uint8_t mask = 0;
if (family == AF_INET) {
// Parse IPv4 subnet using sscanf
if (sscanf(subnet, "%[^/]/%hhu", addr_str, &mask) != 2) {
return false; // Invalid IPv4 subnet format
}
} else if (family == AF_INET6) {
// Parse IPv6 subnet using sscanf
if (sscanf(subnet, "%[^/]/%hhu", addr_str, &mask) != 2) {
return false; // Invalid IPv6 subnet format
}
} else {
return false; // Unsupported address family
}
// Validate the mask value
if (mask < 0 || mask > 128) {
return false; // Invalid mask value
}
// Check if the address is valid using inet_pton
union {
struct in_addr v4;
struct in6_addr v6;
} addr; // Create a union to hold both IPv4 and IPv6 addresses
if (inet_pton(family, addr_str, &addr) != 1) {
return false; // Invalid IP address
}
return true; // Valid subnet
}

@ -42,6 +42,11 @@
#include <uuid/uuid.h>
#ifdef DEBUG
#include "proxy_protocol_info.h"
#endif // DEBUG
/*
extern "C" MySQL_LDAP_Authentication * create_MySQL_LDAP_Authentication_func() {
return NULL;
@ -1967,6 +1972,17 @@ int main(int argc, const char * argv[]) {
if (rc) { exit(EXIT_FAILURE); }
}
#ifdef DEBUG
{
// This run some ProxyProtocolInfo tests.
// It will assert() if any test fails
ProxyProtocolInfo ppi;
ppi.run_tests();
}
#endif // DEBUG
{
MYSQL *my = mysql_init(NULL);
mysql_close(my);

@ -0,0 +1,147 @@
/**
* @file test_PROXY_Protocol-t.cpp
* @brief This test tries the PROXY protocol
* @details The test performs authentication using the PROXY protocol , then
* verifies PROXYSQL INTERNAL SESSION
* @date 2024-08-07
*/
#include <vector>
#include <string>
#include <stdio.h>
#include "mysql.h"
#include "tap.h"
#include "command_line.h"
#include "utils.h"
#include "json.hpp"
#include <utility> // For std::pair
using std::string;
using namespace nlohmann;
void parse_result_json_column(MYSQL_RES *result, json& j) {
if(!result) return;
MYSQL_ROW row;
while ((row = mysql_fetch_row(result))) {
j = json::parse(row[0]);
}
}
int connect_and_run_query(CommandLine& cl, int tests, const char *hdr) {
int ret = 0; // number of success
MYSQL* proxysql_mysql = mysql_init(NULL);
mysql_optionsv(proxysql_mysql, MARIADB_OPT_PROXY_HEADER, hdr, strlen(hdr));
if (!mysql_real_connect(proxysql_mysql, cl.host, cl.username, cl.password, NULL, cl.port, NULL, 0)) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(proxysql_mysql));
return ret;
} else {
ok(true, "Successfully connected");
ret++;
}
MYSQL_QUERY(proxysql_mysql, "PROXYSQL INTERNAL SESSION");
json j_status {};
MYSQL_RES* int_session_res = mysql_store_result(proxysql_mysql);
parse_result_json_column(int_session_res, j_status);
mysql_free_result(int_session_res);
bool proxy_info_found = false;
//diag("%s",j_status.dump(1).c_str());
json jv1 {};
if (j_status.find("client") != j_status.end()) {
json& j = *j_status.find("client");
if (j.find("PROXY_V1") != j.end()) {
proxy_info_found = true;
jv1 = *j.find("PROXY_V1");
}
}
if (tests == 2) { // we must found PROXY_V1
ok(proxy_info_found == true, "PROXY_V1 %sfound", proxy_info_found ? "" : "not ");
if (proxy_info_found == true) {
ret++;
diag("%s",jv1.dump().c_str());
}
} else if (tests == 1) { // PROXY_V1 should not be present
ok(proxy_info_found == false, "PROXY_V1 %sfound", proxy_info_found ? "" : "not ");
if (proxy_info_found == true) {
diag("%s",jv1.dump().c_str());
} else {
ret++;
}
} else {
exit(exit_status());
}
mysql_close(proxysql_mysql);
return ret;
}
int main(int argc, char** argv) {
CommandLine cl;
std::vector<std::pair<int, std::string>> Headers;
Headers.push_back(std::make_pair(2, "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n"));
Headers.push_back(std::make_pair(1, "PROXY TCP4 192.168.0.1 192.168.0.11 56324\r\n"));
Headers.push_back(std::make_pair(0, "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443"));
Headers.push_back(std::make_pair(0, "PROXY"));
Headers.push_back(std::make_pair(2, "PROXY TCP6 fe80::d6ae:52ff:fecf:9876 fe80::d6ae:52aa:fecf:1234 56324 443\r\n"));
Headers.push_back(std::make_pair(1, "PROXY TCP6 fe80::d6ae:52ff:fecf:9876 fe80::d6ae:52aa:fecf:1234 56324\r\n"));
Headers.push_back(std::make_pair(0, "PROXY TCP6 fe80::d6ae:52ff:fecf:9876 fe80::d6ae:52aa:fecf:1234 56324 443"));
int p = 0;
// we will run the tests twice, with:
// - with mysql-proxy_protocol_networks=''
p += Headers.size();
for (const auto& pair : Headers) {
p += ( pair.first ? 2 : 0); // PROXY_V1 should not be present
}
// - with mysql-proxy_protocol_networks='*'
p += Headers.size();
for (const auto& pair : Headers) {
p += ( pair.first ? 2 : 0); // perform either 2 checks, or 0
}
plan(p);
MYSQL* proxysql_admin = mysql_init(NULL);
// Initialize connections
if (!proxysql_admin) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(proxysql_admin));
return -1;
}
if (!mysql_real_connect(proxysql_admin, cl.host, cl.admin_username, cl.admin_password, NULL, cl.admin_port, NULL, 0)) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(proxysql_admin));
return -1;
}
diag("Setting mysql-proxy_protocol_networks=''");
MYSQL_QUERY(proxysql_admin, "SET mysql-proxy_protocol_networks=''");
MYSQL_QUERY(proxysql_admin, "LOAD MYSQL VARIABLES TO RUNTIME");
for (const auto& pair : Headers) {
const std::string& hdr = pair.second;
diag("Testing connection with header: %s", hdr.c_str());
int arg1 = pair.first ? 1 : 0; // if pair.first is not 0 , we will pass 1 because PROXY_V1 should not be present
int ret = connect_and_run_query(cl, arg1, hdr.c_str());
int expected = pair.first ? 2 : 0;
ok(ret == expected , "Expected successes: %d , returned successes: %d", expected, ret);
}
diag("Setting mysql-proxy_protocol_networks='*'");
MYSQL_QUERY(proxysql_admin, "SET mysql-proxy_protocol_networks='*'");
MYSQL_QUERY(proxysql_admin, "LOAD MYSQL VARIABLES TO RUNTIME");
for (const auto& pair : Headers) {
const std::string& hdr = pair.second;
diag("Testing connection with header: %s", hdr.c_str());
int ret = connect_and_run_query(cl, pair.first, hdr.c_str());
int expected = pair.first ? 2 : 0;
ok(ret == expected , "Expected successes: %d , returned successes: %d", expected, ret);
}
return exit_status();
}
Loading…
Cancel
Save