added handshake for dos protection

This commit is contained in:
2025-09-21 11:03:43 -05:00
parent ee03167e43
commit 9c8e96bc5e
3 changed files with 91 additions and 45 deletions

View File

@@ -25,33 +25,24 @@
#include "utils/common.hpp"
namespace repertory::comm {
constexpr const std::uint8_t max_read_attempts{5U};
constexpr const std::uint16_t packet_nonce_size{256U};
inline constexpr const std::uint8_t max_read_attempts{5U};
inline constexpr const std::uint16_t packet_nonce_size{256U};
inline constexpr const std::uint16_t server_handshake_timeout_ms{3000U};
struct non_blocking_guard final {
boost::asio::ip::tcp::socket &sock;
bool non_blocking{};
non_blocking_guard(const non_blocking_guard &) = delete;
non_blocking_guard(non_blocking_guard &&) = delete;
auto operator=(const non_blocking_guard &) -> non_blocking_guard & = delete;
auto operator=(non_blocking_guard &&) -> non_blocking_guard & = delete;
explicit non_blocking_guard(boost::asio::ip::tcp::socket &sock_)
: sock(sock_), non_blocking(sock_.non_blocking()) {
boost::system::error_code err;
[[maybe_unused]] auto ret = sock_.non_blocking(true, err);
}
explicit non_blocking_guard(boost::asio::ip::tcp::socket &sock_);
~non_blocking_guard() {
if (not sock.is_open()) {
return;
}
~non_blocking_guard();
boost::system::error_code err;
[[maybe_unused]] auto ret = sock.non_blocking(non_blocking, err);
}
private:
bool non_blocking;
boost::asio::ip::tcp::socket &sock;
};
[[nodiscard]] auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock)

View File

@@ -25,6 +25,21 @@
#include "events/types/packet_client_timeout.hpp"
namespace repertory::comm {
non_blocking_guard::non_blocking_guard(boost::asio::ip::tcp::socket &sock_)
: sock(sock_), non_blocking(sock_.non_blocking()) {
boost::system::error_code err;
[[maybe_unused]] auto ret = sock_.non_blocking(true, err);
}
non_blocking_guard::~non_blocking_guard() {
if (not sock.is_open()) {
return;
}
boost::system::error_code err;
[[maybe_unused]] auto ret = sock.non_blocking(non_blocking, err);
}
auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) -> bool {
if (not sock.is_open()) {
return false;

View File

@@ -31,11 +31,13 @@
#include "platform/platform.hpp"
#include "types/repertory.hpp"
#include "utils/error_utils.hpp"
#include "utils/timeout.hpp"
using namespace repertory::comm;
using std::thread;
namespace repertory {
packet_server::packet_server(std::uint16_t port, std::string token,
std::uint8_t pool_size, closed_callback closed,
message_handler_callback message_handler)
@@ -93,14 +95,51 @@ auto packet_server::handshake(std::shared_ptr<connection> conn) const -> bool {
request.to_buffer(buffer);
auto to_read{buffer.size() + utils::encryption::encryption_header_size};
write_all_with_deadline(io_context_, conn->socket,
boost::asio::buffer(buffer),
std::chrono::milliseconds(3000U));
const auto timeout_handler = [&conn]() {
try {
boost::system::error_code err{};
[[maybe_unused]] auto ret = conn->socket.cancel(err);
} catch (const std::exception &e) {
repertory::utils::error::raise_error(function_name, e,
"exception occurred");
}
try {
conn->socket.close();
} catch (const std::exception &e) {
repertory::utils::error::raise_error(function_name, e,
"exception occurred");
}
};
timeout write_timeout(timeout_handler, std::chrono::milliseconds(
server_handshake_timeout_ms));
auto written = boost::asio::write(
conn->socket, boost::asio::buffer(boost::asio::buffer(buffer)));
write_timeout.disable();
if (written == buffer.size()) {
conn->buffer.resize(to_read);
read_exact_with_deadline(io_context_, conn->socket,
boost::asio::buffer(conn->buffer),
std::chrono::milliseconds(3000U));
timeout read_timeout(timeout_handler, std::chrono::milliseconds(
server_handshake_timeout_ms));
std::uint32_t total_read{};
while ((total_read < to_read) && conn->socket.is_open()) {
auto bytes_read = boost::asio::read(
conn->socket,
boost::asio::buffer(&conn->buffer[total_read],
conn->buffer.size() - total_read));
if (bytes_read <= 0) {
throw std::runtime_error("0 bytes read");
}
total_read += static_cast<std::uint32_t>(bytes_read);
}
read_timeout.disable();
if (total_read == to_read) {
packet response(conn->buffer);
if (response.decrypt(encryption_token_) == 0) {
std::string nonce;
@@ -109,12 +148,20 @@ auto packet_server::handshake(std::shared_ptr<connection> conn) const -> bool {
conn->generate_nonce();
return true;
}
throw std::runtime_error("invalid nonce");
}
throw std::runtime_error("invalid nonce");
}
throw std::runtime_error("decryption failed");
}
throw std::runtime_error("invalid handshake");
}
throw std::runtime_error("failed to send handshake");
} catch (const std::exception &e) {
repertory::utils::error::raise_error(function_name, e, "handlshake failed");
}
@@ -152,16 +199,10 @@ void packet_server::initialize(const uint16_t &port, uint8_t pool_size) {
}
void packet_server::listen_for_connection(tcp::acceptor &acceptor) {
REPERTORY_USES_FUNCTION_NAME();
auto conn =
std::make_shared<packet_server::connection>(io_context_, acceptor);
acceptor.async_accept(conn->socket, [this, conn](auto &&err) {
try {
on_accept(conn, std::forward<decltype(err)>(err));
} catch (const std::exception &e) {
utils::error::raise_error(function_name, e, "exception occurred");
}
});
}
@@ -320,10 +361,9 @@ void packet_server::send_response(std::shared_ptr<connection> conn,
if (err) {
remove_client(*conn);
utils::error::raise_error(function_name, err.message());
return;
}
} else {
read_header(conn);
}
});
}
} // namespace repertory