From edc4465d153e2fb2c1038f296b83028bb0fc8ada Mon Sep 17 00:00:00 2001 From: "Scott E. Graves" Date: Sat, 30 Aug 2025 19:13:54 -0500 Subject: [PATCH] Implement secure key via KDF for transparent data encryption/decryption #60 --- .cspell/words.txt | 2 + .../src/providers/s3/s3_provider.cpp | 27 +- support/include/utils/base64.hpp | 341 ++++++++++----- support/src/utils/encrypting_reader.cpp | 18 +- support/src/utils/encryption.cpp | 11 +- support/test/src/utils/base64_test.cpp | 389 ++++++++++++++++++ 6 files changed, 666 insertions(+), 122 deletions(-) create mode 100644 support/test/src/utils/base64_test.cpp diff --git a/.cspell/words.txt b/.cspell/words.txt index fdfaa4ee..ea9944f4 100644 --- a/.cspell/words.txt +++ b/.cspell/words.txt @@ -111,6 +111,8 @@ flac_version flag_nopath flarge fontconfig_version +foob +fooba freetype2_version fsetattr_x fusermount diff --git a/repertory/librepertory/src/providers/s3/s3_provider.cpp b/repertory/librepertory/src/providers/s3/s3_provider.cpp index 666a51dc..b3365573 100644 --- a/repertory/librepertory/src/providers/s3/s3_provider.cpp +++ b/repertory/librepertory/src/providers/s3/s3_provider.cpp @@ -32,6 +32,7 @@ #include "file_manager/i_file_manager.hpp" #include "types/repertory.hpp" #include "types/s3.hpp" +#include "utils/base64.hpp" #include "utils/collection.hpp" #include "utils/common.hpp" #include "utils/config.hpp" @@ -1055,13 +1056,13 @@ auto s3_provider::search_keys_for_master_kdf( continue; } - data_buffer buffer; - if (not utils::collection::from_hex_string(object_name, buffer)) { - continue; - } - - if (not utils::encryption::kdf_config::from_header(buffer, - master_kdf_cfg_)) { + try { + auto buffer = macaron::Base64::Decode(object_name); + if (not utils::encryption::kdf_config::from_header(buffer, + master_kdf_cfg_)) { + continue; + } + } catch (...) { continue; } @@ -1127,11 +1128,13 @@ auto s3_provider::set_meta_key(const std::string &api_path, api_meta_map &meta) result.insert(result.begin(), hdr.begin(), hdr.end()); } - meta[META_KEY] = utils::path::create_api_path( - utils::path::combine(utils::path::create_api_path(encrypted_parent_path), - { - utils::collection::to_hex_string(result), - })); + meta[META_KEY] = utils::path::create_api_path(utils::path::combine( + utils::path::create_api_path(encrypted_parent_path), + { + legacy_bucket_ + ? utils::collection::to_hex_string(result) + : macaron::Base64::EncodeUrlSafe(result.data(), result.size()), + })); return api_error::success; } diff --git a/support/include/utils/base64.hpp b/support/include/utils/base64.hpp index 2c63569a..02e07c30 100644 --- a/support/include/utils/base64.hpp +++ b/support/include/utils/base64.hpp @@ -1,10 +1,11 @@ // NOLINTBEGIN -#ifndef _MACARON_BASE64_H_ -#define _MACARON_BASE64_H_ +#ifndef MACARON_BASE64_H_ +#define MACARON_BASE64_H_ /** * The MIT License (MIT) * Copyright (c) 2016 tomykaira + * Copyright (c) 2025 scott.e.graves@protonmail.com * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the @@ -39,121 +40,272 @@ #endif #include +#include #include +#include #include namespace macaron::Base64 { -static std::string Encode(const unsigned char *data, std::size_t len) { - static constexpr std::array sEncodingTable{ - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', - 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', - 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', - }; - auto in_len{len}; - std::string ret; - if (in_len > 0) { - std::size_t out_len{4U * ((in_len + 2U) / 3U)}; - ret = std::string(out_len, '\0'); - std::size_t i; - auto *p = reinterpret_cast(ret.data()); +// --- Alphabets -------------------------------------------------------------- - for (i = 0U; i < in_len - 2U; i += 3U) { - *p++ = sEncodingTable[(data[i] >> 2U) & 0x3F]; - *p++ = sEncodingTable[((data[i] & 0x3) << 4U) | - ((int)(data[i + 1U] & 0xF0) >> 4U)]; - *p++ = sEncodingTable[((data[i + 1] & 0xF) << 2) | - ((int)(data[i + 2U] & 0xC0) >> 6U)]; - *p++ = sEncodingTable[data[i + 2U] & 0x3F]; +static constexpr std::array kStdAlphabet{ + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', +}; + +static constexpr std::array kUrlAlphabet{ + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_', +}; + +// Decoding table that accepts BOTH standard and URL-safe alphabets. +static constexpr std::array kDecodingTable = [] { + std::array t{}; + t.fill(64U); + // 'A'-'Z' + for (unsigned char c = 'A'; c <= 'Z'; ++c) + t[c] = static_cast(c - 'A'); + // 'a'-'z' + for (unsigned char c = 'a'; c <= 'z'; ++c) + t[c] = static_cast(26 + c - 'a'); + // '0'-'9' + for (unsigned char c = '0'; c <= '9'; ++c) + t[c] = static_cast(52 + c - '0'); + // Standard extras + t[static_cast('+')] = 62U; + t[static_cast('/')] = 63U; + // URL-safe extras + t[static_cast('-')] = 62U; + t[static_cast('_')] = 63U; + return t; +}(); + +// --- Encoding --------------------------------------------------------------- + +/** + * Encode to Base64. + * @param data pointer to bytes + * @param len number of bytes + * @param url_safe if true, use URL-safe alphabet ("-","_") instead of ("+","/") + * @param pad if true, add '=' padding; if false, omit padding (RFC 4648 + * §5) + */ +static std::string Encode(const unsigned char *data, std::size_t len, + bool url_safe = false, bool pad = true) { + const auto &alpha = url_safe ? kUrlAlphabet : kStdAlphabet; + + std::string out; + if (len == 0U) { + return out; + } + + const std::size_t full_blocks = len / 3U; + const std::size_t rem = len % 3U; + + std::size_t out_len{}; + if (pad) { + out_len = 4U * ((len + 2U) / 3U); + } else { + // Unpadded length per RFC 4648 §5 + out_len = 4U * full_blocks + (rem == 0U ? 0U : (rem == 1U ? 2U : 3U)); + } + out.assign(out_len, '\0'); + + auto *p = reinterpret_cast(out.data()); + std::size_t i = 0; + + // Full 3-byte blocks -> 4 chars + for (; i + 2U < len; i += 3U) { + const unsigned char b0 = data[i + 0U]; + const unsigned char b1 = data[i + 1U]; + const unsigned char b2 = data[i + 2U]; + + *p++ = alpha[(b0 >> 2U) & 0x3F]; + *p++ = alpha[((b0 & 0x03U) << 4U) | ((b1 >> 4U) & 0x0FU)]; + *p++ = alpha[((b1 & 0x0FU) << 2U) | ((b2 >> 6U) & 0x03U)]; + *p++ = alpha[b2 & 0x3FU]; + } + + // Remainder + if (rem == 1U) { + const unsigned char b0 = data[i]; + *p++ = alpha[(b0 >> 2U) & 0x3F]; + *p++ = alpha[(b0 & 0x03U) << 4U]; + if (pad) { + *p++ = '='; + *p++ = '='; } - if (i < in_len) { - *p++ = sEncodingTable[(data[i] >> 2U) & 0x3F]; - if (i == (in_len - 1U)) { - *p++ = sEncodingTable[((data[i] & 0x3) << 4U)]; - *p++ = '='; - } else { - *p++ = sEncodingTable[((data[i] & 0x3) << 4U) | - ((int)(data[i + 1U] & 0xF0) >> 4U)]; - *p++ = sEncodingTable[((data[i + 1U] & 0xF) << 2U)]; - } + } else if (rem == 2U) { + const unsigned char b0 = data[i + 0U]; + const unsigned char b1 = data[i + 1U]; + *p++ = alpha[(b0 >> 2U) & 0x3F]; + *p++ = alpha[((b0 & 0x03U) << 4U) | ((b1 >> 4U) & 0x0FU)]; + *p++ = alpha[(b1 & 0x0FU) << 2U]; + if (pad) { *p++ = '='; } } - return ret; + return out; } -[[maybe_unused]] static std::string Encode(std::string_view data) { +[[maybe_unused]] static std::string +Encode(std::string_view data, bool url_safe = false, bool pad = true) { return Encode(reinterpret_cast(data.data()), - data.size()); + data.size(), url_safe, pad); } +[[maybe_unused]] static std::string +EncodeUrlSafe(const unsigned char *data, std::size_t len, bool pad = false) { + return Encode(data, len, /*url_safe=*/true, /*pad=*/pad); +} + +[[maybe_unused]] static std::string EncodeUrlSafe(std::string_view data, + bool pad = false) { + return Encode(reinterpret_cast(data.data()), + data.size(), /*url_safe=*/true, /*pad=*/pad); +} + +// --- Decoding --------------------------------------------------------------- + +/** + * Decode standard OR URL-safe Base64. + * Accepts inputs with or without '=' padding. + * Throws std::runtime_error on malformed input. + */ [[maybe_unused]] static std::vector Decode(std::string_view input) { - static constexpr std::array kDecodingTable{ - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63, 52, 53, 54, 55, 56, 57, - 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, 64, 0, 1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, - 25, 64, 64, 64, 64, 64, 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, + std::vector out; + if (input.empty()) { + return out; + } + + std::size_t inLen = input.size(); + std::size_t rem = inLen % 4U; + + // padded if multiple of 4 and last char is '=' + bool hasPadding = (rem == 0U) && (inLen >= 4U) && (input[inLen - 1U] == '='); + + // compute output length + std::size_t outLen{}; + if (hasPadding) { + outLen = (inLen / 4U) * 3U; + if (input[inLen - 1U] == '=') + outLen--; + if (input[inLen - 2U] == '=') + outLen--; + } else { + if (rem == 1U) { + throw std::runtime_error("Invalid Base64 length (mod 4 == 1)"); + } + outLen = (inLen / 4U) * 3U + (rem == 0U ? 0U : (rem == 2U ? 1U : 2U)); + } + + out.resize(outLen); + + auto readVal = [](unsigned char c) -> unsigned char { + unsigned char v = kDecodingTable[c]; + if (v == 64U) { + throw std::runtime_error("Invalid Base64 character"); + } + return v; }; - std::vector out; - if (not input.empty()) { - auto in_len{input.size()}; - if (in_len % 4U != 0U) { - throw std::runtime_error("Input data size is not a multiple of 4"); - } + std::size_t i = 0U; + std::size_t j = 0U; - std::size_t out_len{in_len / 4U * 3U}; - if (input[in_len - 1U] == '=') { - out_len--; - } - if (input[in_len - 2U] == '=') { - out_len--; - } + // process all full unpadded quartets + std::size_t lastFull = + hasPadding ? (inLen - 4U) : (rem == 0U ? inLen : (inLen - rem)); - out.resize(out_len); + while (i + 4U <= lastFull) { + unsigned char a = readVal(static_cast(input[i + 0U])); + unsigned char b = readVal(static_cast(input[i + 1U])); + unsigned char c = readVal(static_cast(input[i + 2U])); + unsigned char d = readVal(static_cast(input[i + 3U])); + i += 4U; - for (std::size_t i = 0U, j = 0U; i < in_len;) { - std::uint32_t a = - input.at(i) == '=' - ? 0U & i++ - : kDecodingTable[static_cast(input.at(i++))]; - std::uint32_t b = - input.at(i) == '=' - ? 0U & i++ - : kDecodingTable[static_cast(input.at(i++))]; - std::uint32_t c = - input.at(i) == '=' - ? 0U & i++ - : kDecodingTable[static_cast(input.at(i++))]; - std::uint32_t d = - input.at(i) == '=' - ? 0U & i++ - : kDecodingTable[static_cast(input.at(i++))]; + std::uint32_t triple = (static_cast(a) << 18U) | + (static_cast(b) << 12U) | + (static_cast(c) << 6U) | + (static_cast(d)); - std::uint32_t triple = - (a << 3U * 6U) + (b << 2U * 6U) + (c << 1U * 6U) + (d << 0U * 6U); + if (j < outLen) + out[j++] = static_cast((triple >> 16U) & 0xFFU); + if (j < outLen) + out[j++] = static_cast((triple >> 8U) & 0xFFU); + if (j < outLen) + out[j++] = static_cast(triple & 0xFFU); + } - if (j < out_len) - out[j++] = (triple >> 2U * 8U) & 0xFF; - if (j < out_len) - out[j++] = (triple >> 1U * 8U) & 0xFF; - if (j < out_len) - out[j++] = (triple >> 0U * 8U) & 0xFF; + // tail: padded quartet or unpadded remainder + if (i < inLen) { + std::size_t left = inLen - i; + + if (left == 4U) { + bool thirdIsPad = (input[i + 2U] == '='); + bool fourthIsPad = (input[i + 3U] == '='); + + // '=' is never allowed in positions 1 or 2 of any quartet + if (input[i + 0U] == '=' || input[i + 1U] == '=') { + throw std::runtime_error("Invalid Base64 padding placement"); + } + + unsigned char a = readVal(static_cast(input[i + 0U])); + unsigned char b = readVal(static_cast(input[i + 1U])); + unsigned char c = 0U; + unsigned char d = 0U; + + if (!thirdIsPad) { + c = readVal(static_cast(input[i + 2U])); + if (!fourthIsPad) { + d = readVal(static_cast(input[i + 3U])); + } + } else { + // if the 3rd is '=', the 4th must also be '=' + if (!fourthIsPad) { + throw std::runtime_error("Invalid Base64 padding placement"); + } + } + i += 4U; + + std::uint32_t triple = (static_cast(a) << 18U) | + (static_cast(b) << 12U) | + (static_cast(c) << 6U) | + (static_cast(d)); + + if (j < outLen) + out[j++] = static_cast((triple >> 16U) & 0xFFU); + if (!thirdIsPad && j < outLen) + out[j++] = static_cast((triple >> 8U) & 0xFFU); + if (!fourthIsPad && !thirdIsPad && j < outLen) + out[j++] = static_cast(triple & 0xFFU); + + } else if (left == 2U || left == 3U) { + unsigned char a = readVal(static_cast(input[i + 0U])); + unsigned char b = readVal(static_cast(input[i + 1U])); + unsigned char c = (left == 3U) + ? readVal(static_cast(input[i + 2U])) + : 0U; + i += left; + + std::uint32_t triple = (static_cast(a) << 18U) | + (static_cast(b) << 12U) | + (static_cast(c) << 6U); + + if (j < outLen) + out[j++] = static_cast((triple >> 16U) & 0xFFU); + if (left == 3U && j < outLen) + out[j++] = static_cast((triple >> 8U) & 0xFFU); + } else { + throw std::runtime_error("Invalid Base64 length (mod 4 == 1)"); } } @@ -169,6 +321,5 @@ Decode(std::string_view input) { #pragma clang diagnostic pop #endif -#endif /* _MACARON_BASE64_H_ */ - -// NOLINTEND +#endif /* MACARON_BASE64_H_ */ + // NOLINTEND diff --git a/support/src/utils/encrypting_reader.cpp b/support/src/utils/encrypting_reader.cpp index 4d043561..8b389ac9 100644 --- a/support/src/utils/encrypting_reader.cpp +++ b/support/src/utils/encrypting_reader.cpp @@ -23,6 +23,7 @@ #include "utils/encrypting_reader.hpp" +#include "utils/base64.hpp" #include "utils/collection.hpp" #include "utils/common.hpp" #include "utils/config.hpp" @@ -444,11 +445,7 @@ void encrypting_reader::common_initialize_kdf_path( const utils::hash::hash_256_t &master_key) { REPERTORY_USES_FUNCTION_NAME(); - data_buffer buffer; - if (not utils::collection::from_hex_string(encrypted_file_path_, buffer)) { - throw utils::error::create_exception( - function_name, {"failed to convert encrypted path from hex to bytes"}); - } + auto buffer = macaron::Base64::Decode(encrypted_file_path_); kdf_config path_cfg; if (not kdf_config::from_header(buffer, path_cfg)) { @@ -475,7 +472,10 @@ void encrypting_reader::create_encrypted_paths( kdf_headers_->second.end()); } - encrypted_file_name_ = utils::collection::to_hex_string(result); + encrypted_file_name_ = + kdf_headers_.has_value() + ? macaron::Base64::EncodeUrlSafe(result.data(), result.size()) + : utils::collection::to_hex_string(result); if (not relative_parent_path.has_value()) { return; @@ -492,7 +492,11 @@ void encrypting_reader::create_encrypted_paths( kdf_headers_->second.end()); } - encrypted_file_path_ += '/' + utils::collection::to_hex_string(result); + encrypted_file_path_ += + '/' + + (kdf_headers_.has_value() + ? macaron::Base64::EncodeUrlSafe(result.data(), result.size()) + : utils::collection::to_hex_string(result)); } encrypted_file_path_ += '/' + encrypted_file_name_; diff --git a/support/src/utils/encryption.cpp b/support/src/utils/encryption.cpp index 18087155..36e9e0d9 100644 --- a/support/src/utils/encryption.cpp +++ b/support/src/utils/encryption.cpp @@ -23,6 +23,7 @@ #include "utils/encryption.hpp" +#include "utils/base64.hpp" #include "utils/collection.hpp" #include "utils/encrypting_reader.hpp" #include "utils/hash.hpp" @@ -97,10 +98,7 @@ auto decrypt_file_name(std::string_view encryption_token, auto decrypt_file_name(std::string_view encryption_token, const kdf_config &cfg, std::string &file_name) -> bool { - data_buffer buffer; - if (not utils::collection::from_hex_string(file_name, buffer)) { - return false; - } + auto buffer = macaron::Base64::Decode(file_name); file_name.clear(); return utils::encryption::decrypt_data(encryption_token, cfg, buffer, @@ -109,10 +107,7 @@ auto decrypt_file_name(std::string_view encryption_token, const kdf_config &cfg, auto decrypt_file_name(const utils::hash::hash_256_t &master_key, std::string &file_name) -> bool { - data_buffer buffer; - if (not utils::collection::from_hex_string(file_name, buffer)) { - return false; - } + auto buffer = macaron::Base64::Decode(file_name); utils::encryption::kdf_config path_cfg; if (not utils::encryption::kdf_config::from_header(buffer, path_cfg)) { diff --git a/support/test/src/utils/base64_test.cpp b/support/test/src/utils/base64_test.cpp new file mode 100644 index 00000000..4a973e93 --- /dev/null +++ b/support/test/src/utils/base64_test.cpp @@ -0,0 +1,389 @@ +/* + Copyright <2018-2025> + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ +#include "test.hpp" + +using macaron::Base64::Decode; +using macaron::Base64::Encode; +using macaron::Base64::EncodeUrlSafe; + +namespace { +[[nodiscard]] auto decode_to_string(std::string_view str) -> std::string { + auto vec = Decode(str); + return {vec.begin(), vec.end()}; +} + +[[nodiscard]] auto standard_to_url_safe(std::string str, bool keep_padding) + -> std::string { + for (auto &cur_ch : str) { + if (cur_ch == '+') { + cur_ch = '-'; + } else if (cur_ch == '/') { + cur_ch = '_'; + } + } + if (not keep_padding) { + while (not str.empty() && str.back() == '=') { + str.pop_back(); + } + } + return str; +} +} // namespace + +TEST(utils_base64, rfc4648_known_vectors_standard_padded) { + struct vec_case { + std::string_view in; + std::string_view b64; + }; + const std::array vectors{{ + {"", ""}, + {"f", "Zg=="}, + {"fo", "Zm8="}, + {"foo", "Zm9v"}, + {"foob", "Zm9vYg=="}, + {"fooba", "Zm9vYmE="}, + {"foobar", "Zm9vYmFy"}, + }}; + + for (const auto &vec_entry : vectors) { + const auto enc_str = + Encode(reinterpret_cast(vec_entry.in.data()), + vec_entry.in.size(), /*url_safe=*/false, /*pad=*/true); + EXPECT_EQ(enc_str, vec_entry.b64); + + const auto dec_vec = Decode(vec_entry.b64); + EXPECT_EQ(std::string(dec_vec.begin(), dec_vec.end()), vec_entry.in); + } +} + +TEST(utils_base64, url_safe_padded_and_unpadded_match_transformed_standard) { + const std::string payload = + std::string("This+/needs/URL-safe mapping and padding checks.") + + std::string("\x00\x01\xFE\xFF", 4); + + const auto std_padded = + Encode(reinterpret_cast(payload.data()), + payload.size(), /*url_safe=*/false, /*pad=*/true); + const auto url_padded = + Encode(reinterpret_cast(payload.data()), + payload.size(), /*url_safe=*/true, /*pad=*/true); + const auto url_unpadded = + Encode(reinterpret_cast(payload.data()), + payload.size(), /*url_safe=*/true, /*pad=*/false); + + const auto url_from_std_padded = + standard_to_url_safe(std_padded, /*keep_padding=*/true); + const auto url_from_std_unpadded = + standard_to_url_safe(std_padded, /*keep_padding=*/false); + + EXPECT_EQ(url_padded, url_from_std_padded); + EXPECT_EQ(url_unpadded, url_from_std_unpadded); + + const auto dec_one = Decode(url_padded); + const auto dec_two = Decode(url_unpadded); + EXPECT_EQ(std::string(dec_one.begin(), dec_one.end()), payload); + EXPECT_EQ(std::string(dec_two.begin(), dec_two.end()), payload); +} + +TEST(utils_base64, empty_input) { + const std::string empty_str; + const auto enc_empty_std = + Encode(reinterpret_cast(empty_str.data()), + empty_str.size(), /*url_safe=*/false, /*pad=*/true); + const auto enc_empty_url = + Encode(reinterpret_cast(empty_str.data()), + empty_str.size(), /*url_safe=*/true, /*pad=*/false); + EXPECT_TRUE(enc_empty_std.empty()); + EXPECT_TRUE(enc_empty_url.empty()); + + const auto dec_empty = Decode(""); + EXPECT_TRUE(dec_empty.empty()); +} + +TEST(utils_base64, remainder_boundaries_round_trip) { + const std::string str_one = "A"; // rem 1 + const std::string str_two = "AB"; // rem 2 + const std::string str_thr = "ABC"; // rem 0 + const std::string str_fou = "ABCD"; // rem 1 after blocks + const std::string str_fiv = "ABCDE"; // rem 2 after blocks + + for (const auto *str_ptr : + {&str_one, &str_two, &str_thr, &str_fou, &str_fiv}) { + const auto enc_std = + Encode(reinterpret_cast(str_ptr->data()), + str_ptr->size(), false, true); + const auto dec_std = Decode(enc_std); + EXPECT_EQ(std::string(dec_std.begin(), dec_std.end()), *str_ptr); + + const auto enc_url_pad = + Encode(reinterpret_cast(str_ptr->data()), + str_ptr->size(), true, true); + const auto dec_url_pad = Decode(enc_url_pad); + EXPECT_EQ(std::string(dec_url_pad.begin(), dec_url_pad.end()), *str_ptr); + + const auto enc_url_nopad = + Encode(reinterpret_cast(str_ptr->data()), + str_ptr->size(), true, false); + const auto dec_url_nopad = Decode(enc_url_nopad); + EXPECT_EQ(std::string(dec_url_nopad.begin(), dec_url_nopad.end()), + *str_ptr); + } +} + +TEST(utils_base64, decode_accepts_standard_and_url_safe_forms) { + const std::string input_str = "Man is distinguished, not only by his reason."; + const auto std_padded = + Encode(reinterpret_cast(input_str.data()), + input_str.size(), false, true); + const auto url_padded = + Encode(reinterpret_cast(input_str.data()), + input_str.size(), true, true); + const auto url_unpadded = + Encode(reinterpret_cast(input_str.data()), + input_str.size(), true, false); + + const auto dec_std = Decode(std_padded); + const auto dec_url_pad = Decode(url_padded); + const auto dec_url_nopad = Decode(url_unpadded); + + EXPECT_EQ(std::string(dec_std.begin(), dec_std.end()), input_str); + EXPECT_EQ(std::string(dec_url_pad.begin(), dec_url_pad.end()), input_str); + EXPECT_EQ(std::string(dec_url_nopad.begin(), dec_url_nopad.end()), input_str); +} + +TEST(utils_base64, all_byte_values_round_trip) { + std::vector byte_vec(256); + for (size_t idx = 0; idx < byte_vec.size(); ++idx) { + byte_vec[idx] = static_cast(idx); + } + + const auto enc_std = Encode(byte_vec.data(), byte_vec.size(), false, true); + const auto dec_std = Decode(enc_std); + ASSERT_EQ(dec_std.size(), byte_vec.size()); + EXPECT_TRUE(std::equal(dec_std.begin(), dec_std.end(), byte_vec.begin())); + + const auto enc_url = Encode(byte_vec.data(), byte_vec.size(), true, false); + const auto dec_url = Decode(enc_url); + ASSERT_EQ(dec_url.size(), byte_vec.size()); + EXPECT_TRUE(std::equal(dec_url.begin(), dec_url.end(), byte_vec.begin())); +} + +TEST(utils_base64, wrapper_encode_url_safe_equals_flagged_encode) { + const std::string data_str = "wrap me!"; + const auto enc_wrap_a = + EncodeUrlSafe(reinterpret_cast(data_str.data()), + data_str.size(), /*pad=*/false); + const auto enc_wrap_b = + Encode(reinterpret_cast(data_str.data()), + data_str.size(), /*url_safe=*/true, /*pad=*/false); + EXPECT_EQ(enc_wrap_a, enc_wrap_b); + + const auto enc_wrap_a2 = EncodeUrlSafe(data_str, /*pad=*/true); + const auto enc_wrap_b2 = Encode(data_str, /*url_safe=*/true, /*pad=*/true); + EXPECT_EQ(enc_wrap_a2, enc_wrap_b2); +} + +TEST(utils_base64, unpadded_length_rules) { + const auto enc_one = Encode("f", /*url_safe=*/true, /*pad=*/false); + const auto enc_two = Encode("fo", /*url_safe=*/true, /*pad=*/false); + const auto enc_thr = Encode("foo", /*url_safe=*/true, /*pad=*/false); + EXPECT_EQ(enc_one.size(), 2U); + EXPECT_EQ(enc_two.size(), 3U); + EXPECT_EQ(enc_thr.size(), 4U); + + EXPECT_EQ(Decode(enc_one), std::vector({'f'})); + EXPECT_EQ(Decode(enc_two), std::vector({'f', 'o'})); + EXPECT_EQ(Decode(enc_thr), std::vector({'f', 'o', 'o'})); +} + +TEST(utils_base64, errors_length_mod4_eq_1) { + EXPECT_THROW(Decode("A"), std::runtime_error); + EXPECT_THROW(Decode("AAAAA"), std::runtime_error); +} + +TEST(utils_base64, errors_invalid_characters) { + EXPECT_THROW(Decode("Zm9v YmFy"), std::runtime_error); + EXPECT_THROW(Decode("Zm9v*YmFy"), std::runtime_error); + EXPECT_THROW(Decode("Z=g="), std::runtime_error); +} + +TEST(utils_base64, reject_whitespace_and_controls) { + // newline, tab, and space should be rejected (decoder does not skip + // whitespace) + EXPECT_THROW(Decode("Zg==\n"), std::runtime_error); + EXPECT_THROW(Decode("Zg==\t"), std::runtime_error); + EXPECT_THROW(Decode("Z g=="), std::runtime_error); +} + +TEST(utils_base64, reject_padding_in_nonfinal_quartet) { + // '=' cannot appear before the final quartet + EXPECT_THROW(Decode("AAA=AAAA"), std::runtime_error); + EXPECT_THROW(Decode("Zg==Zg=="), std::runtime_error); +} + +TEST(utils_base64, reject_padding_in_first_two_slots_of_final_quartet) { + // '=' only allowed in slots 3 and/or 4 of the final quartet + EXPECT_THROW(Decode("=AAA"), std::runtime_error); + EXPECT_THROW(Decode("A=AA"), std::runtime_error); + EXPECT_THROW( + Decode("Z=g="), + std::runtime_error); // already in your suite, kept for completeness +} + +TEST(utils_base64, reject_incorrect_padding_count_for_length) { + // "f" must be "Zg==" (two '='). One '=' is invalid. + EXPECT_THROW(Decode("Zg="), std::runtime_error); + + // "foo" must be unpadded ("Zm9v"). Extra '=' is invalid. + EXPECT_THROW(Decode("Zm9v="), std::runtime_error); + + // "fo" must have exactly one '=' -> "Zm8=" + // Correct cases: + EXPECT_NO_THROW(Decode("Zm8=")); + EXPECT_NO_THROW(Decode("Zm9v")); +} + +TEST(utils_base64, accept_unpadded_equivalents_when_legal) { + EXPECT_EQ(decode_to_string("Zg"), "f"); + EXPECT_EQ(decode_to_string("Zm8"), "fo"); + EXPECT_EQ(decode_to_string("Zm9v"), "foo"); + EXPECT_EQ(decode_to_string("Zm9vYmE"), "fooba"); +} + +TEST(utils_base64, mixed_alphabet_is_accepted) { + const std::string input_str = "any+/mix_/of+chars/"; + const auto std_padded = + Encode(reinterpret_cast(input_str.data()), + input_str.size(), /*url_safe=*/false, /*pad=*/true); + + std::string mixed = std_padded; + for (char &cur_ch : mixed) { + if (cur_ch == '+') { + cur_ch = '-'; + } else if (cur_ch == '/') { + cur_ch = '_'; + } + } + + EXPECT_EQ(decode_to_string(mixed), input_str); +} + +TEST(utils_base64, invalid_non_ascii_octets_in_input) { + // Extended bytes like 0xFF are not valid Base64 characters + std::string bad = "Zg=="; + bad[1] = static_cast(0xFF); + EXPECT_THROW(Decode(bad), std::runtime_error); +} + +TEST(utils_base64, large_buffer_round_trip_and_sizes) { + // Deterministic pseudo-random buffer + const std::size_t byte_len = 1 << 20; // 1 MiB + std::vector data_vec(byte_len); + unsigned int val = 0x12345678U; + for (unsigned char &idx : data_vec) { + val ^= val << 13; + val ^= val >> 17; + val ^= val << 5; // xorshift32 + idx = static_cast(val & 0xFFU); + } + + // Padded encode length should be 4 * ceil(N/3) + const auto enc_pad = Encode(data_vec.data(), data_vec.size(), + /*url_safe=*/false, /*pad=*/true); + const std::size_t expected_padded = 4U * ((byte_len + 2U) / 3U); + EXPECT_EQ(enc_pad.size(), expected_padded); + + // Unpadded encode length rule (RFC 4648 §5) + const auto enc_nopad = Encode(data_vec.data(), data_vec.size(), + /*url_safe=*/true, /*pad=*/false); + const std::size_t rem = byte_len % 3U; + const std::size_t expected_unpadded = + 4U * (byte_len / 3U) + (rem == 0U ? 0U : (rem == 1U ? 2U : 3U)); + EXPECT_EQ(enc_nopad.size(), expected_unpadded); + + // Round-trips + const auto dec_pad = Decode(enc_pad); + const auto dec_nopad = Decode(enc_nopad); + ASSERT_EQ(dec_pad.size(), data_vec.size()); + ASSERT_EQ(dec_nopad.size(), data_vec.size()); + EXPECT_TRUE(std::equal(dec_pad.begin(), dec_pad.end(), data_vec.begin())); + EXPECT_TRUE(std::equal(dec_nopad.begin(), dec_nopad.end(), data_vec.begin())); +} + +TEST(utils_base64, url_safe_round_trip_various_lengths) { + for (std::size_t len : {0U, 1U, 2U, 3U, 4U, 5U, 6U, 7U, 32U, 33U, 64U, 65U}) { + std::vector buf(len); + for (std::size_t i = 0; i < len; ++i) { + buf[i] = static_cast(i * 13U + 7U); + } + + const auto enc_unpadded = + Encode(buf.data(), buf.size(), /*url_safe=*/true, /*pad=*/false); + const auto enc_padded = + Encode(buf.data(), buf.size(), /*url_safe=*/true, /*pad=*/true); + + const auto dec_unpadded = Decode(enc_unpadded); + const auto dec_padded = Decode(enc_padded); + + ASSERT_EQ(dec_unpadded.size(), buf.size()); + ASSERT_EQ(dec_padded.size(), buf.size()); + EXPECT_TRUE( + std::equal(dec_unpadded.begin(), dec_unpadded.end(), buf.begin())); + EXPECT_TRUE(std::equal(dec_padded.begin(), dec_padded.end(), buf.begin())); + } +} + +TEST(utils_base64, reject_trailing_garbage_after_padding) { + // Anything after final '=' padding is invalid + EXPECT_THROW(Decode("Zg==A"), std::runtime_error); + EXPECT_THROW(Decode("Zm8=A"), std::runtime_error); +} + +TEST(utils_base64, reject_three_padding_chars_total) { + // Any string with total length %4==1 is invalid (e.g., "Zg===") + EXPECT_THROW(Decode("Zg==="), std::runtime_error); +} + +TEST(utils_base64, standard_vs_url_safe_encoding_equivalence) { + const std::string msg = "base64 / url-safe + cross-check"; + + const auto std_enc = + Encode(reinterpret_cast(msg.data()), msg.size(), + /*url_safe=*/false, /*pad=*/true); + const auto url_enc = + Encode(reinterpret_cast(msg.data()), msg.size(), + /*url_safe=*/true, /*pad=*/true); + + std::string transformed = std_enc; + for (char &cur_ch : transformed) { + if (cur_ch == '+') { + cur_ch = '-'; + } else if (cur_ch == '/') { + cur_ch = '_'; + } + } + + EXPECT_EQ(url_enc, transformed); + + // decode once, then construct + EXPECT_EQ(decode_to_string(url_enc), msg); +}