diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 836fecec960..4b77f0fc296 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -65,6 +65,7 @@ set(SRC_FILES engine.cc date_utils.cc encrypt_utils_common.cc + encrypt_utils_iv.cc encrypt_utils_ecb.cc encrypt_utils_cbc.cc encrypt_utils_gcm.cc @@ -269,6 +270,7 @@ add_gandiva_test(internals-test encrypt_utils_ecb_test.cc encrypt_utils_cbc_test.cc encrypt_utils_gcm_test.cc + encrypt_utils_iv_test.cc encrypt_utils_common_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_mode_dispatcher.cc b/cpp/src/gandiva/encrypt_mode_dispatcher.cc index fad1c54ba9f..7a0dcd7cde1 100644 --- a/cpp/src/gandiva/encrypt_mode_dispatcher.cc +++ b/cpp/src/gandiva/encrypt_mode_dispatcher.cc @@ -27,7 +27,6 @@ namespace gandiva { -// Supported encryption modes static const std::vector SUPPORTED_MODES = { AES_ECB_MODE, AES_ECB_PKCS7_MODE, AES_ECB_NONE_MODE, AES_CBC_MODE, AES_CBC_PKCS7_MODE, AES_CBC_NONE_MODE, @@ -42,10 +41,19 @@ enum class EncryptionMode { CBC_PKCS7, CBC_NONE, GCM, + NULL_VALUE, UNKNOWN }; -EncryptionMode ParseEncryptionMode(std::string_view mode_str) { +EncryptionMode ParseEncryptionMode(const char* mode, int32_t mode_len, bool mode_validity) { + if (!mode_validity) { + return EncryptionMode::NULL_VALUE; + } + + // Convert mode string to uppercase for case-insensitive comparison + std::string mode_str = + arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); + if (mode_str == AES_ECB_MODE) return EncryptionMode::ECB; if (mode_str == AES_ECB_PKCS7_MODE) return EncryptionMode::ECB_PKCS7; if (mode_str == AES_ECB_NONE_MODE) return EncryptionMode::ECB_NONE; @@ -53,84 +61,88 @@ EncryptionMode ParseEncryptionMode(std::string_view mode_str) { if (mode_str == AES_CBC_PKCS7_MODE) return EncryptionMode::CBC_PKCS7; if (mode_str == AES_CBC_NONE_MODE) return EncryptionMode::CBC_NONE; if (mode_str == AES_GCM_MODE) return EncryptionMode::GCM; + return EncryptionMode::UNKNOWN; } +std::string BuildUnsupportedModeError(const char* operation, const char* mode, int32_t mode_len) { + std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", "); + std::ostringstream oss; + oss << "Unsupported " << operation << " mode: " << std::string_view(mode, mode_len) + << ". Supported modes: " << modes_str; + return oss.str(); +} + int32_t EncryptModeDispatcher::encrypt( - const char* plaintext, int32_t plaintext_len, const char* key, - int32_t key_len, const char* mode, int32_t mode_len, const char* iv, - int32_t iv_len, const char* fifth_argument, int32_t fifth_argument_len, - unsigned char* cipher) { - std::string mode_str = - arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); + const char* plaintext, int32_t plaintext_len, + const char* key, int32_t key_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv, int32_t iv_len, bool iv_validity, + const char* fifth_argument, int32_t fifth_argument_len, + bool fifth_argument_validity, unsigned char* cipher) { + if (!key_validity) { + throw std::runtime_error("Encryption key cannot be NULL"); + } - switch (ParseEncryptionMode(mode_str)) { + switch (ParseEncryptionMode(mode, mode_len, mode_validity)) { case EncryptionMode::ECB: case EncryptionMode::ECB_PKCS7: - // Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7 padding return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, true, cipher); case EncryptionMode::ECB_NONE: - // ECB without padding return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, false, cipher); case EncryptionMode::CBC: case EncryptionMode::CBC_PKCS7: - // Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7 return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, iv, iv_len, true, cipher); case EncryptionMode::CBC_NONE: - // CBC without padding return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, iv, iv_len, false, cipher); case EncryptionMode::GCM: return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len, iv, iv_len, fifth_argument, fifth_argument_len, cipher); + case EncryptionMode::NULL_VALUE: + throw std::runtime_error(BuildUnsupportedModeError("encryption", "NULL", 4)); case EncryptionMode::UNKNOWN: - default: { - std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", "); - std::ostringstream oss; - oss << "Unsupported encryption mode: " << mode_str - << ". Supported modes: " << modes_str; - throw std::runtime_error(oss.str()); - } + default: + throw std::runtime_error(BuildUnsupportedModeError("encryption", mode, mode_len)); } } int32_t EncryptModeDispatcher::decrypt( - const char* ciphertext, int32_t ciphertext_len, const char* key, - int32_t key_len, const char* mode, int32_t mode_len, const char* iv, - int32_t iv_len, const char* fifth_argument, int32_t fifth_argument_len, - unsigned char* plaintext) { - std::string mode_str = - arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); + const char* ciphertext, int32_t ciphertext_len, + const char* key, int32_t key_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv, int32_t iv_len, bool iv_validity, + const char* fifth_argument, int32_t fifth_argument_len, + bool fifth_argument_validity, unsigned char* plaintext) { + // If key is NULL (validity flag is false), throw error + if (!key_validity) { + throw std::runtime_error("Decryption key cannot be NULL"); + } - switch (ParseEncryptionMode(mode_str)) { + switch (ParseEncryptionMode(mode, mode_len, mode_validity)) { case EncryptionMode::ECB: case EncryptionMode::ECB_PKCS7: - // Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7 padding return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, true, plaintext); case EncryptionMode::ECB_NONE: - // ECB without padding return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, false, plaintext); case EncryptionMode::CBC: case EncryptionMode::CBC_PKCS7: - // Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7 return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, iv, iv_len, true, plaintext); case EncryptionMode::CBC_NONE: - // CBC without padding + // CBC mode without padding return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, iv, iv_len, false, plaintext); case EncryptionMode::GCM: return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len, iv, iv_len, fifth_argument, fifth_argument_len, plaintext); case EncryptionMode::UNKNOWN: - default: { - std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", "); - std::ostringstream oss; - oss << "Unsupported decryption mode: " << mode_str - << ". Supported modes: " << modes_str; - throw std::runtime_error(oss.str()); - } + default: + if (!mode_validity) { + throw std::runtime_error(BuildUnsupportedModeError("decryption", "NULL", 4)); + } + throw std::runtime_error(BuildUnsupportedModeError("decryption", mode, mode_len)); } } diff --git a/cpp/src/gandiva/encrypt_mode_dispatcher.h b/cpp/src/gandiva/encrypt_mode_dispatcher.h index 20326845bd0..d19c76422fb 100644 --- a/cpp/src/gandiva/encrypt_mode_dispatcher.h +++ b/cpp/src/gandiva/encrypt_mode_dispatcher.h @@ -33,47 +33,57 @@ class EncryptModeDispatcher { * * @param plaintext The data to encrypt * @param plaintext_len Length of plaintext in bytes - * @param key The encryption key + * @param key The encryption key (16, 24, or 32 bytes for AES-128/192/256) * @param key_len Length of key in bytes - * @param mode Mode string + * @param key_validity Whether key is valid + * @param mode Mode string (case-insensitive) * @param mode_len Length of mode string in bytes - * @param iv The initialization vector (optional, only for modes that support it) + * @param mode_validity Whether mode is valid + * @param iv The initialization vector * @param iv_len Length of the IV in bytes - * @param fifth_argument Additional parameter (optional, only for modes that support it) + * @param iv_validity Whether IV is valid + * @param fifth_argument Additional parameter (e.g. AAD for the GCM mode) * @param fifth_argument_len Length of fifth_argument in bytes + * @param fifth_argument_validity Whether fifth_argument is valid * @param cipher Output buffer for encrypted data * @return Length of encrypted data in bytes - * @throws std::runtime_error on encryption failure or unsupported mode + * @throws std::runtime_error on encryption failure, unsupported mode, or invalid parameters */ static int32_t encrypt(const char* plaintext, int32_t plaintext_len, - const char* key, int32_t key_len, - const char* mode, int32_t mode_len, - const char* iv, int32_t iv_len, + const char* key, int32_t key_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv, int32_t iv_len, bool iv_validity, const char* fifth_argument, int32_t fifth_argument_len, + bool fifth_argument_validity, unsigned char* cipher); /** * Decrypt data using the specified mode * - * @param ciphertext The data to decrypt + * @param ciphertext The data to decrypt (format depends on mode and IV parameter) * @param ciphertext_len Length of ciphertext in bytes - * @param key The decryption key + * @param key The decryption key (16, 24, or 32 bytes for AES-128/192/256) * @param key_len Length of key in bytes - * @param mode Mode string + * @param key_validity Whether key is valid + * @param mode Mode string (case-insensitive) * @param mode_len Length of mode string in bytes - * @param iv The initialization vector (optional, only for modes that support it) + * @param mode_validity Whether mode is valid + * @param iv The initialization vector * @param iv_len Length of the IV in bytes - * @param fifth_argument Additional parameter (optional, only for modes that support it) + * @param iv_validity Whether IV is valid + * @param fifth_argument Additional parameter (e.g. AAD for the GCM mode) * @param fifth_argument_len Length of fifth_argument in bytes + * @param fifth_argument_validity Whether fifth_argument is valid * @param plaintext Output buffer for decrypted data - * @return Length of decrypted data in bytes - * @throws std::runtime_error on decryption failure or unsupported mode + * @return Length of decrypted data in bytes (plaintext only, IV and tag removed) + * @throws std::runtime_error on decryption failure, unsupported mode, invalid parameters, or authentication failure */ static int32_t decrypt(const char* ciphertext, int32_t ciphertext_len, - const char* key, int32_t key_len, - const char* mode, int32_t mode_len, - const char* iv, int32_t iv_len, + const char* key, int32_t key_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv, int32_t iv_len, bool iv_validity, const char* fifth_argument, int32_t fifth_argument_len, + bool fifth_argument_validity, unsigned char* plaintext); }; diff --git a/cpp/src/gandiva/encrypt_utils_cbc.cc b/cpp/src/gandiva/encrypt_utils_cbc.cc index 04eb60c96a7..3c4bf14993e 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc.cc +++ b/cpp/src/gandiva/encrypt_utils_cbc.cc @@ -17,6 +17,7 @@ #include "gandiva/encrypt_utils_cbc.h" #include "gandiva/encrypt_utils_common.h" +#include "gandiva/encrypt_utils_iv.h" #include #include #include @@ -45,18 +46,45 @@ const EVP_CIPHER* get_cbc_cipher_algo(int32_t key_length) { } } +void validate_iv_length_cbc(int32_t iv_len) { + if (iv_len != CBC_IV_LENGTH) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_len + << " bytes. IV must be exactly " << CBC_IV_LENGTH << " bytes"; + throw std::runtime_error(oss.str()); + } +} + +void validate_ciphertext_with_embedded_iv_cbc(int32_t ciphertext_len) { + constexpr int32_t MIN_CIPHERTEXT_LEN = CBC_IV_LENGTH + 16; // IV + minimum one block + if (ciphertext_len < MIN_CIPHERTEXT_LEN) { + std::ostringstream oss; + oss << "Ciphertext too short for AES-CBC with embedded IV: " << ciphertext_len + << " bytes. Must be at least " << MIN_CIPHERTEXT_LEN + << " bytes (16-byte IV + minimum 16-byte block)"; + throw std::runtime_error(oss.str()); + } +} + } // namespace GANDIVA_EXPORT int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, bool use_padding, unsigned char* cipher) { - // Validate IV length - if (iv_len != 16) { - std::ostringstream oss; - oss << "Invalid IV length for AES-CBC: " << iv_len - << " bytes. IV must be exactly 16 bytes"; - throw std::runtime_error(oss.str()); + // Buffer for IV (either user-supplied or auto-generated) + unsigned char iv_buffer[CBC_IV_LENGTH]; + const unsigned char* actual_iv = nullptr; + bool iv_auto_generated = false; + + if (iv == nullptr || iv_len == 0) { + generate_random_iv(iv_buffer, CBC_IV_LENGTH); + actual_iv = iv_buffer; + iv_auto_generated = true; + } else { + validate_iv_length_cbc(iv_len); + actual_iv = reinterpret_cast(iv); + iv_auto_generated = false; } int32_t cipher_len = 0; @@ -69,9 +97,17 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char get_openssl_error_string()); } + // Only prepend IV to output if it was auto-generated + // Auto-generated IV: [16-byte IV][ciphertext] + // User-supplied IV: [ciphertext] + if (iv_auto_generated) { + std::memcpy(cipher, actual_iv, CBC_IV_LENGTH); + cipher_len = CBC_IV_LENGTH; + } + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, reinterpret_cast(key), - reinterpret_cast(iv))) { + actual_iv)) { EVP_CIPHER_CTX_free(en_ctx); throw std::runtime_error("Could not initialize EVP cipher context for encryption: " + get_openssl_error_string()); @@ -84,7 +120,8 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char get_openssl_error_string()); } - if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + // Encrypt plaintext (write after IV) + if (!EVP_EncryptUpdate(en_ctx, cipher + cipher_len, &len, reinterpret_cast(plaintext), plaintext_len)) { EVP_CIPHER_CTX_free(en_ctx); @@ -94,7 +131,7 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char cipher_len += len; - if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + if (!EVP_EncryptFinal_ex(en_ctx, cipher + cipher_len, &len)) { EVP_CIPHER_CTX_free(en_ctx); throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + get_openssl_error_string()); @@ -110,12 +147,21 @@ GANDIVA_EXPORT int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, bool use_padding, unsigned char* plaintext) { - // Validate IV length - if (iv_len != 16) { - std::ostringstream oss; - oss << "Invalid IV length for AES-CBC: " << iv_len - << " bytes. IV must be exactly 16 bytes"; - throw std::runtime_error(oss.str()); + // Buffer for extracted IV (if needed) + unsigned char iv_buffer[CBC_IV_LENGTH]; + const unsigned char* actual_iv = nullptr; + const char* actual_ciphertext = ciphertext; + int32_t actual_ciphertext_len = ciphertext_len; + + if (iv == nullptr) { + validate_ciphertext_with_embedded_iv_cbc(ciphertext_len); + extract_iv_from_ciphertext(ciphertext, ciphertext_len, CBC_IV_LENGTH, + iv_buffer, &actual_ciphertext, + &actual_ciphertext_len); + actual_iv = iv_buffer; + } else { + validate_iv_length_cbc(iv_len); + actual_iv = reinterpret_cast(iv); } int32_t plaintext_len = 0; @@ -130,7 +176,7 @@ int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const ch if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, reinterpret_cast(key), - reinterpret_cast(iv))) { + actual_iv)) { EVP_CIPHER_CTX_free(de_ctx); throw std::runtime_error("Could not initialize EVP cipher context for decryption: " + get_openssl_error_string()); @@ -144,8 +190,8 @@ int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const ch } if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, - reinterpret_cast(ciphertext), - ciphertext_len)) { + reinterpret_cast(actual_ciphertext), + actual_ciphertext_len)) { EVP_CIPHER_CTX_free(de_ctx); throw std::runtime_error("Could not update EVP cipher context for decryption: " + get_openssl_error_string()); diff --git a/cpp/src/gandiva/encrypt_utils_cbc.h b/cpp/src/gandiva/encrypt_utils_cbc.h index b083d6f0a2d..41acdab0cfb 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc.h +++ b/cpp/src/gandiva/encrypt_utils_cbc.h @@ -28,17 +28,29 @@ constexpr const char* AES_CBC_MODE = "AES-CBC"; constexpr const char* AES_CBC_PKCS7_MODE = "AES-CBC-PKCS7"; constexpr const char* AES_CBC_NONE_MODE = "AES-CBC-NONE"; +// CBC IV length in bytes +constexpr int32_t CBC_IV_LENGTH = 16; // 16 bytes (128 bits) - required for CBC + /** * Encrypt data using AES-CBC algorithm with explicit padding mode * + * Output format: + * - With NULL IV (auto-generated): [16-byte IV][ciphertext] + * - With user-supplied IV: [ciphertext] + * + * IV Handling: + * - If iv is NULL: A cryptographically secure random 16-byte IV + * is automatically generated using OpenSSL RAND_bytes and prepended to output + * - If iv is provided: It must be exactly 16 bytes, will be used as-is, and not prepended + * * @param plaintext The data to encrypt * @param plaintext_len Length of plaintext in bytes * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) * @param key_len Length of key in bytes - * @param iv The initialization vector (must be exactly 16 bytes) - * @param iv_len Length of IV in bytes (must be 16) + * @param iv The initialization vector (NULL for auto-generation, or exactly 16 bytes) + * @param iv_len Length of IV in bytes * @param use_padding Whether to use PKCS7 padding (true) or no padding (false) - * @param cipher Output buffer for encrypted data + * @param cipher Output buffer for encrypted data (must be at least plaintext_len + 32 bytes) * @return Length of encrypted data in bytes * @throws std::runtime_error on encryption failure or invalid parameters */ @@ -50,12 +62,20 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char /** * Decrypt data using AES-CBC algorithm with explicit padding mode * + * IV Handling: + * - If iv is NULL: IV is extracted from the first 16 bytes of ciphertext + * (expects format: [16-byte IV][ciphertext]) + * - If iv is provided: It must be exactly 16 bytes, and ciphertext should be + * [ciphertext] without embedded IV + * * @param ciphertext The data to decrypt - * @param ciphertext_len Length of ciphertext in bytes + * - With NULL IV: [16-byte IV][ciphertext] (min 32 bytes) + * - With provided IV: [ciphertext] + * @param ciphertext_len Length of ciphertext in bytes (includes IV if embedded) * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) * @param key_len Length of key in bytes - * @param iv The initialization vector (must be exactly 16 bytes) - * @param iv_len Length of IV in bytes (must be 16) + * @param iv The initialization vector (NULL for extraction, or exactly 16 bytes) + * @param iv_len Length of IV in bytes * @param use_padding Whether to use PKCS7 padding (true) or no padding (false) * @param plaintext Output buffer for decrypted data * @return Length of decrypted data in bytes diff --git a/cpp/src/gandiva/encrypt_utils_cbc_test.cc b/cpp/src/gandiva/encrypt_utils_cbc_test.cc index 8bf9227d65b..de6e1914a51 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc_test.cc +++ b/cpp/src/gandiva/encrypt_utils_cbc_test.cc @@ -21,7 +21,7 @@ #include #include -// Test PKCS#7 padding with 16-byte key +// Test PKCS#7 padding with 16-byte key (user-supplied IV) TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_16) { auto* key = "12345678abcdefgh"; auto* iv = "1234567890123456"; @@ -30,12 +30,21 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_16) { auto key_len = static_cast(strlen(key)); auto iv_len = static_cast(strlen(iv)); auto to_encrypt_len = static_cast(strlen(to_encrypt)); - unsigned char cipher[64]; + unsigned char cipher[128]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, true, cipher); - unsigned char decrypted[64]; + // Output format with user-supplied IV: [ciphertext] (IV NOT prepended) + // Ciphertext includes padding, so it's rounded up to next 16-byte block + EXPECT_GE(cipher_len, to_encrypt_len); // At least plaintext length + EXPECT_LE(cipher_len, to_encrypt_len + 16); // Should NOT include IV (at most one block of padding) + + // Verify IV is NOT prepended (ciphertext should not match IV) + EXPECT_NE(0, std::memcmp(cipher, iv, 16)); + + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, true, decrypted); @@ -44,7 +53,7 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_16) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test PKCS#7 padding with 24-byte key +// Test PKCS#7 padding with 24-byte key (user-supplied IV) TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_24) { auto* key = "12345678abcdefgh12345678"; auto* iv = "1234567890123456"; @@ -53,12 +62,20 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_24) { auto key_len = static_cast(strlen(key)); auto iv_len = static_cast(strlen(iv)); auto to_encrypt_len = static_cast(strlen(to_encrypt)); - unsigned char cipher[64]; + unsigned char cipher[128]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, true, cipher); - unsigned char decrypted[64]; + // Output format with user-supplied IV: [ciphertext] (IV NOT prepended) + EXPECT_GE(cipher_len, to_encrypt_len); + EXPECT_LE(cipher_len, to_encrypt_len + 16); // Should NOT include IV (at most one block of padding) + + // Verify IV is NOT prepended + EXPECT_NE(0, std::memcmp(cipher, iv, 16)); + + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, true, decrypted); @@ -67,7 +84,7 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_24) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test PKCS#7 padding with 32-byte key +// Test PKCS#7 padding with 32-byte key (user-supplied IV) TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_32) { auto* key = "12345678abcdefgh12345678abcdefgh"; auto* iv = "1234567890123456"; @@ -76,12 +93,20 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_32) { auto key_len = static_cast(strlen(key)); auto iv_len = static_cast(strlen(iv)); auto to_encrypt_len = static_cast(strlen(to_encrypt)); - unsigned char cipher[64]; + unsigned char cipher[128]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, true, cipher); - unsigned char decrypted[64]; + // Output format with user-supplied IV: [ciphertext] (IV NOT prepended) + EXPECT_GE(cipher_len, to_encrypt_len); + EXPECT_LT(cipher_len, to_encrypt_len + 16); // Should NOT include IV + + // Verify IV is NOT prepended + EXPECT_NE(0, std::memcmp(cipher, iv, 16)); + + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, true, decrypted); @@ -90,7 +115,7 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_32) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test no-padding mode with block-aligned data (16 bytes) +// Test no-padding mode with block-aligned data (16 bytes, user-supplied IV) TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptNoPadding_16) { auto* key = "12345678abcdefgh"; auto* iv = "1234567890123456"; @@ -99,12 +124,20 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptNoPadding_16) { auto key_len = static_cast(strlen(key)); auto iv_len = static_cast(strlen(iv)); auto to_encrypt_len = static_cast(strlen(to_encrypt)); - unsigned char cipher[64]; + unsigned char cipher[128]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, false, cipher); - unsigned char decrypted[64]; + // Output format with user-supplied IV: [ciphertext] (IV NOT prepended) + // No padding, so ciphertext is exactly 16 bytes + EXPECT_EQ(cipher_len, 16); + + // Verify IV is NOT prepended + EXPECT_NE(0, std::memcmp(cipher, iv, 16)); + + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, false, decrypted); @@ -153,5 +186,97 @@ TEST(TestAesCbcEncryptUtils, TestInvalidKeyLength) { } } +// Test NULL IV with auto-generation (encrypt and decrypt round-trip) +TEST(TestAesCbcEncryptUtils, TestNullIvAutoGeneration) { + auto* key = "12345678abcdefgh"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + // Encrypt with NULL IV (auto-generate) + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + nullptr, 0, true, cipher); + + // Output format: [16-byte IV][ciphertext with padding] + EXPECT_GE(cipher_len, to_encrypt_len + 16); + + // Decrypt with NULL IV (extract from ciphertext) + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, nullptr, 0, + true, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test NULL IV with no padding +TEST(TestAesCbcEncryptUtils, TestNullIvNoPadding) { + auto* key = "12345678abcdefgh"; + auto* to_encrypt = "1234567890123456"; // Exactly 16 bytes + + auto key_len = static_cast(strlen(key)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + // Encrypt with NULL IV and no padding + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + nullptr, 0, false, cipher); + + // Output format: [16-byte IV][16-byte ciphertext] + EXPECT_EQ(cipher_len, 16 + 16); + + // Decrypt with NULL IV and no padding + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, nullptr, 0, + false, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test that user-supplied IV encrypt requires same IV for decrypt +TEST(TestAesCbcEncryptUtils, TestSuppliedIvEncryptRequiresSameIvDecrypt) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + // Encrypt with user-supplied IV (IV will NOT be prepended) + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, true, cipher); + + // Decrypt with the same IV (required since IV was not prepended) + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + true, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test decrypt with too-short ciphertext (NULL IV case) +TEST(TestAesCbcEncryptUtils, TestDecryptTooShortCiphertext) { + auto* key = "12345678abcdefgh"; + auto key_len = static_cast(strlen(key)); + + // Ciphertext too short: only 20 bytes (needs at least 32: 16 IV + 16 min block) + unsigned char short_cipher[20] = {0}; + unsigned char decrypted[128]; + + EXPECT_THROW(gandiva::aes_decrypt_cbc(reinterpret_cast(short_cipher), + 20, key, key_len, nullptr, 0, + true, decrypted), + std::runtime_error); +} + diff --git a/cpp/src/gandiva/encrypt_utils_gcm.cc b/cpp/src/gandiva/encrypt_utils_gcm.cc index f028243da59..05e8d656e90 100644 --- a/cpp/src/gandiva/encrypt_utils_gcm.cc +++ b/cpp/src/gandiva/encrypt_utils_gcm.cc @@ -17,6 +17,7 @@ #include "gandiva/encrypt_utils_gcm.h" #include "gandiva/encrypt_utils_common.h" +#include "gandiva/encrypt_utils_iv.h" #include #include #include @@ -44,6 +45,40 @@ const EVP_CIPHER* get_gcm_cipher_algo(int32_t key_length) { } } +void validate_iv_length_gcm(int32_t iv_len) { + if (iv_len != GCM_IV_LENGTH) { + std::ostringstream oss; + oss << "Invalid IV length for AES-GCM: " << iv_len + << " bytes. IV must be exactly " << GCM_IV_LENGTH << " bytes"; + throw std::runtime_error(oss.str()); + } +} + +void validate_aad_length_gcm(int32_t aad_len) { + if (aad_len <= 0) { + throw std::runtime_error("AAD length must be positive when AAD is provided"); + } +} + +void validate_ciphertext_with_embedded_iv_gcm(int32_t ciphertext_len) { + constexpr int32_t MIN_CIPHERTEXT_LEN = GCM_IV_LENGTH + GCM_TAG_LENGTH; + + if (ciphertext_len < MIN_CIPHERTEXT_LEN) { + std::ostringstream oss; + oss << "Ciphertext too short for AES-GCM with embedded IV: " << ciphertext_len + << " bytes. Must be at least " << MIN_CIPHERTEXT_LEN + << " bytes (12-byte IV + 16-byte tag)"; + throw std::runtime_error(oss.str()); + } +} + +void validate_ciphertext_with_tag(int32_t ciphertext_len) { + if (ciphertext_len < GCM_TAG_LENGTH) { + throw std::runtime_error( + "Ciphertext too short for AES-GCM: must be at least 16 bytes for tag"); + } +} + } // namespace GANDIVA_EXPORT @@ -51,9 +86,16 @@ int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, const char* aad, int32_t aad_len, unsigned char* cipher) { - if (iv_len <= 0) { - throw std::runtime_error( - "Invalid IV length for AES-GCM: IV length must be greater than 0"); + unsigned char iv_buffer[GCM_IV_LENGTH]; + const unsigned char* actual_iv = nullptr; + bool iv_auto_generated = iv == nullptr; + + if (iv_auto_generated) { + generate_random_iv(iv_buffer, GCM_IV_LENGTH); + actual_iv = iv_buffer; + } else { + validate_iv_length_gcm(iv_len); + actual_iv = reinterpret_cast(iv); } int32_t cipher_len = 0; @@ -67,22 +109,32 @@ int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, } try { + // Only prepend IV to output if it was auto-generated + // Auto-generated IV: [12-byte IV][ciphertext][16-byte tag] + // User-supplied IV: [ciphertext][16-byte tag] + if (iv_auto_generated) { + std::memcpy(cipher, actual_iv, GCM_IV_LENGTH); + cipher_len = GCM_IV_LENGTH; + } + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, reinterpret_cast(key), - reinterpret_cast(iv))) { + actual_iv)) { throw std::runtime_error( "Could not initialize EVP cipher context for encryption: " + get_openssl_error_string()); } // Set IV length for GCM mode - if (!EVP_CIPHER_CTX_ctrl(en_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, nullptr)) { + if (!EVP_CIPHER_CTX_ctrl(en_ctx, EVP_CTRL_GCM_SET_IVLEN, GCM_IV_LENGTH, nullptr)) { throw std::runtime_error("Could not set GCM IV length: " + get_openssl_error_string()); } // Process AAD if provided - if (aad != nullptr && aad_len > 0) { + if (aad != nullptr) { + validate_aad_length_gcm(aad_len); + if (!EVP_EncryptUpdate(en_ctx, nullptr, &len, reinterpret_cast(aad), aad_len)) { throw std::runtime_error("Could not process AAD for encryption: " + @@ -90,8 +142,8 @@ int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, } } - // Encrypt plaintext - if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + // Encrypt plaintext (write after IV) + if (!EVP_EncryptUpdate(en_ctx, cipher + cipher_len, &len, reinterpret_cast(plaintext), plaintext_len)) { throw std::runtime_error("Could not update EVP cipher context for encryption: " + @@ -101,7 +153,7 @@ int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, cipher_len += len; // Finalize encryption - if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + if (!EVP_EncryptFinal_ex(en_ctx, cipher + cipher_len, &len)) { throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + get_openssl_error_string()); } @@ -129,14 +181,21 @@ int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, const char* aad, int32_t aad_len, unsigned char* plaintext) { - if (iv_len <= 0) { - throw std::runtime_error( - "Invalid IV length for AES-GCM: IV length must be greater than 0"); - } - - if (ciphertext_len < GCM_TAG_LENGTH) { - throw std::runtime_error( - "Ciphertext too short for AES-GCM: must be at least 16 bytes for tag"); + unsigned char iv_buffer[GCM_IV_LENGTH]; + const unsigned char* actual_iv = nullptr; + const char* actual_ciphertext = ciphertext; + int32_t actual_ciphertext_with_tag_len = ciphertext_len; + + if (iv == nullptr) { + validate_ciphertext_with_embedded_iv_gcm(ciphertext_len); + extract_iv_from_ciphertext(ciphertext, ciphertext_len, GCM_IV_LENGTH, + iv_buffer, &actual_ciphertext, + &actual_ciphertext_with_tag_len); + actual_iv = iv_buffer; + } else { + validate_iv_length_gcm(iv_len); + validate_ciphertext_with_tag(ciphertext_len); + actual_iv = reinterpret_cast(iv); } int32_t plaintext_len = 0; @@ -152,20 +211,22 @@ int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, try { if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, reinterpret_cast(key), - reinterpret_cast(iv))) { + actual_iv)) { throw std::runtime_error( "Could not initialize EVP cipher context for decryption: " + get_openssl_error_string()); } // Set IV length for GCM mode - if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, nullptr)) { + if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_IVLEN, GCM_IV_LENGTH, nullptr)) { throw std::runtime_error("Could not set GCM IV length: " + get_openssl_error_string()); } // Process AAD if provided - if (aad != nullptr && aad_len > 0) { + if (aad != nullptr) { + validate_aad_length_gcm(aad_len); + if (!EVP_DecryptUpdate(de_ctx, nullptr, &len, reinterpret_cast(aad), aad_len)) { throw std::runtime_error("Could not process AAD for decryption: " + @@ -173,10 +234,10 @@ int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, } } - // Extract tag from end of ciphertext - int32_t actual_ciphertext_len = ciphertext_len - GCM_TAG_LENGTH; - const unsigned char* tag = - reinterpret_cast(ciphertext + actual_ciphertext_len); + // GCM always has a tag appended, regardless of whether AAD was used + int32_t ciphertext_without_tag_len = actual_ciphertext_with_tag_len - GCM_TAG_LENGTH; + const unsigned char* tag = reinterpret_cast( + actual_ciphertext + ciphertext_without_tag_len); // Set the authentication tag if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_LENGTH, @@ -187,8 +248,8 @@ int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, // Decrypt ciphertext if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, - reinterpret_cast(ciphertext), - actual_ciphertext_len)) { + reinterpret_cast(actual_ciphertext), + ciphertext_without_tag_len)) { throw std::runtime_error("Could not update EVP cipher context for decryption: " + get_openssl_error_string()); } diff --git a/cpp/src/gandiva/encrypt_utils_gcm.h b/cpp/src/gandiva/encrypt_utils_gcm.h index 07a597af0b6..6b117251681 100644 --- a/cpp/src/gandiva/encrypt_utils_gcm.h +++ b/cpp/src/gandiva/encrypt_utils_gcm.h @@ -26,22 +26,34 @@ namespace gandiva { // GCM mode identifier constexpr const char* AES_GCM_MODE = "AES-GCM"; +// GCM IV length in bytes +constexpr int32_t GCM_IV_LENGTH = 12; // 12 bytes (96 bits) - recommended for GCM but agreed to enforce it + // GCM authentication tag length in bytes constexpr int32_t GCM_TAG_LENGTH = 16; /** * Encrypt data using AES-GCM algorithm * + * Output format: + * - With NULL IV (auto-generated): [12-byte IV][ciphertext][16-byte authentication tag] + * - With user-supplied IV: [ciphertext][16-byte authentication tag] + * + * IV Handling: + * - If iv is NULL: A cryptographically secure random 12-byte IV + * is automatically generated using OpenSSL RAND_bytes and prepended to output + * - If iv is provided: It must be exactly 12 bytes, will be used as-is, and not prepended + * * @param plaintext The data to encrypt * @param plaintext_len Length of plaintext in bytes * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) * @param key_len Length of key in bytes - * @param iv The initialization vector (variable length, typically 12 bytes) + * @param iv The initialization vector (NULL for auto-generation, or exactly 12 bytes) * @param iv_len Length of IV in bytes * @param aad Optional additional authenticated data (can be null) - * @param aad_len Length of AAD in bytes (0 if aad is null) - * @param cipher Output buffer for encrypted data (must be at least plaintext_len + 16 bytes) - * @return Length of encrypted data in bytes (plaintext_len + 16 for the tag) + * @param aad_len Length of AAD in bytes + * @param cipher Output buffer for encrypted data (must be at least plaintext_len + 28 bytes) + * @return Length of encrypted data in bytes (12 + plaintext_len + 16) * @throws std::runtime_error on encryption failure or invalid parameters */ GANDIVA_EXPORT @@ -52,16 +64,24 @@ int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, const char /** * Decrypt data using AES-GCM algorithm * - * @param ciphertext The data to decrypt (includes 16-byte authentication tag at the end) - * @param ciphertext_len Length of ciphertext in bytes (includes tag) + * IV Handling: + * - If iv is NULL or iv_len is 0: IV is extracted from the first 12 bytes of ciphertext + * (expects format: [12-byte IV][ciphertext][16-byte tag]) + * - If iv is provided: It must be exactly 12 bytes, and ciphertext should be + * [ciphertext][16-byte tag] without embedded IV + * + * @param ciphertext The data to decrypt + * - With NULL IV: [12-byte IV][ciphertext][16-byte tag] (min 28 bytes) + * - With provided IV: [ciphertext][16-byte tag] (min 16 bytes) + * @param ciphertext_len Length of ciphertext in bytes (includes IV if embedded, and tag) * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) * @param key_len Length of key in bytes - * @param iv The initialization vector (variable length, typically 12 bytes) - * @param iv_len Length of IV in bytes + * @param iv The initialization vector (NULL for extraction, or exactly 12 bytes) + * @param iv_len Length of IV in bytes (0 for extraction, or 12) * @param aad Optional additional authenticated data (can be null) * @param aad_len Length of AAD in bytes (0 if aad is null) * @param plaintext Output buffer for decrypted data - * @return Length of decrypted data in bytes (ciphertext_len - 16) + * @return Length of decrypted data in bytes * @throws std::runtime_error on decryption failure, invalid parameters, or tag verification failure */ GANDIVA_EXPORT diff --git a/cpp/src/gandiva/encrypt_utils_gcm_test.cc b/cpp/src/gandiva/encrypt_utils_gcm_test.cc index 2156132bc62..8325b24c16d 100644 --- a/cpp/src/gandiva/encrypt_utils_gcm_test.cc +++ b/cpp/src/gandiva/encrypt_utils_gcm_test.cc @@ -21,7 +21,7 @@ #include #include -// Test IV-only GCM with 16-byte key +// Test IV-only GCM with 16-byte key (user-supplied IV) TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_16) { auto* key = "12345678abcdefgh"; auto* iv = "123456789012"; // 12-byte IV @@ -35,10 +35,14 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_16) { int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, nullptr, 0, cipher); - // Ciphertext should be plaintext_len + 16 (tag) + // Output format with user-supplied IV: [ciphertext][16-byte tag] (IV NOT prepended) EXPECT_EQ(cipher_len, to_encrypt_len + 16); + // Verify IV is NOT prepended (ciphertext should not match IV) + EXPECT_NE(0, std::memcmp(cipher, iv, 12)); + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, nullptr, 0, decrypted); @@ -47,7 +51,7 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_16) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test IV + AAD GCM with 16-byte key +// Test IV + AAD GCM with 16-byte key (user-supplied IV) TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptWithAad_16) { auto* key = "12345678abcdefgh"; auto* iv = "123456789012"; @@ -63,9 +67,14 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptWithAad_16) { int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, aad, aad_len, cipher); + // Output format with user-supplied IV: [ciphertext][16-byte tag] (IV NOT prepended) EXPECT_EQ(cipher_len, to_encrypt_len + 16); + // Verify IV is NOT prepended + EXPECT_NE(0, std::memcmp(cipher, iv, 12)); + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, aad, aad_len, decrypted); @@ -74,7 +83,7 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptWithAad_16) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test IV-only GCM with 24-byte key +// Test IV-only GCM with 24-byte key (user-supplied IV) TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_24) { auto* key = "12345678abcdefgh12345678"; auto* iv = "123456789012"; @@ -88,7 +97,14 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_24) { int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, nullptr, 0, cipher); + // Output format with user-supplied IV: [ciphertext][16-byte tag] (IV NOT prepended) + EXPECT_EQ(cipher_len, to_encrypt_len + 16); + + // Verify IV is NOT prepended + EXPECT_NE(0, std::memcmp(cipher, iv, 12)); + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, nullptr, 0, decrypted); @@ -97,7 +113,7 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_24) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test IV-only GCM with 32-byte key +// Test IV-only GCM with 32-byte key (user-supplied IV) TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_32) { auto* key = "12345678abcdefgh12345678abcdefgh"; auto* iv = "123456789012"; @@ -111,7 +127,14 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_32) { int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, nullptr, 0, cipher); + // Output format with user-supplied IV: [ciphertext][16-byte tag] (IV NOT prepended) + EXPECT_EQ(cipher_len, to_encrypt_len + 16); + + // Verify IV is NOT prepended + EXPECT_NE(0, std::memcmp(cipher, iv, 12)); + unsigned char decrypted[128]; + // Pass the same IV to decrypt (since encrypt did NOT prepend it) int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, nullptr, 0, decrypted); @@ -120,7 +143,7 @@ TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_32) { std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test tag verification failure +// Test tag verification failure (user-supplied IV) TEST(TestAesGcmEncryptUtils, TestTagVerificationFailure) { auto* key = "12345678abcdefgh"; auto* iv = "123456789012"; @@ -144,10 +167,10 @@ TEST(TestAesGcmEncryptUtils, TestTagVerificationFailure) { std::runtime_error); } -// Test invalid IV length +// Test invalid IV length (non-12-byte IV should fail) TEST(TestAesGcmEncryptUtils, TestInvalidIvLength) { auto* key = "12345678abcdefgh"; - auto* iv = ""; // Empty IV + auto* iv = "1234567890"; // 10-byte IV (invalid, must be exactly 12) auto* to_encrypt = "some test string"; auto key_len = static_cast(strlen(key)); @@ -160,3 +183,97 @@ TEST(TestAesGcmEncryptUtils, TestInvalidIvLength) { std::runtime_error); } +// Test NULL IV with auto-generation (encrypt and decrypt round-trip) +TEST(TestAesGcmEncryptUtils, TestNullIvAutoGeneration) { + auto* key = "12345678abcdefgh"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + // Encrypt with NULL IV (auto-generate) + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + nullptr, 0, nullptr, 0, cipher); + + // Output format: [12-byte IV][ciphertext][16-byte tag] + EXPECT_EQ(cipher_len, to_encrypt_len + 12 + 16); + + // Decrypt with NULL IV (extract from ciphertext) + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, nullptr, 0, + nullptr, 0, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test NULL IV with AAD +TEST(TestAesGcmEncryptUtils, TestNullIvWithAad) { + auto* key = "12345678abcdefgh"; + auto* to_encrypt = "some test string"; + auto* aad = "additional authenticated data"; + + auto key_len = static_cast(strlen(key)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + auto aad_len = static_cast(strlen(aad)); + unsigned char cipher[128]; + + // Encrypt with NULL IV and AAD + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + nullptr, 0, aad, aad_len, cipher); + + // Output format: [12-byte IV][ciphertext][16-byte tag] + EXPECT_EQ(cipher_len, to_encrypt_len + 12 + 16); + + // Decrypt with NULL IV and AAD + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, nullptr, 0, + aad, aad_len, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test that user-supplied IV encrypt requires same IV for decrypt +TEST(TestAesGcmEncryptUtils, TestSuppliedIvEncryptRequiresSameIvDecrypt) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + // Encrypt with user-supplied IV (IV will NOT be prepended) + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, nullptr, 0, cipher); + + // Decrypt with the same IV (required since IV was not prepended) + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + nullptr, 0, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test decrypt with too-short ciphertext (NULL IV case) +TEST(TestAesGcmEncryptUtils, TestDecryptTooShortCiphertext) { + auto* key = "12345678abcdefgh"; + auto key_len = static_cast(strlen(key)); + + // Ciphertext too short: only 20 bytes (needs at least 28: 12 IV + 16 tag) + unsigned char short_cipher[20] = {0}; + unsigned char decrypted[128]; + + EXPECT_THROW(gandiva::aes_decrypt_gcm(reinterpret_cast(short_cipher), + 20, key, key_len, nullptr, 0, + nullptr, 0, decrypted), + std::runtime_error); +} + diff --git a/cpp/src/gandiva/encrypt_utils_iv.cc b/cpp/src/gandiva/encrypt_utils_iv.cc new file mode 100644 index 00000000000..338ece1b8c2 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_iv.cc @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_iv.h" +#include +#include +#include +#include + +namespace gandiva { + +void generate_random_iv(unsigned char* iv_buffer, int32_t iv_length) { + if (iv_buffer == nullptr) { + throw std::runtime_error("IV buffer cannot be null"); + } + + if (iv_length <= 0) { + std::ostringstream oss; + oss << "Invalid IV length: " << iv_length << ". IV length must be positive"; + throw std::runtime_error(oss.str()); + } + + // Generate cryptographically secure random bytes using OpenSSL + int result = RAND_bytes(iv_buffer, iv_length); + if (result != 1) { + throw std::runtime_error( + "Failed to generate random IV: OpenSSL RAND_bytes failed"); + } +} + +void extract_iv_from_ciphertext(const char* ciphertext_with_iv, int32_t ciphertext_len, + int32_t iv_length, unsigned char* extracted_iv, + const char** actual_ciphertext, + int32_t* actual_ciphertext_len) { + if (ciphertext_with_iv == nullptr) { + throw std::runtime_error("Ciphertext cannot be null"); + } + + if (extracted_iv == nullptr) { + throw std::runtime_error("Extracted IV buffer cannot be null"); + } + + if (actual_ciphertext == nullptr) { + throw std::runtime_error("Actual ciphertext output pointer cannot be null"); + } + + if (actual_ciphertext_len == nullptr) { + throw std::runtime_error("Actual ciphertext length output pointer cannot be null"); + } + + if (iv_length <= 0) { + std::ostringstream oss; + oss << "Invalid IV length: " << iv_length << ". IV length must be positive"; + throw std::runtime_error(oss.str()); + } + + if (ciphertext_len < iv_length) { + std::ostringstream oss; + oss << "Ciphertext too short to contain IV: ciphertext is " << ciphertext_len + << " bytes but IV requires " << iv_length << " bytes"; + throw std::runtime_error(oss.str()); + } + + // Extract IV from the beginning of ciphertext + std::memcpy(extracted_iv, ciphertext_with_iv, iv_length); + + // Set pointer to actual ciphertext (after IV) + *actual_ciphertext = ciphertext_with_iv + iv_length; + *actual_ciphertext_len = ciphertext_len - iv_length; +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_iv.h b/cpp/src/gandiva/encrypt_utils_iv.h new file mode 100644 index 00000000000..c3bdcf0e13a --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_iv.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef GANDIVA_ENCRYPT_UTILS_IV_H +#define GANDIVA_ENCRYPT_UTILS_IV_H + +#include +#include "gandiva/visibility.h" + +namespace gandiva { + +/** + * Generate a cryptographically secure random initialization vector (IV) + * using OpenSSL's RAND_bytes. + * + * @param iv_buffer Output buffer to store the generated IV + * @param iv_length Length of IV to generate in bytes (typically 12 for GCM, 16 for CBC) + * @throws std::runtime_error if random number generation fails + */ +GANDIVA_EXPORT +void generate_random_iv(unsigned char* iv_buffer, int32_t iv_length); + +/** + * Extract IV from the beginning of ciphertext and return pointer to actual ciphertext. + * This is a helper function for decrypt operations when IV is embedded in the ciphertext. + * + * @param ciphertext_with_iv Pointer to ciphertext with IV prepended + * @param ciphertext_len Total length including IV + * @param iv_length Expected IV length (12 for GCM, 16 for CBC) + * @param extracted_iv Output buffer to store extracted IV (must be at least iv_length bytes) + * @param actual_ciphertext Output pointer to the actual ciphertext (after IV) + * @param actual_ciphertext_len Output length of actual ciphertext (without IV) + * @throws std::runtime_error if ciphertext is too short to contain IV + */ +GANDIVA_EXPORT +void extract_iv_from_ciphertext(const char* ciphertext_with_iv, int32_t ciphertext_len, + int32_t iv_length, unsigned char* extracted_iv, + const char** actual_ciphertext, + int32_t* actual_ciphertext_len); + +} // namespace gandiva + +#endif // GANDIVA_ENCRYPT_UTILS_IV_H + diff --git a/cpp/src/gandiva/encrypt_utils_iv_test.cc b/cpp/src/gandiva/encrypt_utils_iv_test.cc new file mode 100644 index 00000000000..0eb1af9a110 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_iv_test.cc @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_iv.h" +#include "gandiva/encrypt_utils_gcm.h" +#include "gandiva/encrypt_utils_cbc.h" + +#include +#include +#include + +// Test that multiple IV generations produce different values +TEST(TestIVUtils, TestGenerateRandomIvUniqueness) { + unsigned char iv1[gandiva::GCM_IV_LENGTH]; + unsigned char iv2[gandiva::GCM_IV_LENGTH]; + + gandiva::generate_random_iv(iv1, gandiva::GCM_IV_LENGTH); + gandiva::generate_random_iv(iv2, gandiva::GCM_IV_LENGTH); + + EXPECT_NE(0, std::memcmp(iv1, iv2, gandiva::GCM_IV_LENGTH)); +} + +// Test that generated IV has the correct length +TEST(TestIVUtils, TestGenerateRandomIvLength) { + unsigned char iv[gandiva::GCM_IV_LENGTH]; + + // Generate IV with GCM_IV_LENGTH (12 bytes) + ASSERT_NO_THROW(gandiva::generate_random_iv(iv, gandiva::GCM_IV_LENGTH)); + + // Verify that the function accepts the correct length without throwing + // The actual length verification is implicit - if the buffer is correctly + // filled, no out-of-bounds access occurs + + // Also test with CBC_IV_LENGTH (16 bytes) + unsigned char iv_cbc[gandiva::CBC_IV_LENGTH]; + ASSERT_NO_THROW(gandiva::generate_random_iv(iv_cbc, gandiva::CBC_IV_LENGTH)); +} + +// Test error handling for null buffer +TEST(TestIVUtils, TestGenerateRandomIvNullBuffer) { + EXPECT_THROW(gandiva::generate_random_iv(nullptr, gandiva::GCM_IV_LENGTH), + std::runtime_error); +} + +// Test error handling for invalid length +TEST(TestIVUtils, TestGenerateRandomIvInvalidLength) { + unsigned char iv[16]; + EXPECT_THROW(gandiva::generate_random_iv(iv, 0), std::runtime_error); + EXPECT_THROW(gandiva::generate_random_iv(iv, -1), std::runtime_error); +} + +// Test extracting GCM IV from ciphertext +TEST(TestIVUtils, TestExtractIvFromCiphertextGcm) { + // Create test data: [12-byte IV][ciphertext] + const char test_data[] = "123456789012CIPHERTEXT_DATA"; + const int32_t total_len = 28; // 12 + 16 + + unsigned char extracted_iv[gandiva::GCM_IV_LENGTH]; + const char* actual_ciphertext = nullptr; + int32_t actual_ciphertext_len = 0; + + ASSERT_NO_THROW(gandiva::extract_iv_from_ciphertext( + test_data, total_len, gandiva::GCM_IV_LENGTH, extracted_iv, + &actual_ciphertext, &actual_ciphertext_len)); + + // Verify IV was extracted correctly + EXPECT_EQ(0, std::memcmp(extracted_iv, "123456789012", gandiva::GCM_IV_LENGTH)); + + // Verify ciphertext pointer and length + EXPECT_EQ(actual_ciphertext, test_data + gandiva::GCM_IV_LENGTH); + EXPECT_EQ(actual_ciphertext_len, 16); + EXPECT_EQ(0, std::memcmp(actual_ciphertext, "CIPHERTEXT_DATA", 15)); +} + +// Test extracting CBC IV from ciphertext +TEST(TestIVUtils, TestExtractIvFromCiphertextCbc) { + // Create test data: [16-byte IV][ciphertext] + const char test_data[] = "1234567890123456CIPHERTEXT_DATA_HERE"; + const int32_t total_len = 37; // 16 + 21 + + unsigned char extracted_iv[gandiva::CBC_IV_LENGTH]; + const char* actual_ciphertext = nullptr; + int32_t actual_ciphertext_len = 0; + + ASSERT_NO_THROW(gandiva::extract_iv_from_ciphertext( + test_data, total_len, gandiva::CBC_IV_LENGTH, extracted_iv, + &actual_ciphertext, &actual_ciphertext_len)); + + // Verify IV was extracted correctly + EXPECT_EQ(0, std::memcmp(extracted_iv, "1234567890123456", gandiva::CBC_IV_LENGTH)); + + // Verify ciphertext pointer and length + EXPECT_EQ(actual_ciphertext, test_data + gandiva::CBC_IV_LENGTH); + EXPECT_EQ(actual_ciphertext_len, 21); + EXPECT_EQ(0, std::memcmp(actual_ciphertext, "CIPHERTEXT_DATA_HERE", 20)); +} + +// Test error handling for ciphertext too short +TEST(TestIVUtils, TestExtractIvFromCiphertextTooShort) { + const char test_data[] = "SHORT"; + unsigned char extracted_iv[gandiva::GCM_IV_LENGTH]; + const char* actual_ciphertext = nullptr; + int32_t actual_ciphertext_len = 0; + + EXPECT_THROW(gandiva::extract_iv_from_ciphertext( + test_data, 5, gandiva::GCM_IV_LENGTH, extracted_iv, + &actual_ciphertext, &actual_ciphertext_len), + std::runtime_error); +} + +// Test error handling for null inputs +TEST(TestIVUtils, TestExtractIvFromCiphertextNullInputs) { + const char test_data[] = "1234567890123456CIPHERTEXT"; + unsigned char extracted_iv[16]; + const char* actual_ciphertext = nullptr; + int32_t actual_ciphertext_len = 0; + + // Null ciphertext + EXPECT_THROW(gandiva::extract_iv_from_ciphertext( + nullptr, 27, 16, extracted_iv, &actual_ciphertext, &actual_ciphertext_len), + std::runtime_error); + + // Null extracted_iv buffer + EXPECT_THROW(gandiva::extract_iv_from_ciphertext( + test_data, 27, 16, nullptr, &actual_ciphertext, &actual_ciphertext_len), + std::runtime_error); + + // Null actual_ciphertext pointer + EXPECT_THROW(gandiva::extract_iv_from_ciphertext( + test_data, 27, 16, extracted_iv, nullptr, &actual_ciphertext_len), + std::runtime_error); + + // Null actual_ciphertext_len pointer + EXPECT_THROW(gandiva::extract_iv_from_ciphertext( + test_data, 27, 16, extracted_iv, &actual_ciphertext, nullptr), + std::runtime_error); +} + diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 7750421360e..6bab661ccb2 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -505,30 +505,35 @@ std::vector GetStringFunctionRegistry() { NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // Parameters: data, key, mode (e.g. ECB mode) + // Uses kResultNullInternal to allow NULL data while failing on NULL key/mode NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_encrypt_dispatcher_3args", + kResultNullInternal, "gdv_fn_encrypt_dispatcher_3args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_decrypt_dispatcher_3args", + kResultNullInternal, "gdv_fn_decrypt_dispatcher_3args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // Parameters: data, key, mode, iv (e.g. CBC mode) + // Note: IV can be NULL for CBC/GCM modes (auto-generates random IV) + // Uses kResultNullInternal to allow NULL IV while failing on NULL key/mode NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_encrypt_dispatcher_4args", + kResultNullInternal, "gdv_fn_encrypt_dispatcher_4args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_decrypt_dispatcher_4args", + kResultNullInternal, "gdv_fn_decrypt_dispatcher_4args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // Parameters: data, key, mode, iv, fifth_argument (e.g. GCM mode) + // Note: IV and AAD can be NULL (auto-generates random IV, no AAD) + // Uses kResultNullInternal to allow NULL IV/AAD while failing on NULL key/mode NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_encrypt_dispatcher_5args", + kResultNullInternal, "gdv_fn_encrypt_dispatcher_5args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_decrypt_dispatcher_5args", + kResultNullInternal, "gdv_fn_decrypt_dispatcher_5args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 55f6dd43cd1..2fc940e3f6f 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -778,12 +778,16 @@ const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, // This function is ECB-only, so we enforce the mode const char* mode = "AES-ECB"; int32_t mode_len = 7; + bool out_valid = true; + + // Passing `true` for validity parameters because this function is marked as kResultNullIfNull, + // so if we're here the inputs are guaranteed to be non-NULL const char* result = gdv_fn_encrypt_dispatcher_3args( - context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); + context, data, data_len, true, key_data, key_data_len, true, mode, mode_len, true, + &out_valid, out_len); // Add null terminator for string compatibility - // Note: This may not be valid UTF-8, but it's needed for string handling - if (result != nullptr) { + if (result != nullptr && out_valid) { char* mutable_result = const_cast(result); mutable_result[*out_len] = '\0'; } @@ -805,12 +809,16 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, // This function is ECB-only, so we enforce the mode const char* mode = "AES-ECB"; int32_t mode_len = 7; + bool out_valid = true; + + // Passing `true` for validity parameters because this function is marked as kResultNullIfNull, + // so if we're here the inputs are guaranteed to be non-NULL const char* result = gdv_fn_decrypt_dispatcher_3args( - context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); + context, data, data_len, true, key_data, key_data_len, true, mode, mode_len, true, + &out_valid, out_len); // Add null terminator for string compatibility - // Note: This may not be valid UTF-8, but it's needed for string handling - if (result != nullptr) { + if (result != nullptr && out_valid) { char* mutable_result = const_cast(result); mutable_result[*out_len] = '\0'; } @@ -821,69 +829,96 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, // The 3- and 4-arg signatures exist to support optional IV and other arguments extern "C" GANDIVA_EXPORT const char* gdv_fn_encrypt_dispatcher_3args( - int64_t context, const char* data, int32_t data_len, const char* key_data, - int32_t key_data_len, const char* mode, int32_t mode_len, - int32_t* out_len) { + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + bool* out_valid, int32_t* out_len) { return gdv_fn_encrypt_dispatcher_5args( - context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr, - 0, nullptr, 0, out_len); + context, data, data_len, data_validity, key_data, key_data_len, key_validity, + mode, mode_len, mode_validity, nullptr, 0, false, nullptr, 0, false, + out_valid, out_len); } extern "C" GANDIVA_EXPORT const char* gdv_fn_decrypt_dispatcher_3args( - int64_t context, const char* data, int32_t data_len, const char* key_data, - int32_t key_data_len, const char* mode, int32_t mode_len, - int32_t* out_len) { + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + bool* out_valid, int32_t* out_len) { return gdv_fn_decrypt_dispatcher_5args( - context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr, - 0, nullptr, 0, out_len); + context, data, data_len, data_validity, key_data, key_data_len, key_validity, + mode, mode_len, mode_validity, nullptr, 0, false, nullptr, 0, false, + out_valid, out_len); } extern "C" GANDIVA_EXPORT const char* gdv_fn_encrypt_dispatcher_4args( - int64_t context, const char* data, int32_t data_len, const char* key_data, - int32_t key_data_len, const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, int32_t* out_len) { + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + bool* out_valid, int32_t* out_len) { return gdv_fn_encrypt_dispatcher_5args( - context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data, - iv_data_len, nullptr, 0, out_len); + context, data, data_len, data_validity, key_data, key_data_len, key_validity, + mode, mode_len, mode_validity, iv_data, iv_data_len, iv_validity, + nullptr, 0, false, out_valid, out_len); } extern "C" GANDIVA_EXPORT const char* gdv_fn_decrypt_dispatcher_4args( - int64_t context, const char* data, int32_t data_len, const char* key_data, - int32_t key_data_len, const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, int32_t* out_len) { + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + bool* out_valid, int32_t* out_len) { return gdv_fn_decrypt_dispatcher_5args( - context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data, - iv_data_len, nullptr, 0, out_len); + context, data, data_len, data_validity, key_data, key_data_len, key_validity, + mode, mode_len, mode_validity, iv_data, iv_data_len, iv_validity, + nullptr, 0, false, out_valid, out_len); } extern "C" GANDIVA_EXPORT const char* gdv_fn_encrypt_dispatcher_5args( - int64_t context, const char* data, int32_t data_len, const char* key_data, - int32_t key_data_len, const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, const char* fifth_argument, - int32_t fifth_argument_len, int32_t* out_len) { + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + const char* fifth_argument, int32_t fifth_argument_len, bool fifth_argument_validity, + bool* out_valid, int32_t* out_len) { + // Check if plaintext is NULL - this is the only case where we return NULL + if (!data_validity) { + *out_valid = false; + *out_len = 0; + return nullptr; + } + + *out_valid = true; + try { - // Allocate extra 16 bytes for AES block padding (PKCS7 padding can add - // up to 16 bytes for a 128-bit block cipher) - // In cases of no-padding modes, this extra space is not used + // Calculate buffer size based on mode: + // - ECB: data_len + 16 (padding only, no IV) + // - CBC: data_len + 16 (IV) + 16 (padding) = data_len + 32 + // - GCM: data_len + 12 (IV) + 16 (tag) = data_len + 28 + // Use maximum to handle all modes safely + int32_t buffer_size = data_len + 32; + auto* output = reinterpret_cast( - gdv_fn_context_arena_malloc(context, data_len + 16)); + gdv_fn_context_arena_malloc(context, buffer_size)); if (output == nullptr) { throw std::runtime_error( "Memory allocation failed for encryption output"); } int32_t cipher_len = EncryptModeDispatcher::encrypt( - data, data_len, key_data, key_data_len, mode, mode_len, iv_data, - iv_data_len, fifth_argument, fifth_argument_len, output); + data, data_len, key_data, key_data_len, key_validity, + mode, mode_len, mode_validity, iv_data, iv_data_len, iv_validity, + fifth_argument, fifth_argument_len, fifth_argument_validity, output); *out_len = cipher_len; return reinterpret_cast(output); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); + *out_valid = false; *out_len = 0; return nullptr; } @@ -891,11 +926,26 @@ const char* gdv_fn_encrypt_dispatcher_5args( extern "C" GANDIVA_EXPORT const char* gdv_fn_decrypt_dispatcher_5args( - int64_t context, const char* data, int32_t data_len, const char* key_data, - int32_t key_data_len, const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, const char* fifth_argument, - int32_t fifth_argument_len, int32_t* out_len) { + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + const char* fifth_argument, int32_t fifth_argument_len, bool fifth_argument_validity, + bool* out_valid, int32_t* out_len) { + // Check if ciphertext is NULL - this is the only case where we return NULL + if (!data_validity) { + *out_valid = false; + *out_len = 0; + return nullptr; + } + + *out_valid = true; + try { + // Buffer size for decryption output is data_len: + // - Input may contain IV + ciphertext + tag/padding + // - Output is plaintext only (IV and tag/padding are removed) + // - Plaintext is always <= input size, so data_len is sufficient auto* output = reinterpret_cast( gdv_fn_context_arena_malloc(context, data_len)); if (output == nullptr) { @@ -904,13 +954,15 @@ const char* gdv_fn_decrypt_dispatcher_5args( } int32_t plaintext_len = EncryptModeDispatcher::decrypt( - data, data_len, key_data, key_data_len, mode, mode_len, iv_data, - iv_data_len, fifth_argument, fifth_argument_len, output); + data, data_len, key_data, key_data_len, key_validity, + mode, mode_len, mode_validity, iv_data, iv_data_len, iv_validity, + fifth_argument, fifth_argument_len, fifth_argument_validity, output); *out_len = plaintext_len; return reinterpret_cast(output); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); + *out_valid = false; *out_len = 0; return nullptr; } @@ -1152,14 +1204,19 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; // gdv_fn_encrypt_dispatcher_3args (data, key, mode) + // Note: kResultNullInternal functions receive validity for each argument args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length + types->i1_type(), // data_validity types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i1_type(), // key_validity types->i8_ptr_type(), // mode (binary string) types->i32_type(), // mode_length + types->i1_type(), // mode_validity + types->ptr_type(types->i1_type()), // out_valid types->i32_ptr_type() // out_length }; @@ -1169,14 +1226,19 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_encrypt_dispatcher_3args)); // gdv_fn_decrypt_dispatcher_3args (data, key, mode) + // Note: kResultNullInternal functions receive validity for each argument args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length + types->i1_type(), // data_validity types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i1_type(), // key_validity types->i8_ptr_type(), // mode (binary string) types->i32_type(), // mode_length + types->i1_type(), // mode_validity + types->ptr_type(types->i1_type()), // out_valid types->i32_ptr_type() // out_length }; @@ -1186,16 +1248,22 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_decrypt_dispatcher_3args)); // gdv_fn_encrypt_dispatcher_4args (data, key, mode, iv) + // Note: kResultNullInternal functions receive validity for each argument args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length + types->i1_type(), // data_validity types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i1_type(), // key_validity types->i8_ptr_type(), // mode (binary string) types->i32_type(), // mode_length + types->i1_type(), // mode_validity types->i8_ptr_type(), // iv (binary string) types->i32_type(), // iv_length + types->i1_type(), // iv_validity + types->ptr_type(types->i1_type()), // out_valid types->i32_ptr_type() // out_length }; @@ -1205,16 +1273,22 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_encrypt_dispatcher_4args)); // gdv_fn_decrypt_dispatcher_4args (data, key, mode, iv) + // Note: kResultNullInternal functions receive validity for each argument args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length + types->i1_type(), // data_validity types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i1_type(), // key_validity types->i8_ptr_type(), // mode (binary string) types->i32_type(), // mode_length + types->i1_type(), // mode_validity types->i8_ptr_type(), // iv (binary string) types->i32_type(), // iv_length + types->i1_type(), // iv_validity + types->ptr_type(types->i1_type()), // out_valid types->i32_ptr_type() // out_length }; @@ -1225,18 +1299,25 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { // gdv_fn_encrypt_dispatcher_5args (data, key, mode, iv, // fifth_argument) + // Note: kResultNullInternal functions receive validity for each argument args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length + types->i1_type(), // data_validity types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i1_type(), // key_validity types->i8_ptr_type(), // mode (binary string) types->i32_type(), // mode_length + types->i1_type(), // mode_validity types->i8_ptr_type(), // iv (binary string) types->i32_type(), // iv_length + types->i1_type(), // iv_validity types->i8_ptr_type(), // fifth_argument (binary string) types->i32_type(), // fifth_argument_length + types->i1_type(), // fifth_argument_validity + types->ptr_type(types->i1_type()), // out_valid types->i32_ptr_type() // out_length }; @@ -1247,18 +1328,25 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { // gdv_fn_decrypt_dispatcher_5args (data, key, mode, iv, // fifth_argument) + // Note: kResultNullInternal functions receive validity for each argument args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length + types->i1_type(), // data_validity types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i1_type(), // key_validity types->i8_ptr_type(), // mode (binary string) types->i32_type(), // mode_length + types->i1_type(), // mode_validity types->i8_ptr_type(), // iv (binary string) types->i32_type(), // iv_length + types->i1_type(), // iv_validity types->i8_ptr_type(), // fifth_argument (binary string) types->i32_type(), // fifth_argument_length + types->i1_type(), // fifth_argument_validity + types->ptr_type(types->i1_type()), // out_valid types->i32_ptr_type() // out_length }; diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 54480ac7f6f..3d9daf7cd71 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -205,49 +205,58 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int32_t* out_len); // 3-argument dispatcher: (data, key, mode) +// Note: kResultNullInternal functions receive validity for each argument GANDIVA_EXPORT const char* gdv_fn_encrypt_dispatcher_3args( - int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, const char* mode, - int32_t mode_len, int32_t* out_len); + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + bool* out_valid, int32_t* out_len); GANDIVA_EXPORT const char* gdv_fn_decrypt_dispatcher_3args( - int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, const char* mode, - int32_t mode_len, int32_t* out_len); + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + bool* out_valid, int32_t* out_len); // 4-argument dispatcher: (data, key, mode, iv) +// Note: kResultNullInternal functions receive validity for each argument GANDIVA_EXPORT const char* gdv_fn_encrypt_dispatcher_4args( - int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, const char* mode, - int32_t mode_len, const char* iv_data, int32_t iv_data_len, - int32_t* out_len); + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + bool* out_valid, int32_t* out_len); GANDIVA_EXPORT const char* gdv_fn_decrypt_dispatcher_4args( - int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, const char* mode, - int32_t mode_len, const char* iv_data, int32_t iv_data_len, - int32_t* out_len); + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + bool* out_valid, int32_t* out_len); // 5-argument dispatcher: (data, key, mode, iv, fifth_argument) +// Note: kResultNullInternal functions receive validity for each argument GANDIVA_EXPORT const char* gdv_fn_encrypt_dispatcher_5args( - int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, const char* mode, - int32_t mode_len, const char* iv_data, int32_t iv_data_len, - const char* fifth_argument, int32_t fifth_argument_len, - int32_t* out_len); + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + const char* fifth_argument, int32_t fifth_argument_len, bool fifth_argument_validity, + bool* out_valid, int32_t* out_len); GANDIVA_EXPORT const char* gdv_fn_decrypt_dispatcher_5args( - int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, const char* mode, - int32_t mode_len, const char* iv_data, int32_t iv_data_len, - const char* fifth_argument, int32_t fifth_argument_len, - int32_t* out_len); + int64_t context, const char* data, int32_t data_len, bool data_validity, + const char* key_data, int32_t key_data_len, bool key_validity, + const char* mode, int32_t mode_len, bool mode_validity, + const char* iv_data, int32_t iv_data_len, bool iv_validity, + const char* fifth_argument, int32_t fifth_argument_len, bool fifth_argument_validity, + bool* out_valid, int32_t* out_len); GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index bfb34eeb31d..7e33c0e41d0 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1356,20 +1356,15 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_ECB_MODE; - auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_3args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &decrypted_len); + const char* cipher = gdv_fn_aes_encrypt_ecb_legacy(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, &cipher_len); + const char* decrypted_value = gdv_fn_aes_decrypt_ecb_legacy( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len); EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + std::string(reinterpret_cast(decrypted_value), decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { @@ -1380,21 +1375,16 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_ECB_MODE; - auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_3args( - ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, mode.c_str(), - mode_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_ecb_legacy(ctx_ptr, data.c_str(), data_len, key24.c_str(), + key24_len, &cipher_len); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( - ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, mode.c_str(), - mode_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_ecb_legacy( + ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len); EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + std::string(reinterpret_cast(decrypted_value), decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { @@ -1405,21 +1395,16 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_ECB_MODE; - auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_3args( - ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, mode.c_str(), - mode_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_ecb_legacy(ctx_ptr, data.c_str(), data_len, key32.c_str(), + key32_len, &cipher_len); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( - ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, mode.c_str(), - mode_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_ecb_legacy( + ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len); EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + std::string(reinterpret_cast(decrypted_value), decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { @@ -1429,48 +1414,50 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_ECB_MODE; - auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); std::string cipher = "12345678abcdefgh12345678abcdefghb"; auto cipher_len = static_cast(cipher.length()); - gdv_fn_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, - key33.c_str(), key33_len, mode.c_str(), - mode_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Unsupported key length for AES-ECB")); + gdv_fn_aes_encrypt_ecb_legacy(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, + &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported key length for AES-ECB: 33 bytes. Supported lengths: 16, 24, 32 bytes")); ctx.Reset(); - gdv_fn_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len, - key33.c_str(), key33_len, mode.c_str(), - mode_len, &decrypted_len); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Unsupported key length for AES-ECB")); + gdv_fn_aes_decrypt_ecb_legacy(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, + &decrypted_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported key length for AES-ECB: 33 bytes. Supported lengths: 16, 24, 32 bytes")); ctx.Reset(); } // Tests for new mode-aware AES functions TEST(TestGdvFnStubs, TestAesEncryptDecryptModeEcb) { gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; + int64_t ctx_ptr = reinterpret_cast(&ctx); + std::string data = "test string"; auto data_len = static_cast(data.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string mode = AES_ECB_MODE; auto mode_len = static_cast(mode.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + + bool encrypt_valid = true; const char* cipher = gdv_fn_encrypt_dispatcher_3args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &cipher_len); + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); + bool decrypt_valid = true; const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &decrypted_len); + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); @@ -1478,21 +1465,26 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeEcb) { TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; + int64_t ctx_ptr = reinterpret_cast(&ctx); + std::string data = "test string"; auto data_len = static_cast(data.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string invalid_mode = "AES-INVALID"; auto invalid_mode_len = static_cast(invalid_mode.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); + + int32_t cipher_len = 0; + int32_t decrypted_len = 0; // Test encrypt with invalid mode - gdv_fn_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, - key16.c_str(), key16_len, - invalid_mode.c_str(), invalid_mode_len, - &cipher_len); + bool encrypt_valid = true; + gdv_fn_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, true, + key16.c_str(), key16_len, true, + invalid_mode.c_str(), invalid_mode_len, true, + &encrypt_valid, &cipher_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported encryption mode")); ctx.Reset(); @@ -1500,214 +1492,777 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { // Test decrypt with invalid mode std::string cipher = "12345678abcdefgh12345678abcdefgh"; auto cipher_len_val = static_cast(cipher.length()); - gdv_fn_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len_val, - key16.c_str(), key16_len, - invalid_mode.c_str(), invalid_mode_len, - &decrypted_len); + bool decrypt_valid = true; + gdv_fn_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len_val, true, + key16.c_str(), key16_len, true, + invalid_mode.c_str(), invalid_mode_len, true, + &decrypt_valid, &decrypted_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported decryption mode")); ctx.Reset(); } // Tests for AES-GCM mode -TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmIvOnly) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithUserSuppliedIv) { gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + std::string data = "A long-ish test string to make sure the ciphertext is long enough for GCM"; + auto data_len = static_cast(data.length()); + std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); + std::string mode = AES_GCM_MODE; auto mode_len = static_cast(mode.length()); + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + + bool encrypt_valid = true; const char* cipher = gdv_fn_encrypt_dispatcher_5args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, nullptr, 0, &cipher_len); + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, iv.c_str(), iv_len, true, + nullptr, 0, false, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); + // Positive test + bool decrypt_valid = true; const char* decrypted_value = gdv_fn_decrypt_dispatcher_5args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, nullptr, 0, &decrypted_len); - + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, iv.c_str(), iv_len, true, + nullptr, 0, false, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + + // Negative test: IV not supplied to decrypt + ctx.Reset(); + decrypt_valid = true; + gdv_fn_decrypt_dispatcher_5args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, + nullptr, 0, false, &decrypt_valid, &decrypted_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("GCM tag verification failed or decryption error: Unknown OpenSSL error")); + ctx.Reset(); } -TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAad) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAutoGeneratedIv) { gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; - std::string data = "test string"; + + std::string mode = AES_GCM_MODE; + auto mode_len = static_cast(mode.length()); + + std::string iv = "123456789012"; + auto iv_len = static_cast(iv.length()); + + bool encrypt_valid = true; + bool decrypt_valid = true; + + // Ideally, we would want to test all combinations of encrypt/decrypt variants, but + // that would result in 3*3 = 9 combinations. Instead, we assume the respective variants + // are well tested elsewhere and only test the encrypt/decrypt functions out of pairs. + + // Encrypting with the 3-args variant and decrypting with the 4-args variant + int32_t cipher_from_3args_len = 0; + const char* cipher_from_3args = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, &encrypt_valid, &cipher_from_3args_len); + + int32_t decrypted_from_4args_len = 0; + const char* decrypted_from_4args = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher_from_3args, cipher_from_3args_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, &decrypt_valid, &decrypted_from_4args_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_from_4args), + decrypted_from_4args_len)); + + // Encrypting with the 4-args variant and decrypting with the 5-args variant + int32_t cipher_from_4args_len = 0; + const char* cipher_from_4args = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, &encrypt_valid, &cipher_from_4args_len); + + int32_t decrypted_from_5args_len = 0; + const char* decrypted_from_5args = gdv_fn_decrypt_dispatcher_5args( + ctx_ptr, cipher_from_4args, cipher_from_4args_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, + nullptr, 0, false, &decrypt_valid, &decrypted_from_5args_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_from_5args), + decrypted_from_5args_len)); + + // Encrypting with the 5-args variant and decrypting with the 3-args variant + int32_t cipher_from_5args_len = 0; + const char* cipher_from_5args = gdv_fn_encrypt_dispatcher_5args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, + nullptr, 0, false, &encrypt_valid, &cipher_from_5args_len); + + int32_t decrypted_from_3args_len = 0; + const char* decrypted_from_3args = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher_from_3args, cipher_from_3args_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, &decrypt_valid, &decrypted_from_3args_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_from_3args), + decrypted_from_3args_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAad) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + std::string data = "A long-ish test string to make sure the ciphertext is long enough for GCM"; auto data_len = static_cast(data.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string mode = AES_GCM_MODE; auto mode_len = static_cast(mode.length()); + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); + std::string aad = "additional authenticated data"; auto aad_len = static_cast(aad.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + + bool encrypt_valid = true; const char* cipher = gdv_fn_encrypt_dispatcher_5args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, aad.c_str(), aad_len, &cipher_len); + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, iv.c_str(), iv_len, true, + aad.c_str(), aad_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); + // Positive test + bool decrypt_valid = true; const char* decrypted_value = gdv_fn_decrypt_dispatcher_5args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, aad.c_str(), aad_len, &decrypted_len); - + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, iv.c_str(), iv_len, true, + aad.c_str(), aad_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + + // Negative test: AAD not supplied to decrypt + ctx.Reset(); + decrypt_valid = true; + gdv_fn_decrypt_dispatcher_5args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, iv.c_str(), iv_len, true, + nullptr, 0, false, &decrypt_valid, &decrypted_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("GCM tag verification failed or decryption error: Unknown OpenSSL error")); + ctx.Reset(); + + // Negative test: a different AAD is supplied to decrypt + ctx.Reset(); + decrypt_valid = true; + std::string different_aad = "different aad"; + auto different_aad_len = static_cast(different_aad.length()); + gdv_fn_decrypt_dispatcher_5args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, iv.c_str(), iv_len, true, + different_aad.c_str(), different_aad_len, true, &decrypt_valid, &decrypted_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("GCM tag verification failed or decryption error: Unknown OpenSSL error")); } -// Tests for shorthand mode: AES-ECB (defaults to PKCS7) -TEST(TestGdvFnStubs, TestAesEncryptDecryptShorthandEcb) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptCbcWithAutoGeneratedIv) { gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; + int64_t ctx_ptr = reinterpret_cast(&ctx); + std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_ECB_MODE; // Shorthand mode + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + + std::string mode = AES_CBC_MODE; auto mode_len = static_cast(mode.length()); + + bool encrypt_valid = true; + bool decrypt_valid = true; + + // Ideally, we would want to test all combinations of encrypt/decrypt variants, but + // that would result in 3*3 = 9 combinations. Instead, we assume the respective variants + // are well tested elsewhere and only test the encrypt/decrypt functions out of pairs. + + // Encrypting with the 3-args variant and decrypting with the 4-args variant + // Note: 3-args doesn't support IV, so this won't work for CBC. Skip this combination. + + // Encrypting with the 4-args variant (NULL IV) and decrypting with the 5-args variant (NULL IV) + int32_t cipher_from_4args_len = 0; + const char* cipher_from_4args = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, &encrypt_valid, &cipher_from_4args_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_from_4args_len, 0); + + int32_t decrypted_from_5args_len = 0; + const char* decrypted_from_5args = gdv_fn_decrypt_dispatcher_5args( + ctx_ptr, cipher_from_4args, cipher_from_4args_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, + nullptr, 0, false, &decrypt_valid, &decrypted_from_5args_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_from_5args), + decrypted_from_5args_len)); + + // Encrypting with the 5-args variant (NULL IV) and decrypting with the 4-args variant (NULL IV) + int32_t cipher_from_5args_len = 0; + const char* cipher_from_5args = gdv_fn_encrypt_dispatcher_5args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, + nullptr, 0, false, &encrypt_valid, &cipher_from_5args_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_from_5args_len, 0); + + int32_t decrypted_from_4args_len = 0; + const char* decrypted_from_4args = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher_from_5args, cipher_from_5args_len, true, key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, nullptr, 0, false, &decrypt_valid, &decrypted_from_4args_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_from_4args), + decrypted_from_4args_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptEcbWithBlockAlignedData) { + gandiva::ExecutionContext ctx; int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_3args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &cipher_len); + // Data with length that is a multiple of 16 + std::string data = "12345678901234561234567890123456"; + auto data_len = static_cast(data.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + + int32_t cipher_len = 0; + bool encrypt_valid = true; + + // Test AES-ECB (shorthand, with PKCS7 padding) + std::string mode_ecb = AES_ECB_MODE; + auto mode_ecb_len = static_cast(mode_ecb.length()); + const char* cipher_ecb = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb.c_str(), mode_ecb_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &decrypted_len); + // Test AES-ECB-PKCS7 (explicit PKCS7 padding) + std::string mode_ecb_pkcs7 = AES_ECB_PKCS7_MODE; + auto mode_ecb_pkcs7_len = static_cast(mode_ecb_pkcs7.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_ecb_pkcs7 = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb_pkcs7.c_str(), mode_ecb_pkcs7_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); - EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + // Test AES-ECB-NONE (no padding) + std::string mode_ecb_none = AES_ECB_NONE_MODE; + auto mode_ecb_none_len = static_cast(mode_ecb_none.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_ecb_none = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb_none.c_str(), mode_ecb_none_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); } -// Tests for explicit mode: AES-ECB-PKCS7 -TEST(TestGdvFnStubs, TestAesEncryptDecryptExplicitEcbPkcs7) { +TEST(TestGdvFnStubs, TestAesDecryptEcbWithBlockAlignedData) { gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is a multiple of 16 + std::string data = "12345678901234561234567890123456"; + auto data_len = static_cast(data.length()); + std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); + + // Encrypt once with AES-ECB-PKCS7 to get ciphertext + std::string mode_ecb_pkcs7 = AES_ECB_PKCS7_MODE; + auto mode_ecb_pkcs7_len = static_cast(mode_ecb_pkcs7.length()); int32_t cipher_len = 0; + bool encrypt_valid = true; + const char* cipher = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb_pkcs7.c_str(), mode_ecb_pkcs7_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); + int32_t decrypted_len = 0; + bool decrypt_valid = true; + + // Test AES-ECB (shorthand, with PKCS7 padding) + std::string mode_ecb = AES_ECB_MODE; + auto mode_ecb_len = static_cast(mode_ecb.length()); + const char* decrypted_ecb = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_ecb.c_str(), mode_ecb_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_GT(decrypted_len, 0); + + // Test AES-ECB-PKCS7 (explicit PKCS7 padding) + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_ecb_pkcs7 = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_ecb_pkcs7.c_str(), mode_ecb_pkcs7_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_GT(decrypted_len, 0); + + // Test AES-ECB-NONE (no padding) + std::string mode_ecb_none = AES_ECB_NONE_MODE; + auto mode_ecb_none_len = static_cast(mode_ecb_none.length()); + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_ecb_none = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_ecb_none.c_str(), mode_ecb_none_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_GT(decrypted_len, 0); +} + +TEST(TestGdvFnStubs, TestAesEncryptEcbWithNonBlockAlignedData) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is not a multiple of 16 std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_ECB_PKCS7_MODE; // Explicit mode - auto mode_len = static_cast(mode.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_3args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &cipher_len); + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + + int32_t cipher_len = 0; + bool encrypt_valid = true; + + // Test AES-ECB (shorthand, with PKCS7 padding) - should succeed + std::string mode_ecb = AES_ECB_MODE; + auto mode_ecb_len = static_cast(mode_ecb.length()); + const char* cipher_ecb = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb.c_str(), mode_ecb_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, &decrypted_len); + // Test AES-ECB-PKCS7 (explicit PKCS7 padding) - should succeed + std::string mode_ecb_pkcs7 = AES_ECB_PKCS7_MODE; + auto mode_ecb_pkcs7_len = static_cast(mode_ecb_pkcs7.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_ecb_pkcs7 = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb_pkcs7.c_str(), mode_ecb_pkcs7_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); - EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + // Test AES-ECB-NONE (no padding) - should fail because data is not block-aligned + std::string mode_ecb_none = AES_ECB_NONE_MODE; + auto mode_ecb_none_len = static_cast(mode_ecb_none.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_ecb_none = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb_none.c_str(), mode_ecb_none_len, true, &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Could not finalize EVP cipher context for encryption")); } -// Tests for shorthand mode: AES-CBC (defaults to PKCS7) -TEST(TestGdvFnStubs, TestAesEncryptDecryptShorthandCbc) { +TEST(TestGdvFnStubs, TestAesDecryptEcbWithNonBlockAlignedData) { gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is not a multiple of 16 + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); + + // Encrypt once with AES-ECB-PKCS7 to get ciphertext + std::string mode_ecb_pkcs7 = AES_ECB_PKCS7_MODE; + auto mode_ecb_pkcs7_len = static_cast(mode_ecb_pkcs7.length()); int32_t cipher_len = 0; + bool encrypt_valid = true; + const char* cipher = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_ecb_pkcs7.c_str(), mode_ecb_pkcs7_len, true, &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); + int32_t decrypted_len = 0; - std::string data = "test string"; + bool decrypt_valid = true; + + // Test AES-ECB (shorthand, with PKCS7 padding) - should succeed + std::string mode_ecb = AES_ECB_MODE; + auto mode_ecb_len = static_cast(mode_ecb.length()); + const char* decrypted_ecb = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_ecb.c_str(), mode_ecb_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(decrypted_len, data_len); // Returns original data length (padding removed) + + // Test AES-ECB-PKCS7 (explicit PKCS7 padding) - should succeed + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_ecb_pkcs7 = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_ecb_pkcs7.c_str(), mode_ecb_pkcs7_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(decrypted_len, data_len); // Returns original data length (padding removed) + + // Test AES-ECB-NONE (no padding) - should succeed but return padded data + std::string mode_ecb_none = AES_ECB_NONE_MODE; + auto mode_ecb_none_len = static_cast(mode_ecb_none.length()); + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_ecb_none = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_ecb_none.c_str(), mode_ecb_none_len, true, &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(decrypted_len, cipher_len); // Returns full block including padding +} + +// CBC mode tests with block-aligned and non-block-aligned data + +TEST(TestGdvFnStubs, TestAesEncryptCbcWithBlockAlignedData) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is a multiple of 16 (32 bytes) + std::string data = "12345678901234561234567890123456"; auto data_len = static_cast(data.length()); - std::string mode = AES_CBC_MODE; // Shorthand mode - auto mode_len = static_cast(mode.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; auto iv_len = static_cast(iv.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_4args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, &cipher_len); + int32_t cipher_len = 0; + bool encrypt_valid = true; + + // Test AES-CBC (shorthand, with PKCS7 padding) + std::string mode_cbc = AES_CBC_MODE; + auto mode_cbc_len = static_cast(mode_cbc.length()); + const char* cipher_cbc = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc.c_str(), mode_cbc_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_4args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, &decrypted_len); + // Test AES-CBC-PKCS7 (explicit PKCS7 padding) + std::string mode_cbc_pkcs7 = AES_CBC_PKCS7_MODE; + auto mode_cbc_pkcs7_len = static_cast(mode_cbc_pkcs7.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_cbc_pkcs7 = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc_pkcs7.c_str(), mode_cbc_pkcs7_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); - EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + // Test AES-CBC-NONE (no padding) + std::string mode_cbc_none = AES_CBC_NONE_MODE; + auto mode_cbc_none_len = static_cast(mode_cbc_none.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_cbc_none = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc_none.c_str(), mode_cbc_none_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); } -// Tests for explicit mode: AES-CBC-PKCS7 -TEST(TestGdvFnStubs, TestAesEncryptDecryptExplicitCbcPkcs7) { +TEST(TestGdvFnStubs, TestAesDecryptCbcWithBlockAlignedData) { gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is a multiple of 16 (32 bytes) + std::string data = "12345678901234561234567890123456"; + auto data_len = static_cast(data.length()); + std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); + + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + + // Encrypt once with AES-CBC-PKCS7 to get ciphertext + std::string mode_cbc_pkcs7 = AES_CBC_PKCS7_MODE; + auto mode_cbc_pkcs7_len = static_cast(mode_cbc_pkcs7.length()); int32_t cipher_len = 0; + bool encrypt_valid = true; + const char* cipher = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc_pkcs7.c_str(), mode_cbc_pkcs7_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); + int32_t decrypted_len = 0; + bool decrypt_valid = true; + + // Test AES-CBC (shorthand, with PKCS7 padding) + std::string mode_cbc = AES_CBC_MODE; + auto mode_cbc_len = static_cast(mode_cbc.length()); + const char* decrypted_cbc = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_cbc.c_str(), mode_cbc_len, true, iv.c_str(), iv_len, true, + &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_GT(decrypted_len, 0); + + // Test AES-CBC-PKCS7 (explicit PKCS7 padding) + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_cbc_pkcs7 = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_cbc_pkcs7.c_str(), mode_cbc_pkcs7_len, true, iv.c_str(), iv_len, true, + &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_GT(decrypted_len, 0); + + // Test AES-CBC-NONE (no padding) + std::string mode_cbc_none = AES_CBC_NONE_MODE; + auto mode_cbc_none_len = static_cast(mode_cbc_none.length()); + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_cbc_none = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_cbc_none.c_str(), mode_cbc_none_len, true, iv.c_str(), iv_len, true, + &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_GT(decrypted_len, 0); +} + +TEST(TestGdvFnStubs, TestAesEncryptCbcWithNonBlockAlignedData) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is NOT a multiple of 16 (11 bytes) std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = AES_CBC_PKCS7_MODE; // Explicit mode - auto mode_len = static_cast(mode.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; auto iv_len = static_cast(iv.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_encrypt_dispatcher_4args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, &cipher_len); + int32_t cipher_len = 0; + bool encrypt_valid = true; + + // Test AES-CBC (shorthand, with PKCS7 padding) - should succeed + std::string mode_cbc = AES_CBC_MODE; + auto mode_cbc_len = static_cast(mode_cbc.length()); + const char* cipher_cbc = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc.c_str(), mode_cbc_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_4args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, &decrypted_len); + // Test AES-CBC-PKCS7 (explicit PKCS7 padding) - should succeed + std::string mode_cbc_pkcs7 = AES_CBC_PKCS7_MODE; + auto mode_cbc_pkcs7_len = static_cast(mode_cbc_pkcs7.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_cbc_pkcs7 = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc_pkcs7.c_str(), mode_cbc_pkcs7_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); + EXPECT_GT(cipher_len, 0); - EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); + // Test AES-CBC-NONE (no padding) - should fail because data is not block-aligned + std::string mode_cbc_none = AES_CBC_NONE_MODE; + auto mode_cbc_none_len = static_cast(mode_cbc_none.length()); + encrypt_valid = true; + cipher_len = 0; + const char* cipher_cbc_none = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc_none.c_str(), mode_cbc_none_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Could not finalize EVP cipher context for encryption")); } -// Tests for explicit mode: AES-CBC-NONE (no padding) -TEST(TestGdvFnStubs, TestAesEncryptDecryptCbcNone) { +TEST(TestGdvFnStubs, TestAesDecryptCbcWithNonBlockAlignedData) { gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Data with length that is not a multiple of 16 + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; - // Use exactly 16 bytes (one block) for no-padding mode - std::string data = "1234567890123456"; - auto data_len = static_cast(data.length()); - std::string mode = AES_CBC_NONE_MODE; // No padding mode - auto mode_len = static_cast(mode.length()); + std::string iv = "1234567890123456"; auto iv_len = static_cast(iv.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); + // Encrypt once with AES-CBC-PKCS7 to get ciphertext + std::string mode_cbc_pkcs7 = AES_CBC_PKCS7_MODE; + auto mode_cbc_pkcs7_len = static_cast(mode_cbc_pkcs7.length()); + int32_t cipher_len = 0; + bool encrypt_valid = true; const char* cipher = gdv_fn_encrypt_dispatcher_4args( - ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, &cipher_len); + ctx_ptr, data.c_str(), data_len, true, key16.c_str(), key16_len, true, + mode_cbc_pkcs7.c_str(), mode_cbc_pkcs7_len, true, iv.c_str(), iv_len, true, + &encrypt_valid, &cipher_len); + EXPECT_TRUE(encrypt_valid); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_decrypt_dispatcher_4args( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), - mode_len, iv.c_str(), iv_len, &decrypted_len); + int32_t decrypted_len = 0; + bool decrypt_valid = true; + + // Test AES-CBC (shorthand, with PKCS7 padding) - should succeed + std::string mode_cbc = AES_CBC_MODE; + auto mode_cbc_len = static_cast(mode_cbc.length()); + const char* decrypted_cbc = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_cbc.c_str(), mode_cbc_len, true, iv.c_str(), iv_len, true, + &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(decrypted_len, data_len); // Returns original data length (padding removed) + + // Test AES-CBC-PKCS7 (explicit PKCS7 padding) - should succeed + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_cbc_pkcs7 = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_cbc_pkcs7.c_str(), mode_cbc_pkcs7_len, true, iv.c_str(), iv_len, true, + &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(decrypted_len, data_len); // Returns original data length (padding removed) + + // Test AES-CBC-NONE (no padding) - should succeed but return padded data + std::string mode_cbc_none = AES_CBC_NONE_MODE; + auto mode_cbc_none_len = static_cast(mode_cbc_none.length()); + decrypt_valid = true; + decrypted_len = 0; + const char* decrypted_cbc_none = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, true, key16.c_str(), key16_len, true, + mode_cbc_none.c_str(), mode_cbc_none_len, true, iv.c_str(), iv_len, true, + &decrypt_valid, &decrypted_len); + EXPECT_TRUE(decrypt_valid); + EXPECT_EQ(decrypted_len, cipher_len); // Returns full block including padding +} - EXPECT_EQ(data, - std::string(reinterpret_cast(decrypted_value), - decrypted_len)); +// Validation tests + +TEST(TestGdvFnStubs, TestAesEncrypt3ArgsValidation) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + + std::string data = "test string"; + auto data_len = static_cast(data.length()); + + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + + std::string mode = AES_ECB_MODE; + auto mode_len = static_cast(mode.length()); + + int32_t cipher_len = 0; + bool encrypt_valid = true; + + // Test 1: NULL plaintext should return NULL ciphertext + const char* result = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, false, // data_validity = false (NULL plaintext) + key16.c_str(), key16_len, true, + mode.c_str(), mode_len, true, + &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_EQ(result, nullptr); + EXPECT_EQ(cipher_len, 0); + + // Test 2: NULL key should fail with error + ctx.Reset(); + encrypt_valid = true; + cipher_len = 0; + result = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, + key16.c_str(), key16_len, false, // key_validity = false (NULL key) + mode.c_str(), mode_len, true, + &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("key cannot be NULL")); + + // Test 3: NULL mode should fail with error + ctx.Reset(); + encrypt_valid = true; + cipher_len = 0; + result = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, + key16.c_str(), key16_len, true, + mode.c_str(), mode_len, false, // mode_validity = false (NULL mode) + &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported encryption mode: NULL. Supported modes: AES-ECB, AES-ECB-PKCS7, AES-ECB-NONE, AES-CBC, AES-CBC-PKCS7, AES-CBC-NONE, AES-GCM")); + + // Test 4: Invalid mode string should fail with error + ctx.Reset(); + encrypt_valid = true; + cipher_len = 0; + std::string invalid_mode = "AES-INVALID"; + auto invalid_mode_len = static_cast(invalid_mode.length()); + result = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, + key16.c_str(), key16_len, true, + invalid_mode.c_str(), invalid_mode_len, true, + &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported encryption mode: AES-INVALID. Supported modes: AES-ECB, AES-ECB-PKCS7, AES-ECB-NONE, AES-CBC, AES-CBC-PKCS7, AES-CBC-NONE, AES-GCM")); + + // Test 5: Invalid key length should fail with error + ctx.Reset(); + encrypt_valid = true; + cipher_len = 0; + std::string short_key = "short"; + auto short_key_len = static_cast(short_key.length()); + result = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, true, + short_key.c_str(), short_key_len, true, + mode.c_str(), mode_len, true, + &encrypt_valid, &cipher_len); + EXPECT_FALSE(encrypt_valid); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported key length for AES-ECB")); } } // namespace gandiva