added handshake for dos protection

This commit is contained in:
2025-09-21 11:03:43 -05:00
parent ee03167e43
commit 9c8e96bc5e
3 changed files with 91 additions and 45 deletions

View File

@@ -25,33 +25,24 @@
#include "utils/common.hpp" #include "utils/common.hpp"
namespace repertory::comm { namespace repertory::comm {
constexpr const std::uint8_t max_read_attempts{5U}; inline constexpr const std::uint8_t max_read_attempts{5U};
constexpr const std::uint16_t packet_nonce_size{256U}; inline constexpr const std::uint16_t packet_nonce_size{256U};
inline constexpr const std::uint16_t server_handshake_timeout_ms{3000U};
struct non_blocking_guard final { struct non_blocking_guard final {
boost::asio::ip::tcp::socket &sock;
bool non_blocking{};
non_blocking_guard(const non_blocking_guard &) = delete; non_blocking_guard(const non_blocking_guard &) = delete;
non_blocking_guard(non_blocking_guard &&) = delete; non_blocking_guard(non_blocking_guard &&) = delete;
auto operator=(const non_blocking_guard &) -> non_blocking_guard & = delete; auto operator=(const non_blocking_guard &) -> non_blocking_guard & = delete;
auto operator=(non_blocking_guard &&) -> non_blocking_guard & = delete; auto operator=(non_blocking_guard &&) -> non_blocking_guard & = delete;
explicit non_blocking_guard(boost::asio::ip::tcp::socket &sock_) explicit non_blocking_guard(boost::asio::ip::tcp::socket &sock_);
: sock(sock_), non_blocking(sock_.non_blocking()) {
boost::system::error_code err;
[[maybe_unused]] auto ret = sock_.non_blocking(true, err);
}
~non_blocking_guard() { ~non_blocking_guard();
if (not sock.is_open()) {
return;
}
boost::system::error_code err; private:
[[maybe_unused]] auto ret = sock.non_blocking(non_blocking, err); bool non_blocking;
} boost::asio::ip::tcp::socket &sock;
}; };
[[nodiscard]] auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) [[nodiscard]] auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock)

View File

