mirror of https://github.com/sysown/proxysql
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
6.6 KiB
220 lines
6.6 KiB
#ifndef PG_LITE_CLIENT_HPP
|
|
#define PG_LITE_CLIENT_HPP
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
#include <stdexcept>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <variant>
|
|
#include <sys/socket.h>
|
|
#include <netinet/in.h>
|
|
#include <arpa/inet.h>
|
|
#include <unistd.h>
|
|
#include <cstring>
|
|
#include <functional>
|
|
|
|
class PgException : public std::runtime_error {
|
|
public:
|
|
explicit PgException(const std::string& msg)
|
|
: std::runtime_error(msg) {}
|
|
};
|
|
|
|
class PgResult {
|
|
public:
|
|
using Value = std::variant<std::monostate, std::string, std::vector<uint8_t>>;
|
|
|
|
PgResult() = default;
|
|
|
|
void addRow(const std::vector<Value>& row);
|
|
void setColumns(const std::vector<std::string>& cols);
|
|
void setColumnFormats(const std::vector<int16_t>& fmts);
|
|
|
|
int rowCount() const;
|
|
int columnCount() const;
|
|
Value getValue(int row, int col) const;
|
|
std::string columnName(int col) const;
|
|
int16_t columnFormat(int col) const;
|
|
bool isNull(int row, int col) const;
|
|
|
|
private:
|
|
std::vector<std::string> columns_;
|
|
std::vector<int16_t> columnFormats_;
|
|
std::vector<std::vector<Value>> rows_;
|
|
};
|
|
|
|
class BufferReader {
|
|
public:
|
|
BufferReader(const uint8_t* data, size_t size)
|
|
: data_(data), size_(size), pos_(0) {
|
|
}
|
|
|
|
BufferReader(const std::vector<uint8_t>& buffer)
|
|
: data_(buffer.data()), size_(buffer.size()), pos_(0) {
|
|
}
|
|
|
|
uint8_t readByte() {
|
|
if (pos_ >= size_) throw PgException("Buffer underrun");
|
|
return data_[pos_++];
|
|
}
|
|
|
|
int16_t readInt16() {
|
|
if (pos_ + 2 > size_) throw PgException("Buffer underrun");
|
|
int16_t value;
|
|
memcpy(&value, data_ + pos_, 2);
|
|
pos_ += 2;
|
|
return ntohs(value);
|
|
}
|
|
|
|
int32_t readInt32() {
|
|
if (pos_ + 4 > size_) throw PgException("Buffer underrun");
|
|
int32_t value;
|
|
memcpy(&value, data_ + pos_, 4);
|
|
pos_ += 4;
|
|
return ntohl(value);
|
|
}
|
|
|
|
std::string readString() {
|
|
const char* start = reinterpret_cast<const char*>(data_ + pos_);
|
|
size_t len = 0;
|
|
while (pos_ + len < size_ && data_[pos_ + len] != '\0') {
|
|
len++;
|
|
}
|
|
if (pos_ + len >= size_) throw PgException("String missing null terminator");
|
|
std::string str(start, len);
|
|
pos_ += len + 1; // skip null terminator
|
|
return str;
|
|
}
|
|
|
|
std::vector<uint8_t> readBytes(size_t count) {
|
|
if (pos_ + count > size_) throw PgException("Buffer underrun");
|
|
std::vector<uint8_t> bytes(data_ + pos_, data_ + pos_ + count);
|
|
pos_ += count;
|
|
return bytes;
|
|
}
|
|
|
|
size_t remaining() const { return size_ - pos_; }
|
|
|
|
private:
|
|
const uint8_t* data_;
|
|
size_t size_;
|
|
size_t pos_;
|
|
};
|
|
|
|
|
|
class PgConnection {
|
|
public:
|
|
// Protocol constants
|
|
static const int PROTOCOL_VERSION = 196608; // 3.0
|
|
static const char SSL_REQUEST = 'N';
|
|
static const char AUTH_TYPE = 'R';
|
|
static const char PARAMETER_STATUS = 'S';
|
|
static const char READY_FOR_QUERY = 'Z';
|
|
static const char ROW_DESCRIPTION = 'T';
|
|
static const char PARAMETER_DESCRIPTION = 't';
|
|
static const char DATA_ROW = 'D';
|
|
static const char COMMAND_COMPLETE = 'C';
|
|
static const char PARSE_COMPLETE = '1';
|
|
static const char BIND_COMPLETE = '2';
|
|
static const char NO_DATA = 'n';
|
|
static const char ERROR_RESPONSE = 'E';
|
|
static const char CLOSE_COMPLETE = '3';
|
|
static const char EMPTY_QUERY_RESPONSE = 'I';
|
|
static const char PORTAL_SUSPENDED = 's';
|
|
static const char NOTICE_RESPONSE = 'N';
|
|
|
|
struct Param {
|
|
std::variant<std::monostate, int32_t, std::string, std::vector<uint8_t>> value;
|
|
int16_t format; // 0 = text, 1 = binary
|
|
};
|
|
|
|
PgConnection(int timeout_ms = 0);
|
|
~PgConnection();
|
|
|
|
void connect(
|
|
const std::string& host,
|
|
int port,
|
|
const std::string& dbname,
|
|
const std::string& user,
|
|
const std::string& password
|
|
);
|
|
|
|
void disconnect();
|
|
bool isConnected() const;
|
|
inline int getSocket() const { return sock_; }
|
|
|
|
void execute(const std::string& query);
|
|
void executeParams(
|
|
const std::string& stmtName,
|
|
const std::string& query,
|
|
const std::vector<Param>& params,
|
|
const std::vector<int16_t>& resultFormats = {}
|
|
);
|
|
|
|
// Prepared statement interface
|
|
void prepareStatement(const std::string& stmtName, const std::string& query, bool send_sync,const std::vector<uint32_t>& paramType = {});
|
|
void describeStatement(const std::string& stmtName, bool send_sync);
|
|
void describePortal(const std::string& stmtName, bool send_sync);
|
|
void bindStatement(
|
|
const std::string& stmtName,
|
|
const std::string& portalName,
|
|
const std::vector<Param>& params,
|
|
const std::vector<int16_t>& resultFormats = {},
|
|
bool sync = false
|
|
);
|
|
void executePortal(
|
|
const std::string& portalName,
|
|
int maxRows = 0, // 0 = all rows
|
|
bool send_sync = true
|
|
);
|
|
void executeStatement(
|
|
int maxRows = 0,
|
|
bool send_sync = true
|
|
);
|
|
void closePortal(const std::string& portalName, bool send_sync);
|
|
void closeStatement(const std::string& stmtName, bool send_sync);
|
|
void sendSync();
|
|
void waitForMessage(char expectedType, const std::string& errorContext, bool wait_for_ready);
|
|
void consumeInputUntilReady();
|
|
void readMessage(char& type, std::vector<uint8_t>& buffer);
|
|
void sendMessage(char type, const std::vector<uint8_t>& data);
|
|
void sendQuery(const std::string& query);
|
|
void waitForReady();
|
|
std::shared_ptr<PgResult> readResult(); // NOT TESTED YET
|
|
|
|
private:
|
|
int sock_ = -1;
|
|
int timeout_ms_ = 0;
|
|
std::string user_;
|
|
std::string dbname_;
|
|
|
|
void sendStartupPacket();
|
|
void handleAuthentication(const std::string& password);
|
|
void sendPassword(const std::string& password);
|
|
|
|
|
|
void sendParse(const std::string& query, const std::string& stmtName, const std::vector<uint32_t>& paramType);
|
|
void sendDescribeStatement(const std::string& stmtName);
|
|
void sendDescribePortal(const std::string& portalName);
|
|
void sendBind(
|
|
const std::vector<Param>& params,
|
|
const std::string& stmtName,
|
|
const std::vector<int16_t>& resultFormats
|
|
);
|
|
void sendExecute(const std::string& portalName, int maxRows);
|
|
|
|
void sendClose(const std::string& name, char type, bool send_sync);
|
|
|
|
std::vector<uint8_t> readBytes(size_t count);
|
|
void writeBytes(const uint8_t* data, size_t count);
|
|
|
|
std::string readString();
|
|
void writeString(const std::string& str);
|
|
|
|
int32_t readInt32();
|
|
void writeInt32(int32_t value);
|
|
int16_t readInt16();
|
|
void writeInt16(int16_t value);
|
|
};
|
|
|
|
#endif // PG_CLIENT_HPP
|