diff --git a/repertory/librepertory/src/comm/packet/common.cpp b/repertory/librepertory/src/comm/packet/common.cpp index 1a4263b7..5e964404 100644 --- a/repertory/librepertory/src/comm/packet/common.cpp +++ b/repertory/librepertory/src/comm/packet/common.cpp @@ -66,15 +66,7 @@ auto is_socket_still_alive(boost::asio::ip::tcp::socket &sock) -> bool { return false; } - if (not err && available == 0) { - return false; - } - - if (not err && available > 0) { - return true; - } - - return false; + return not err; } template diff --git a/repertory/repertory_test/src/packet_client_test.cpp b/repertory/repertory_test/src/packet_client_test.cpp new file mode 100644 index 00000000..3d9016d5 --- /dev/null +++ b/repertory/repertory_test/src/packet_client_test.cpp @@ -0,0 +1,294 @@ +/* + Copyright <2018-2025> + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to do so, subject to the + following conditions: The above copyright notice and this permission notice + shall be included in all copies or substantial portions of the Software. + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ +#include "test_common.hpp" + +#include "comm/packet/common.hpp" +#include "comm/packet/packet.hpp" +#include "comm/packet/packet_client.hpp" +#include "types/remote.hpp" +#include "utils/utils.hpp" +#include "version.hpp" + +using namespace repertory; +using namespace repertory::comm; +using boost::asio::ip::tcp; + +namespace { +void write_all(tcp::socket &sock, const void *data, std::size_t size) { + boost::asio::write(sock, boost::asio::buffer(data, size)); +} + +void read_exact(tcp::socket &sock, void *data, std::size_t size) { + boost::asio::read(sock, boost::asio::buffer(data, size)); +} + +struct test_server final { + std::string encryption_token; + std::atomic port{0}; + std::thread server_thread; + bool send_initial_nonce{false}; + bool respond_to_send{false}; + std::uint32_t response_service_flags{0}; + + explicit test_server(std::string token, bool send_nonce, + bool do_send_response, std::uint32_t svc_flags = 0U) + : encryption_token(std::move(token)), + send_initial_nonce(send_nonce), + respond_to_send(do_send_response), + response_service_flags(svc_flags) {} + + void start() { + std::promise ready; + server_thread = std::thread([this, &ready]() { + try { + boost::asio::io_context io_ctx; + tcp::acceptor acceptor(io_ctx, tcp::endpoint(tcp::v4(), 0)); + port.store(acceptor.local_endpoint().port(), std::memory_order_relaxed); + ready.set_value(); + + tcp::socket sock(io_ctx); + acceptor.accept(sock); + + packet handshake_pkt; + auto min_version = utils::get_version_number(project_get_version()); + handshake_pkt.encode(static_cast(min_version)); + handshake_pkt.encode(static_cast(~min_version)); + handshake_pkt.encode(utils::generate_random_string(packet_nonce_size)); + + data_buffer out; + handshake_pkt.to_buffer(out); + write_all(sock, out.data(), out.size()); + + std::vector echo(out.size()); + if (not echo.empty()) { + read_exact(sock, echo.data(), echo.size()); + } + + std::string last_nonce{}; + const auto generate_response = [&]() -> auto { + last_nonce = utils::generate_random_string(packet_nonce_size); + packet resp; + resp.encode(server_nonce); + resp.encode(response_service_flags); + resp.encode(packet::error_type{}); + resp.encrypt(encryption_token); + return resp; + }; + + if (send_initial_nonce) { + auto resp = generate_response(); + write_all(sock, &resp[0], resp.get_size()); + } + + if (respond_to_send) { + std::uint32_t req_net_len{}; + if (read_for_size(sock, req_net_len)) { + boost::endian::big_to_native_inplace(req_net_len); + + EXPECT_GT(req_net_len, 0); + if (req_net_len > 0) { + data_buffer buffer(req_net_len); + read_exact(sock, buffer.data(), buffer.size()); + + packet response(buffer); + EXPECT_EQ(0, response.decrypt(token)); + + std::string nonce; + EXPECT_EQ(0, response.decode(nonce)); + EXPECT_STREQ(last_nonce.c_str(), nonce.c_str()); + + std::string version; + EXPECT_EQ(0, response.decode(version)); + + std::uint32_t service_flags{}; + EXPECT_EQ(0, response.decode(service_flags)); + + std::string client_id; + EXPECT_EQ(0, response.decode(client_id)); + + std::string thread_id; + EXPECT_EQ(0, response.decode(thread_id)); + + std::string method; + EXPECT_EQ(0, response.decode(method)); + EXPECT_STREQ("ping", method.c_str()); + } + } + + auto resp = generate_response(); + write_all(sock, &resp[0], resp.get_size()); + } + + sock.close(); + } catch (...) { + } + }); + ready.get_future().wait(); + } + + static bool read_for_size(tcp::socket &sock, std::uint32_t &net_len) { + boost::system::error_code err; + auto count = boost::asio::read( + sock, boost::asio::buffer(&net_len, sizeof(net_len)), err); + return not err && count == sizeof(net_len); + } + + void stop() { + if (server_thread.joinable()) { + server_thread.join(); + } + } +}; + +remote::remote_config make_cfg(std::uint16_t port, const std::string &token) { + return remote::remote_config{ + .host_name_or_ip = "127.0.0.1", + .api_port = p, + .max_connections = 2U, + .conn_timeout_ms = 1500U, + .recv_timeout_ms = 1500U, + .send_timeout_ms = 1500U, + .encryption_token = token, + }; +} + +TEST(packet_client_test, can_check_version) { + std::string token = "cow_moose_doge_chicken"; + + test_server srv(token, false, false); + srv.start(); + + packet_client client(make_cfg(srv.port.load(), token)); + + std::uint32_t min_version{}; + auto res = client.check_version( + utils::get_version_number(project_get_version()), min_version); + + EXPECT_EQ(res, api_error::success); + EXPECT_NE(min_version, 0U); + + srv.stop(); +} + +TEST(packet_client_test, can_send_request_and_receive_response) { + std::string token = "cow_moose_doge_chicken"; + std::uint32_t svc_flags_server = 0xA5A5A5A5U; + + test_server srv(token, true, true, svc_flags_server); + srv.start(); + + packet_client client(make_cfg(srv.port.load(), token)); + + std::uint32_t service_flags{}; + packet req; + packet resp; + + auto ret = client.send("ping", req, resp, service_flags); + + EXPECT_EQ(ret, 0); + EXPECT_EQ(service_flags, svc_flags_server); + + srv.stop(); +} + +TEST(packet_client_test, pooled_connection_reused_on_second_send) { + std::string token{"test_token"}; + std::uint16_t port{}; + ASSERT_TRUE(utils::get_next_available_port(50000U, port)); + + std::atomic close_count{0U}; + + packet_server server{ + port, token, 2U, + [&close_count](const std::string & /*client_id*/) { ++close_count; }, + [](std::uint32_t /*service_flags_in*/, const std::string & /*client_id*/, + std::uint64_t /*thread_id*/, const std::string &method, + packet * /*request*/, packet & /*response*/, + packet_server::message_complete_callback done) { + if (method == "ping") { + done(packet::error_type{0}); + } else { + done(packet::error_type{-1}); + } + }}; + + packet_client client(::make_cfg(port, token)); + + std::uint32_t service_flags{}; + packet req_one; + packet resp_one; + auto ret_one = client.send("ping", req_one, resp_one, service_flags); + EXPECT_EQ(ret_one, 0); + + packet req_two; + packet resp_two; + auto ret_two = client.send("ping", req_two, resp_two, service_flags); + EXPECT_EQ(ret_two, 0); + + EXPECT_EQ(close_count.load(), 0U); +} + +TEST(packet_client_test, reconnects_when_server_closes_socket) { + std::string token{"test_token"}; + std::uint16_t port{}; + ASSERT_TRUE(utils::get_next_available_port(50000U, port)); + + std::atomic close_count{0U}; + std::shared_ptr last_conn; + + packet_server server{ + port, token, 2U, + [&close_count](const std::string & /*client_id*/) { ++close_count; }, + [&last_conn](std::uint32_t /*service_flags_in*/, + const std::string & /*client_id*/, + std::uint64_t /*thread_id*/, const std::string &method, + packet * /*request*/, packet & /*response*/, + packet_server::message_complete_callback done) { + if (method == "ping") { + done(packet::error_type{0}); + } else { + done(packet::error_type{-1}); + } + }}; + + packet_client client(::make_cfg(port, token)); + + std::uint32_t service_flags{}; + packet req_one; + packet resp_one; + auto ret_one = client.send("ping", req_one, resp_one, service_flags); + EXPECT_EQ(ret_one, 0); + + { + std::lock_guard guard(server.conn_mutex_); + if (not server.connections_.empty()) { + auto conn = *server.connections_.begin(); + boost::system::error_code ec; + conn->socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + conn->socket().close(ec); + } + } + + packet req_two; + packet resp_two; + auto ret_two = client.send("ping", req_two, resp_two, service_flags); + EXPECT_EQ(ret_two, 0); + + EXPECT_EQ(close_count.load(), 1U); +} +} // namespace