diff --git a/include/CppCore.Interface.C/cppcore.h b/include/CppCore.Interface.C/cppcore.h index cf92ddfc..9b7e9adc 100644 --- a/include/CppCore.Interface.C/cppcore.h +++ b/include/CppCore.Interface.C/cppcore.h @@ -152,8 +152,8 @@ extern "C" { // base64 - CPPCORE_EXPORT unsigned int cppcore_base64_symbollength(unsigned int bytes); - CPPCORE_EXPORT unsigned int cppcore_base64_bytelength(char* s, unsigned int len); + CPPCORE_EXPORT unsigned int cppcore_base64_symbollength(unsigned int bytes, unsigned int url); + CPPCORE_EXPORT unsigned int cppcore_base64_bytelength(char* s, unsigned int len, unsigned int url); CPPCORE_EXPORT void cppcore_base64_encode(void* in, unsigned int len, char* out, unsigned int url, unsigned int writeterm); CPPCORE_EXPORT unsigned int cppcore_base64_decode(char* in, unsigned int len, void* out, unsigned int url); diff --git a/include/CppCore.Test/Encoding.h b/include/CppCore.Test/Encoding.h index 977d3ad1..e3b9a350 100644 --- a/include/CppCore.Test/Encoding.h +++ b/include/CppCore.Test/Encoding.h @@ -665,6 +665,7 @@ namespace CppCore { namespace Test public: INLINE static bool bytelength() { + // STD if (CppCore::Base64::bytelength("", 0) != 0) return false; // invalid length if (CppCore::Base64::bytelength("=", 1) != 0) return false; // invalid length if (CppCore::Base64::bytelength("==", 2) != 0) return false; // invalid length @@ -674,45 +675,111 @@ namespace CppCore { namespace Test if (CppCore::Base64::bytelength("aa==", 4) != 1) return false; // ok if (CppCore::Base64::bytelength("aaa=", 4) != 2) return false; // ok if (CppCore::Base64::bytelength("aaaa", 4) != 3) return false; // ok - if (CppCore::Base64::bytelength("aa=a", 4) != 3) return false; // invalid symbol - if (CppCore::Base64::bytelength("(aaa", 4) != 3) return false; // invalid symbol + if (CppCore::Base64::bytelength("aa=a", 4) != 3) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("(aaa", 4) != 3) return false; // ok (invalid symbol) if (CppCore::Base64::bytelength("aaaaa", 5) != 0) return false; // invalid length if (CppCore::Base64::bytelength("aaaaaa", 6) != 0) return false; // invalid length if (CppCore::Base64::bytelength("aaaaaa=", 7) != 0) return false; // invalid length if (CppCore::Base64::bytelength("aaaaaa==", 8) != 4) return false; // ok if (CppCore::Base64::bytelength("aaaaaaa=", 8) != 5) return false; // ok if (CppCore::Base64::bytelength("aaaaaaaa", 8) != 6) return false; // ok + // URL + if (CppCore::Base64::bytelength("", 0, true) != 0) return false; // invalid length + if (CppCore::Base64::bytelength("=", 1, true) != 0) return false; // invalid length + if (CppCore::Base64::bytelength("==", 2, true) != 1) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("a", 1, true) != 0) return false; // invalid length + if (CppCore::Base64::bytelength("aa", 2, true) != 1) return false; // ok + if (CppCore::Base64::bytelength("aa=", 3, true) != 2) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aa==", 4, true) != 3) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aaa=", 4, true) != 3) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aaaa", 4, true) != 3) return false; // ok + if (CppCore::Base64::bytelength("aa=a", 4, true) != 3) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("(aaa", 4, true) != 3) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aaaaa", 5, true) != 0) return false; // invalid length + if (CppCore::Base64::bytelength("aaaaaa", 6, true) != 4) return false; // ok + if (CppCore::Base64::bytelength("aaaaaa=", 7, true) != 5) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aaaaaa==", 8, true) != 6) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aaaaaaa=", 8, true) != 6) return false; // ok (invalid symbol) + if (CppCore::Base64::bytelength("aaaaaaaa", 8, true) != 6) return false; // ok return true; } INLINE static bool symbollength() { - if (CppCore::Base64::symbollength(0) != 0) return false; - if (CppCore::Base64::symbollength(1) != 4) return false; - if (CppCore::Base64::symbollength(2) != 4) return false; - if (CppCore::Base64::symbollength(3) != 4) return false; - if (CppCore::Base64::symbollength(4) != 8) return false; - if (CppCore::Base64::symbollength(5) != 8) return false; - if (CppCore::Base64::symbollength(6) != 8) return false; + // STD + if (CppCore::Base64::symbollength(0, false) != 0) return false; + if (CppCore::Base64::symbollength(1, false) != 4) return false; + if (CppCore::Base64::symbollength(2, false) != 4) return false; + if (CppCore::Base64::symbollength(3, false) != 4) return false; + if (CppCore::Base64::symbollength(4, false) != 8) return false; + if (CppCore::Base64::symbollength(5, false) != 8) return false; + if (CppCore::Base64::symbollength(6, false) != 8) return false; + // URL + if (CppCore::Base64::symbollength(0, true) != 0) return false; + if (CppCore::Base64::symbollength(1, true) != 2) return false; + if (CppCore::Base64::symbollength(2, true) != 3) return false; + if (CppCore::Base64::symbollength(3, true) != 4) return false; + if (CppCore::Base64::symbollength(4, true) != 6) return false; + if (CppCore::Base64::symbollength(5, true) != 7) return false; + if (CppCore::Base64::symbollength(6, true) != 8) return false; return true; } - INLINE static bool encode() + INLINE static bool encode_std() { std::string s; - CppCore::Base64::encode("1", s); if (s != "MQ==") return false; - CppCore::Base64::encode("12", s); if (s != "MTI=") return false; - CppCore::Base64::encode("123", s); if (s != "MTIz") return false; - CppCore::Base64::encode("1234", s); if (s != "MTIzNA==") return false; - CppCore::Base64::encode("12345", s); if (s != "MTIzNDU=") return false; - CppCore::Base64::encode("123456", s); if (s != "MTIzNDU2") return false; - CppCore::Base64::encode("1234567", s); if (s != "MTIzNDU2Nw==") return false; - CppCore::Base64::encode("12345678", s); if (s != "MTIzNDU2Nzg=") return false; - CppCore::Base64::encode("123456789", s); if (s != "MTIzNDU2Nzg5") return false; - CppCore::Base64::encode("", s); if (s != "") return false; + CppCore::Base64::encode("1", s); if (s != "MQ==") return false; + CppCore::Base64::encode("12", s); if (s != "MTI=") return false; + CppCore::Base64::encode("123", s); if (s != "MTIz") return false; + CppCore::Base64::encode("1234", s); if (s != "MTIzNA==") return false; + CppCore::Base64::encode("12345", s); if (s != "MTIzNDU=") return false; + CppCore::Base64::encode("123456", s); if (s != "MTIzNDU2") return false; + CppCore::Base64::encode("1234567", s); if (s != "MTIzNDU2Nw==") return false; + CppCore::Base64::encode("12345678", s); if (s != "MTIzNDU2Nzg=") return false; + CppCore::Base64::encode("123456789", s); if (s != "MTIzNDU2Nzg5") return false; + CppCore::Base64::encode("1234567890abcdef", s); if (s != "MTIzNDU2Nzg5MGFiY2RlZg==") return false; + CppCore::Base64::encode("", s); if (s != "") return false; uint8_t d1[] = { 0x01, 0x02, 0x03 }; CppCore::Base64::encode(d1, s); if (s != "AQID") return false; - uint8_t d2[] = { 0xFF }; CppCore::Base64::encode(d2, s); if (s != "/w==") return false; + uint8_t d2[] = { 0xFF, 0xEE }; CppCore::Base64::encode(d2, s); if (s != "/+4=") return false; + uint8_t d3[] = { 0xFF }; CppCore::Base64::encode(d3, s); if (s != "/w==") return false; + uint8_t d4[] = { + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff + }; + CppCore::Base64::encode(d4, s); + if (s != "/+7//+7//+7//+7//+7////u///u///u///u///u////7v//7v//7v//7v//7v///+7//+7//+7//+7//+7//w==") + return false; return true; } - INLINE static bool decode() + INLINE static bool encode_url() + { + std::string s; + CppCore::Base64::encode("1", s, true); if (s != "MQ") return false; + CppCore::Base64::encode("12", s, true); if (s != "MTI") return false; + CppCore::Base64::encode("123", s, true); if (s != "MTIz") return false; + CppCore::Base64::encode("1234", s, true); if (s != "MTIzNA") return false; + CppCore::Base64::encode("12345", s, true); if (s != "MTIzNDU") return false; + CppCore::Base64::encode("123456", s, true); if (s != "MTIzNDU2") return false; + CppCore::Base64::encode("1234567", s, true); if (s != "MTIzNDU2Nw") return false; + CppCore::Base64::encode("12345678", s, true); if (s != "MTIzNDU2Nzg") return false; + CppCore::Base64::encode("123456789", s, true); if (s != "MTIzNDU2Nzg5") return false; + CppCore::Base64::encode("1234567890abcdef", s, true); if (s != "MTIzNDU2Nzg5MGFiY2RlZg") return false; + CppCore::Base64::encode("", s, true); if (s != "") return false; + uint8_t d1[] = { 0x01, 0x02, 0x03 }; CppCore::Base64::encode(d1, s, true); if (s != "AQID") return false; + uint8_t d2[] = { 0xFF, 0xEE }; CppCore::Base64::encode(d2, s, true); if (s != "_-4") return false; + uint8_t d3[] = { 0xFF }; CppCore::Base64::encode(d3, s, true); if (s != "_w") return false; + uint8_t d4[] = { + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, + 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff, 0xee, 0xff, 0xff + }; + CppCore::Base64::encode(d4, s, true); + if (s != "_-7__-7__-7__-7__-7____u___u___u___u___u____7v__7v__7v__7v__7v___-7__-7__-7__-7__-7__w") + return false; + return true; + } + INLINE static bool decode_std() { std::string s; if (!CppCore::Base64::decode("MQ==", s) || s != "1") return false; @@ -732,6 +799,60 @@ namespace CppCore { namespace Test if (!CppCore::Base64::decode("/w==", d1) || d1[0] != 0xFF || d1[1] != 0x00 || d1[2] != 0x00) return false; if (!CppCore::Base64::decode("/w==", d2) || d2[0] != 0xFF) return false; if ( CppCore::Base64::decode("AQID", d2)) return false; // too large + uint8_t d3_gen[48]; + uint8_t d3_exp[48] = { + 0x00, 0x10, 0x83, 0x10, 0x51, 0x87, 0x20, 0x92, 0x8b, 0x30, 0xd3, 0x8f, + 0x41, 0x14, 0x93, 0x51, 0x55, 0x97, 0x61, 0x96, 0x9b, 0x71, 0xd7, 0x9f, + 0x82, 0x18, 0xa3, 0x92, 0x59, 0xa7, 0xa2, 0x9a, 0xab, 0xb2, 0xdb, 0xaf, + 0xc3, 0x1c, 0xb3, 0xd3, 0x5d, 0xb7, 0xe3, 0x9e, 0xbb, 0xf3, 0xdf, 0xbf + }; + if (!CppCore::Base64::decode("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", d3_gen)) + return false; + if (::memcmp(d3_gen, d3_exp, 48) != 0) + return false; + if (CppCore::Base64::decode("#BCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", d3_gen)) + return false; // invalid symbol '#' at index 0 + if (CppCore::Base64::decode("ABCDEFGHIJKLMNOPQRSTUVWXY!abcdefghijklmnopqrstuvwxyz0123456789+/", d3_gen)) + return false; // invalid symbol '!' at index 25 + return true; + } + INLINE static bool decode_url() + { + std::string s; + if (!CppCore::Base64::decode("MQ", s, true) || s != "1") return false; + if (!CppCore::Base64::decode("MTI", s, true) || s != "12") return false; + if (!CppCore::Base64::decode("MTIz", s, true) || s != "123") return false; + if (!CppCore::Base64::decode("MTIzNA", s, true) || s != "1234") return false; + if ( CppCore::Base64::decode("", s, true)) return false; // empty string + if ( CppCore::Base64::decode("Z", s, true)) return false; // invalid length + if ( CppCore::Base64::decode("M=", s, true)) return false; // invalid symbol + if ( CppCore::Base64::decode("M==", s, true)) return false; // invalid symbol + if ( CppCore::Base64::decode("(aaa", s, true)) return false; // invalid symbol + if ( CppCore::Base64::decode("aa=a", s, true)) return false; // invalid symbol + if ( CppCore::Base64::decode("MQ==", s, true)) return false; // invalid symbol + if ( CppCore::Base64::decode("MTI=", s, true)) return false; // invalid symbol + if ( CppCore::Base64::decode("MTIaa", s, true)) return false; // invalid length + uint8_t d1[3]; + uint8_t d2[1]; + if (!CppCore::Base64::decode("AQID", d1, true) || d1[0] != 0x01 || d1[1] != 0x02 || d1[2] != 0x03) return false; + if (!CppCore::Base64::decode("_w", d1, true) || d1[0] != 0xFF || d1[1] != 0x00 || d1[2] != 0x00) return false; + if (!CppCore::Base64::decode("_w", d2, true) || d2[0] != 0xFF) return false; + if ( CppCore::Base64::decode("AQID", d2, true)) return false; // too large + uint8_t d3_gen[48]; + uint8_t d3_exp[48] = { + 0x00, 0x10, 0x83, 0x10, 0x51, 0x87, 0x20, 0x92, 0x8b, 0x30, 0xd3, 0x8f, + 0x41, 0x14, 0x93, 0x51, 0x55, 0x97, 0x61, 0x96, 0x9b, 0x71, 0xd7, 0x9f, + 0x82, 0x18, 0xa3, 0x92, 0x59, 0xa7, 0xa2, 0x9a, 0xab, 0xb2, 0xdb, 0xaf, + 0xc3, 0x1c, 0xb3, 0xd3, 0x5d, 0xb7, 0xe3, 0x9e, 0xbb, 0xf3, 0xdf, 0xbf + }; + if (!CppCore::Base64::decode("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", d3_gen, true)) + return false; + if (::memcmp(d3_gen, d3_exp, 48) != 0) + return false; + if (CppCore::Base64::decode("#BCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", d3_gen, true)) + return false; // invalid symbol '#' at index 0 + if (CppCore::Base64::decode("ABCDEFGHIJKLMNOPQRSTUVWXY!abcdefghijklmnopqrstuvwxyz0123456789-_", d3_gen, true)) + return false; // invalid symbol '!' at index 25 return true; } }; @@ -1237,8 +1358,10 @@ namespace CppCore { namespace Test { namespace VS TEST_METHOD(HEX_PARSE64) { Assert::AreEqual(true, CppCore::Test::Encoding::Hex::parse64()); } TEST_METHOD(BASE64_BYTELENGTH) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::bytelength()); } TEST_METHOD(BASE64_SYMBOLLENGTH) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::symbollength()); } - TEST_METHOD(BASE64_ENCODE) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::encode()); } - TEST_METHOD(BASE64_DECODE) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::decode()); } + TEST_METHOD(BASE64_ENCODE_STD) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::encode_std()); } + TEST_METHOD(BASE64_ENCODE_URL) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::encode_url()); } + TEST_METHOD(BASE64_DECODE_STD) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::decode_std()); } + TEST_METHOD(BASE64_DECODE_URL) { Assert::AreEqual(true, CppCore::Test::Encoding::Base64::decode_url()); } TEST_METHOD(DEC_TOSTRING8U) { Assert::AreEqual(true, CppCore::Test::Encoding::Decimal::tostring8u()); } TEST_METHOD(DEC_TOSTRING8S) { Assert::AreEqual(true, CppCore::Test::Encoding::Decimal::tostring8s()); } TEST_METHOD(DEC_TOSTRING16U) { Assert::AreEqual(true, CppCore::Test::Encoding::Decimal::tostring16u()); } diff --git a/include/CppCore/Encoding.h b/include/CppCore/Encoding.h index 3f668b3a..0f4b0bb4 100644 --- a/include/CppCore/Encoding.h +++ b/include/CppCore/Encoding.h @@ -775,29 +775,37 @@ namespace CppCore /// /// Returns the number of bytes needed to store symbols. - /// Returns 0 if len isn't a multiple of 4. + /// Returns 0 if len%4!=0 (url=false) or len%4=1 (url=true). /// - INLINE static size_t bytelength(const char* s, size_t len) - { - if (((len == 0) | (len & 0x03)) != 0) CPPCORE_UNLIKELY - return 0; - size_t n = (len / 4U) * 3U; - if (s[len-1] == '=') { - n--; - if (s[len-2] == '=') - n--; + INLINE static size_t bytelength(const char* s, size_t len, bool url = false) + { + size_t full = (len >> 2U) * 3U; + size_t tail = (len & 0x03U); + if (url) + { + return + (tail == 0U) ? (full) : + (tail & 0x02U ? (full + (tail-1U)) : 0U); + } + else + { + if ((len == 0U) | (tail != 0U)) + return 0U; + size_t t = len; + if (s[len-1] == '=') len--; + if (s[len-1] == '=') len--; + return full - (t-len); } - return n; } /// /// Returns the number of symbols needed to store bytes. /// - INLINE static size_t symbollength(size_t bytes) + INLINE static size_t symbollength(size_t bytes, bool url = false) { - if constexpr(sizeof(size_t) == 4U) return (CppCore::rup32((uint32_t)bytes, 3U) / 3U) * 4U; - else if constexpr(sizeof(size_t) == 8U) return (CppCore::rup64((uint64_t)bytes, 3U) / 3U) * 4U; - else assert(false); + size_t full = (bytes / 3U) * 4U; + size_t tail = (bytes % 3U); + return full + (tail ? (url ? tail + 1 : 4U) : 0U); } /// @@ -807,10 +815,55 @@ namespace CppCore /// INLINE static void encode(const void* in, size_t len, char* out, bool url = false, bool writeterm = true) { + static_assert(CPPCORE_ENDIANESS_LITTLE); const char* tbl = url ? Base64::BINTOB64_URL : Base64::BINTOB64_STD; const uint8_t* p = (const uint8_t*)in; + #if defined(CPPCORE_CPUFEAT_SSSE3) + if (len >= 64U) + { + // 16 symbols from 12 bytes + // adapted from: https://github.com/WojciechMula/base64simd + // only if reasonable large due to overhead for loading masks to sse registers + const __m128i SHUF1 = _mm_set_epi8( + 10, 11, 9, 10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1 + ); + const __m128i SHUF2_URL = _mm_setr_epi8( + 'a' - 26, '0' - 52, '0' - 52, '0' - 52, '0' - 52, '0' - 52, + '0' - 52, '0' - 52, '0' - 52, '0' - 52, '0' - 52, '-' - 62, + '_' - 63, 'A', 0, 0); + const __m128i SHUF2_STD = _mm_setr_epi8( + 'a' - 26, '0' - 52, '0' - 52, '0' - 52, '0' - 52, '0' - 52, + '0' - 52, '0' - 52, '0' - 52, '0' - 52, '0' - 52, '+' - 62, + '/' - 63, 'A', 0, 0); + const __m128i SHUF2 = url ? SHUF2_URL : SHUF2_STD; + const __m128i M1 = _mm_set1_epi32(0x0fc0fc00); + const __m128i M2 = _mm_set1_epi32(0x04000040); + const __m128i M3 = _mm_set1_epi32(0x003f03f0); + const __m128i M4 = _mm_set1_epi32(0x01000010); + const __m128i M5 = _mm_set1_epi8(51); + const __m128i M6 = _mm_set1_epi8(26); + const __m128i M7 = _mm_set1_epi8(13); + do + { + __m128i t, r; + t = _mm_loadu_si128((const __m128i*)p); + t = _mm_shuffle_epi8(t, SHUF1); + t = _mm_or_si128( + _mm_mulhi_epu16(_mm_and_si128(t, M1), M2), + _mm_mullo_epi16(_mm_and_si128(t, M3), M4)); + r = _mm_subs_epu8(t, M5); + r = _mm_or_si128(r, _mm_and_si128(_mm_cmpgt_epi8(M6, t), M7)); + r = _mm_shuffle_epi8(SHUF2, r); + r = _mm_add_epi8(r, t); + _mm_storeu_si128((__m128i*)out, r); + p += 12; + len -= 12; + out += 16; + } while (len >= 16U); + } + #endif while (len >= 6U) { // 8 symbols from 6 bytes @@ -894,7 +947,8 @@ namespace CppCore *out++ = tbl[s1]; *out++ = tbl[s2]; *out++ = tbl[s3]; - *out++ = '='; + if (!url) + *out++ = '='; } else if (len) { @@ -903,8 +957,10 @@ namespace CppCore uint8_t s2 = (p[0] << 4) & 0x30; *out++ = tbl[s1]; *out++ = tbl[s2]; - *out++ = '='; - *out++ = '='; + if (!url) { + *out++ = '='; + *out++ = '='; + } } if (writeterm) *out = 0x00; @@ -927,7 +983,7 @@ namespace CppCore INLINE static void encode(const void* in, size_t len, std::string& out, bool url = false, bool writeterm = true) { - out.resize(Base64::symbollength(len)); + out.resize(Base64::symbollength(len, url)); Base64::encode(in, len, out.data(), url, writeterm); } @@ -974,7 +1030,7 @@ namespace CppCore tail = read % 3U; diff = read - tail; CppCore::Base64::encode(bin, diff, bout, url, false); - out.write(bout, CppCore::Base64::symbollength(diff)); + out.write(bout, CppCore::Base64::symbollength(diff, url)); if (tail == 1) { bin[0] = bin[read-1]; @@ -987,7 +1043,7 @@ namespace CppCore } if (tail) { CppCore::Base64::encode(bin, tail, bout, url, false); - out.write(bout, CppCore::Base64::symbollength(tail)); + out.write(bout, CppCore::Base64::symbollength(tail, url)); } if (writeterm) out << '\0'; @@ -1028,16 +1084,94 @@ namespace CppCore /// INLINE static bool decode(const char* in, size_t len, void* out, bool url = false) { - if (((len == 0) | (len & 0x03)) != 0) CPPCORE_UNLIKELY + if (len == 0) CPPCORE_UNLIKELY return false; - if (in[len-1] == '=') len--; - if (in[len-1] == '=') len--; + size_t tail = (len & 0x03); + if (url) + { + if (tail == 1U) + return false; + } + else + { + if (tail) return false; + if (in[len-1] == '=') len--; + if (in[len-1] == '=') len--; + } uint8_t* p = (uint8_t*)out; uint32_t r = 0; uint32_t v; const uint8_t* tbl = url ? Base64::B64TOBIN_URL : Base64::B64TOBIN_STD; + #if defined(CPPCORE_CPUFEAT_SSSE3) + if (len >= 64) + { + // 12 bytes from 16 symbols + // adapted from: https://github.com/WojciechMula/base64simd + // only if reasonable large due to overhead for loading masks to sse registers + const __m128i CMP_GT_A = _mm_set1_epi8('A'-1); + const __m128i CMP_LT_Z = _mm_set1_epi8('Z'+1); + const __m128i CMP_GT_a = _mm_set1_epi8('a'-1); + const __m128i CMP_LT_z = _mm_set1_epi8('z'+1); + const __m128i CMP_GT_0 = _mm_set1_epi8('0'-1); + const __m128i CMP_LT_9 = _mm_set1_epi8('9'+1); + const __m128i CMP_E_C1 = url ? _mm_set1_epi8('-') : _mm_set1_epi8('+'); + const __m128i CMP_E_C2 = url ? _mm_set1_epi8('_') : _mm_set1_epi8('/'); + const __m128i M1 = _mm_set1_epi8(-65); + const __m128i M2 = _mm_set1_epi8(-71); + const __m128i M3 = _mm_set1_epi8(4); + const __m128i M4 = url ? _mm_set1_epi8(17) : _mm_set1_epi8(19); + const __m128i M5 = url ? _mm_set1_epi8(-32) : _mm_set1_epi8(16); + const __m128i M6 = _mm_set1_epi32(0x01400140); + const __m128i M7 = _mm_set1_epi32(0x00011000); + const __m128i SHUF = _mm_setr_epi8( + 2, 1, 0, 6, 5, 4, + 10, 9, 8, 14, 13, 12, + -1, -1, -1, -1); + __m128i e = _mm_setzero_si128(); + do + { + __m128i v = _mm_loadu_si128((__m128i*)in); + const __m128i ge_A = _mm_cmpgt_epi8(v, CMP_GT_A); + const __m128i le_Z = _mm_cmplt_epi8(v, CMP_LT_Z); + const __m128i r_AZ = _mm_and_si128(M1, _mm_and_si128(ge_A, le_Z)); + const __m128i ge_a = _mm_cmpgt_epi8(v, CMP_GT_a); + const __m128i le_z = _mm_cmplt_epi8(v, CMP_LT_z); + const __m128i r_az = _mm_and_si128(M2, _mm_and_si128(ge_a, le_z)); + const __m128i ge_0 = _mm_cmpgt_epi8(v, CMP_GT_0); + const __m128i le_9 = _mm_cmplt_epi8(v, CMP_LT_9); + const __m128i r_09 = _mm_and_si128(M3, _mm_and_si128(ge_0, le_9)); + const __m128i e_c1 = _mm_cmpeq_epi8(v, CMP_E_C1); + const __m128i r_c1 = _mm_and_si128(M4, e_c1); + const __m128i e_c2 = _mm_cmpeq_epi8(v, CMP_E_C2); + const __m128i r_c2 = _mm_and_si128(M5, e_c2); + const __m128i shift = + _mm_or_si128(r_AZ, + _mm_or_si128(r_az, + _mm_or_si128(r_09, + _mm_or_si128(r_c1, r_c2)))); + e = _mm_or_si128(e, _mm_cmpeq_epi8(shift, _mm_setzero_si128())); + v = _mm_add_epi8(v, shift); + v = _mm_maddubs_epi16(v, M6); + v = _mm_madd_epi16(v, M7); + v = _mm_shuffle_epi8(v, SHUF); + _mm_storeu_si64((__m128i*)(p), v); + _mm_storeu_si32((__m128i*)(p+8U), _mm_castps_si128(_mm_movehl_ps( + _mm_castsi128_ps(v), _mm_castsi128_ps(v)))); + p += 12; + len -= 16; + in += 16; + } while (len >= 16); + #if defined(CPPCORE_CPUFEAT_SSE41) + if (_mm_testz_si128(e, e) == 0) + return false; + #else + if (_mm_movemask_epi8(e) != 0) + return false; + #endif + } + #endif while (len >= 4) { // 3 bytes from 4 symbols @@ -1092,7 +1226,7 @@ namespace CppCore INLINE static bool decode(const char* in, size_t len, std::string& out, bool url = false) { - out.resize(Base64::bytelength(in, len)); + out.resize(Base64::bytelength(in, len, url)); return Base64::decode(in, len, out.data(), url); } @@ -1114,7 +1248,7 @@ namespace CppCore template INLINE static bool decode(const char* in, size_t len, T& out, bool url = false, bool clear = true) { - const auto BLEN = Base64::bytelength(in, len); + const auto BLEN = Base64::bytelength(in, len, url); if (BLEN > sizeof(T)) return false; if (clear && BLEN < sizeof(T)) @@ -1152,7 +1286,7 @@ namespace CppCore size_t read = in.gcount(); if (!CppCore::Base64::decode(bin, read, (void*)bout, url)) return false; - out.write(bout, CppCore::Base64::bytelength(bin, read)); + out.write(bout, CppCore::Base64::bytelength(bin, read, url)); } return true; } diff --git a/src/CppCore.Interface.C/cppcore.cpp b/src/CppCore.Interface.C/cppcore.cpp index 0e1d79dc..451f36f3 100644 --- a/src/CppCore.Interface.C/cppcore.cpp +++ b/src/CppCore.Interface.C/cppcore.cpp @@ -134,11 +134,11 @@ CPPCORE_BASE16_IMPLEMENTATION(8192, CppCore::Block8192) // BASE64 /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -unsigned int cppcore_base64_symbollength(unsigned int bytes) { - return CppCore::Base64::symbollength(bytes); +unsigned int cppcore_base64_symbollength(unsigned int bytes, unsigned int url) { + return CppCore::Base64::symbollength(bytes, (bool)url); } -unsigned int cppcore_base64_bytelength(char* s, unsigned int len) { - return CppCore::Base64::bytelength(s, len); +unsigned int cppcore_base64_bytelength(char* s, unsigned int len, unsigned int url) { + return CppCore::Base64::bytelength(s, len, (bool)url); } void cppcore_base64_encode(void* in, unsigned int len, char* out, unsigned int url, unsigned int writeterm) { CppCore::Base64::encode(in, (size_t)len, out, (bool)url, (bool)writeterm); diff --git a/src/CppCore.Test/Test.cpp b/src/CppCore.Test/Test.cpp index 3ceb0b57..8ea3e87c 100644 --- a/src/CppCore.Test/Test.cpp +++ b/src/CppCore.Test/Test.cpp @@ -770,8 +770,10 @@ int main() std::cout << "-------------------------------" << std::endl; TEST(CppCore::Test::Encoding::Base64::bytelength, "bytelength: ", std::endl); TEST(CppCore::Test::Encoding::Base64::symbollength, "symbollength: ", std::endl); - TEST(CppCore::Test::Encoding::Base64::encode, "encode: ", std::endl); - TEST(CppCore::Test::Encoding::Base64::decode, "decode: ", std::endl); + TEST(CppCore::Test::Encoding::Base64::encode_std, "encode_std: ", std::endl); + TEST(CppCore::Test::Encoding::Base64::encode_url, "encode_url: ", std::endl); + TEST(CppCore::Test::Encoding::Base64::decode_std, "decode_std: ", std::endl); + TEST(CppCore::Test::Encoding::Base64::decode_url, "decode_url: ", std::endl); #ifndef CPPCORE_NO_SOCKET std::cout << "-------------------------------" << std::endl;