From 9c8e96bc5e904afdddfed389a45f5a41532cf5bb Mon Sep 17 00:00:00 2001 From: "Scott E. Graves" Date: Sun, 21 Sep 2025 11:03:43 -0500 Subject: [PATCH] added handshake for dos protection --- .../include/comm/packet/common.hpp | 25 ++--- .../librepertory/src/comm/packet/common.cpp | 15 +++ .../src/comm/packet/packet_server.cpp | 96 +++++++++++++------ 3 files changed, 91 insertions(+), 45 deletions(-) diff --git a/repertory/librepertory/include/comm/packet/common.hpp b/repertory/librepertory/include/comm/packet/common.hpp index 5638ab58..744758da 100644 --- a/repertory/librepertory/include/comm/packet/common.hpp +++ b/repertory/librepertory/include/comm/packet/common.hpp @@ -25,33 +25,24 @@ #include "utils/common.hpp" namespace repertory::comm { -constexpr const std::uint8_t max_read_attempts{5U}; -constexpr const std::uint16_t packet_nonce_size{256U}; +inline constexpr const std::uint8_t max_read_attempts{5U}; +inline constexpr const std::uint16_t packet_nonce_size{256U}; +inline constexpr const std::uint16_t server_handshake_timeout_ms{3000U}; struct non_blocking_guard final { - boost::asio::ip::tcp::socket &sock; - bool non_blocking{}; - non_blocking_guard(const non_blocking_guard &) = delete; non_blocking_guard(non_blocking_guard &&) = delete; auto operator=(const non_blocking_guard &) -> non_blocking_guard & = delete; auto operator=(non_blocking_guard &&) -> non_blocking_guard & = delete; - explicit non_blocking_guard(boost::asio::ip::tcp::socket &sock_) - : sock(sock_), non_blocking(sock_.non_blocking()) { - boost::system::error_code err; - [[maybe_unused]] auto ret = sock_.non_blocking(true, err); - } + explicit non_blocking_guard(boost::asio::ip::tcp::socket &sock_); - ~non_blocking_guard() { - if (not sock.is_open()) { - return; - } + ~non_blocking_guard(); - boost::system::error_code err; - [[maybe_unused]] auto ret = sock.non_blocking(non_blocking, err); - } +private: + bool non_blocking; + boost::asio::ip::tcp::socket &sock; }; [[nodiscard]] auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) diff --git a/repertory/librepertory/src/comm/packet/common.cpp b/repertory/librepertory/src/comm/packet/common.cpp index 07f5549b..a1e4481c 100644 --- a/repertory/librepertory/src/comm/packet/common.cpp +++ b/repertory/librepertory/src/comm/packet/common.cpp @@ -25,6 +25,21 @@ #include "events/types/packet_client_timeout.hpp" namespace repertory::comm { +non_blocking_guard::non_blocking_guard(boost::asio::ip::tcp::socket &sock_) + : sock(sock_), non_blocking(sock_.non_blocking()) { + boost::system::error_code err; + [[maybe_unused]] auto ret = sock_.non_blocking(true, err); +} + +non_blocking_guard::~non_blocking_guard() { + if (not sock.is_open()) { + return; + } + + boost::system::error_code err; + [[maybe_unused]] auto ret = sock.non_blocking(non_blocking, err); +} + auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) -> bool { if (not sock.is_open()) { return false; diff --git a/repertory/librepertory/src/comm/packet/packet_server.cpp b/repertory/librepertory/src/comm/packet/packet_server.cpp index 8ecb0a53..5e97d695 100644 --- a/repertory/librepertory/src/comm/packet/packet_server.cpp +++ b/repertory/librepertory/src/comm/packet/packet_server.cpp @@ -31,11 +31,13 @@ #include "platform/platform.hpp" #include "types/repertory.hpp" #include "utils/error_utils.hpp" +#include "utils/timeout.hpp" using namespace repertory::comm; using std::thread; namespace repertory { + packet_server::packet_server(std::uint16_t port, std::string token, std::uint8_t pool_size, closed_callback closed, message_handler_callback message_handler) @@ -93,28 +95,73 @@ auto packet_server::handshake(std::shared_ptr conn) const -> bool { request.to_buffer(buffer); auto to_read{buffer.size() + utils::encryption::encryption_header_size}; - write_all_with_deadline(io_context_, conn->socket, - boost::asio::buffer(buffer), - std::chrono::milliseconds(3000U)); - - conn->buffer.resize(to_read); - read_exact_with_deadline(io_context_, conn->socket, - boost::asio::buffer(conn->buffer), - std::chrono::milliseconds(3000U)); - packet response(conn->buffer); - if (response.decrypt(encryption_token_) == 0) { - std::string nonce; - if (response.decode(nonce) == 0) { - if (nonce == conn->nonce) { - conn->generate_nonce(); - return true; - } + const auto timeout_handler = [&conn]() { + try { + boost::system::error_code err{}; + [[maybe_unused]] auto ret = conn->socket.cancel(err); + } catch (const std::exception &e) { + repertory::utils::error::raise_error(function_name, e, + "exception occurred"); } - throw std::runtime_error("invalid nonce"); + try { + conn->socket.close(); + } catch (const std::exception &e) { + repertory::utils::error::raise_error(function_name, e, + "exception occurred"); + } + }; + + timeout write_timeout(timeout_handler, std::chrono::milliseconds( + server_handshake_timeout_ms)); + + auto written = boost::asio::write( + conn->socket, boost::asio::buffer(boost::asio::buffer(buffer))); + write_timeout.disable(); + + if (written == buffer.size()) { + conn->buffer.resize(to_read); + + timeout read_timeout(timeout_handler, std::chrono::milliseconds( + server_handshake_timeout_ms)); + + std::uint32_t total_read{}; + while ((total_read < to_read) && conn->socket.is_open()) { + auto bytes_read = boost::asio::read( + conn->socket, + boost::asio::buffer(&conn->buffer[total_read], + conn->buffer.size() - total_read)); + if (bytes_read <= 0) { + throw std::runtime_error("0 bytes read"); + } + + total_read += static_cast(bytes_read); + } + read_timeout.disable(); + + if (total_read == to_read) { + packet response(conn->buffer); + if (response.decrypt(encryption_token_) == 0) { + std::string nonce; + if (response.decode(nonce) == 0) { + if (nonce == conn->nonce) { + conn->generate_nonce(); + return true; + } + + throw std::runtime_error("invalid nonce"); + } + + throw std::runtime_error("invalid nonce"); + } + + throw std::runtime_error("decryption failed"); + } + + throw std::runtime_error("invalid handshake"); } - throw std::runtime_error("decryption failed"); + throw std::runtime_error("failed to send handshake"); } catch (const std::exception &e) { repertory::utils::error::raise_error(function_name, e, "handlshake failed"); } @@ -152,16 +199,10 @@ void packet_server::initialize(const uint16_t &port, uint8_t pool_size) { } void packet_server::listen_for_connection(tcp::acceptor &acceptor) { - REPERTORY_USES_FUNCTION_NAME(); - auto conn = std::make_shared(io_context_, acceptor); acceptor.async_accept(conn->socket, [this, conn](auto &&err) { - try { - on_accept(conn, std::forward(err)); - } catch (const std::exception &e) { - utils::error::raise_error(function_name, e, "exception occurred"); - } + on_accept(conn, std::forward(err)); }); } @@ -320,10 +361,9 @@ void packet_server::send_response(std::shared_ptr conn, if (err) { remove_client(*conn); utils::error::raise_error(function_name, err.message()); - return; + } else { + read_header(conn); } - - read_header(conn); }); } } // namespace repertory