diff --git a/repertory/librepertory/include/comm/packet/packet_client.hpp b/repertory/librepertory/include/comm/packet/packet_client.hpp index 9daca189..3021bc30 100644 --- a/repertory/librepertory/include/comm/packet/packet_client.hpp +++ b/repertory/librepertory/include/comm/packet/packet_client.hpp @@ -46,13 +46,12 @@ public: private: remote::remote_config cfg_; mutable boost::asio::io_context io_context_; - atomic unique_id_; - boost::asio::executor_work_guard - work_guard_; + utils::atomic unique_id_; private: std::atomic allow_connections_{true}; - atomic::results_type> + utils::atomic< + boost::asio::ip::basic_resolver::results_type> resolve_results_; std::mutex clients_mutex_; std::vector> clients_; diff --git a/repertory/librepertory/src/comm/packet/packet_client.cpp b/repertory/librepertory/src/comm/packet/packet_client.cpp index 1fc7009b..aecd9a3c 100644 --- a/repertory/librepertory/src/comm/packet/packet_client.cpp +++ b/repertory/librepertory/src/comm/packet/packet_client.cpp @@ -31,10 +31,7 @@ using namespace repertory::comm; namespace repertory { packet_client::packet_client(remote::remote_config cfg) - : cfg_(std::move(cfg)), - io_context_(), - unique_id_(utils::create_uuid_string()), - work_guard_(boost::asio::make_work_guard(io_context_)) { + : cfg_(std::move(cfg)), unique_id_(utils::create_uuid_string()) { for (std::uint8_t idx = 0U; idx < cfg.max_connections; ++idx) { service_threads_.emplace_back([this]() { io_context_.run(); }); } @@ -50,7 +47,6 @@ packet_client::~packet_client() { } catch (...) { } - work_guard_.reset(); io_context_.stop(); for (auto &thread : service_threads_) { @@ -62,9 +58,10 @@ packet_client::~packet_client() { void packet_client::close(client &cli) noexcept { boost::system::error_code err1; - cli.socket.shutdown(boost::asio::socket_base::shutdown_both, err1); + auto res = cli.socket.shutdown(boost::asio::socket_base::shutdown_both, err1); + boost::system::error_code err2; - cli.socket.close(err2); + res = cli.socket.close(err2); } void packet_client::close_all() { @@ -73,7 +70,7 @@ void packet_client::close_all() { close(*cli); } clients_.clear(); - resolve_results_ = {}; + resolve_results_.store({}); unique_id_ = utils::create_uuid_string(); } @@ -100,11 +97,6 @@ auto packet_client::check_version(std::uint32_t client_version, return api_error::comm_error; } - if ((min_version & 0xFFU) != 0U) { - min_version = 0U; - return api_error::incompatible_version; - } - if (client_version < min_version) { return api_error::incompatible_version; } @@ -121,8 +113,7 @@ auto packet_client::connect(client &cli) -> bool { try { resolve(); - boost::asio::ip::basic_resolver::results_type cached = - resolve_results_; + auto cached = resolve_results_.load(); connect_with_deadline(io_context_, cli.socket, cached, std::chrono::milliseconds(cfg_.conn_timeout_ms)); @@ -144,7 +135,7 @@ auto packet_client::connect(client &cli) -> bool { return true; } catch (...) { close(cli); - resolve_results_ = {}; + resolve_results_.store({}); throw; } } @@ -162,7 +153,7 @@ auto packet_client::get_client() -> std::shared_ptr { clients_lock.unlock(); auto cli = std::make_shared(io_context_); - if (!connect(*cli)) { + if (not connect(*cli)) { return nullptr; } return cli; diff --git a/repertory/librepertory/src/comm/packet/packet_server.cpp b/repertory/librepertory/src/comm/packet/packet_server.cpp index b1ab3d39..5ff79102 100644 --- a/repertory/librepertory/src/comm/packet/packet_server.cpp +++ b/repertory/librepertory/src/comm/packet/packet_server.cpp @@ -146,17 +146,26 @@ auto packet_server::handshake(std::shared_ptr conn) const -> bool { if (response.decrypt(encryption_token_) == 0) { std::uint32_t client_version{}; if (response.decode(client_version) == 0) { - std::string nonce; - if (response.decode(nonce) == 0) { - if (nonce == conn->nonce) { - conn->generate_nonce(); - return true; + std::uint32_t client_version_check{}; + if (response.decode(client_version_check) == 0) { + if (~client_version != client_version_check) { + throw std::runtime_error("client version check failed"); } - throw std::runtime_error("nonce mismatch"); + std::string nonce; + if (response.decode(nonce) == 0) { + if (nonce == conn->nonce) { + conn->generate_nonce(); + return true; + } + + throw std::runtime_error("nonce mismatch"); + } + + throw std::runtime_error("invalid nonce"); } - throw std::runtime_error("invalid nonce"); + throw std::runtime_error("invalid client version"); } throw std::runtime_error("invalid client version");