@@ -25,6 +25,21 @@
#include "events/types/packet_client_timeout.hpp" #include "events/types/packet_client_timeout.hpp"
namespace repertory::comm { namespace repertory::comm {
non_blocking_guard::non_blocking_guard(boost::asio::ip::tcp::socket &sock_)
: sock(sock_), non_blocking(sock_.non_blocking()) {
boost::system::error_code err;
[[maybe_unused]] auto ret = sock_.non_blocking(true, err);
}
non_blocking_guard::~non_blocking_guard() {
if (not sock.is_open()) {
return;
}
boost::system::error_code err;
[[maybe_unused]] auto ret = sock.non_blocking(non_blocking, err);
}
auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) -> bool { auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) -> bool {
if (not sock.is_open()) { if (not sock.is_open()) {
return false; return false;

View File

@@ -31,11 +31,13 @@
#include "platform/platform.hpp" #include "platform/platform.hpp"
#include "types/repertory.hpp" #include "types/repertory.hpp"
#include "utils/error_utils.hpp" #include "utils/error_utils.hpp"
#include "utils/timeout.hpp"
using namespace repertory::comm; using namespace repertory::comm;
using std::thread; using std::thread;
namespace repertory { namespace repertory {
packet_server::packet_server(std::uint16_t port, std::string token, packet_server::packet_server(std::uint16_t port, std::string token,
std::uint8_t pool_size, closed_callback closed, std::uint8_t pool_size, closed_callback closed,
message_handler_callback message_handler) message_handler_callback message_handler)
@@ -93,28 +95,73 @@ auto packet_server::handshake(std::shared_ptr<connection> conn) const -> bool {
request.to_buffer(buffer); request.to_buffer(buffer);
auto to_read{buffer.size() + utils::encryption::encryption_header_size}; auto to_read{buffer.size() + utils::encryption::encryption_header_size};
write_all_with_deadline(io_context_, conn->socket, const auto timeout_handler = [&conn]() {
boost::asio::buffer(buffer), try {
std::chrono::milliseconds(3000U)); boost::system::error_code err{};
[[maybe_unused]] auto ret = conn->socket.cancel(err);
conn->buffer.resize(to_read); } catch (const std::exception &e) {
read_exact_with_deadline(io_context_, conn->socket, repertory::utils::error::raise_error(function_name, e,
boost::asio::buffer(conn->buffer), "exception occurred");
std::chrono::milliseconds(3000U));
packet response(conn->buffer);
if (response.decrypt(encryption_token_) == 0) {
std::string nonce;
if (response.decode(nonce) == 0) {
if (nonce == conn->nonce) {
conn->generate_nonce();
return true;
}
} }
throw std::runtime_error("invalid nonce"); try {
conn->socket.close();
} catch (const std::exception &e) {
repertory::utils::error::raise_error(function_name, e,
"exception occurred");
}
};
timeout write_timeout(timeout_handler, std::chrono::milliseconds(
server_handshake_timeout_ms));
auto written = boost::asio::write(
conn->socket, boost::asio::buffer(boost::asio::buffer(buffer)));
write_timeout.disable();
if (written == buffer.size()) {
conn->buffer.resize(to_read);
timeout read_timeout(timeout_handler, std::chrono::milliseconds(
server_handshake_timeout_ms));
std::uint32_t total_read{};
while ((total_read < to_read) && conn->socket.is_open()) {
auto bytes_read = boost::asio::read(
conn->socket,
boost::asio::buffer(&conn->buffer[total_read],
conn->buffer.size() - total_read));
if (bytes_read <= 0) {
throw std::runtime_error("0 bytes read");
}
total_read += static_cast<std::uint32_t>(bytes_read);
}
read_timeout.disable();
if (total_read == to_read) {
packet response(conn->buffer);
if (response.decrypt(encryption_token_) == 0) {
std::string nonce;
if (response.decode(nonce) == 0) {
if (nonce == conn->nonce) {
conn->generate_nonce();
return true;
}
throw std::runtime_error("invalid nonce");
}
throw std::runtime_error("invalid nonce");
}
throw std::runtime_error("decryption failed");
}
throw std::runtime_error("invalid handshake");
} }
throw std::runtime_error("decryption failed"); throw std::runtime_error("failed to send handshake");
} catch (const std::exception &e) { } catch (const std::exception &e) {
repertory::utils::error::raise_error(function_name, e, "handlshake failed"); repertory::utils::error::raise_error(function_name, e, "handlshake failed");
} }
@@ -152,16 +199,10 @@ void packet_server::initialize(const uint16_t &port, uint8_t pool_size) {
} }
void packet_server::listen_for_connection(tcp::acceptor &acceptor) { void packet_server::listen_for_connection(tcp::acceptor &acceptor) {
REPERTORY_USES_FUNCTION_NAME();
auto conn = auto conn =
std::make_shared<packet_server::connection>(io_context_, acceptor); std::make_shared<packet_server::connection>(io_context_, acceptor);
acceptor.async_accept(conn->socket, [this, conn](auto &&err) { acceptor.async_accept(conn->socket, [this, conn](auto &&err) {
try { on_accept(conn, std::forward<decltype(err)>(err));
on_accept(conn, std::forward<decltype(err)>(err));
} catch (const std::exception &e) {
utils::error::raise_error(function_name, e, "exception occurred");
}
}); });
} }
@@ -320,10 +361,9 @@ void packet_server::send_response(std::shared_ptr<connection> conn,
if (err) { if (err) {
remove_client(*conn); remove_client(*conn);
utils::error::raise_error(function_name, err.message()); utils::error::raise_error(function_name, err.message());
return; } else {
read_header(conn);
} }
read_header(conn);
}); });
} }
} // namespace repertory } // namespace repertory