Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/gandiva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ set(SRC_FILES
engine.cc
date_utils.cc
encrypt_utils_common.cc
encrypt_utils_iv.cc
encrypt_utils_ecb.cc
encrypt_utils_cbc.cc
encrypt_utils_gcm.cc
Expand Down Expand Up @@ -269,6 +270,7 @@ add_gandiva_test(internals-test
encrypt_utils_ecb_test.cc
encrypt_utils_cbc_test.cc
encrypt_utils_gcm_test.cc
encrypt_utils_iv_test.cc
encrypt_utils_common_test.cc
expr_decomposer_test.cc
exported_funcs_registry_test.cc
Expand Down
88 changes: 50 additions & 38 deletions cpp/src/gandiva/encrypt_mode_dispatcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

namespace gandiva {

// Supported encryption modes
static const std::vector<std::string_view> SUPPORTED_MODES = {
AES_ECB_MODE, AES_ECB_PKCS7_MODE, AES_ECB_NONE_MODE,
AES_CBC_MODE, AES_CBC_PKCS7_MODE, AES_CBC_NONE_MODE,
Expand All @@ -42,95 +41,108 @@ enum class EncryptionMode {
CBC_PKCS7,
CBC_NONE,
GCM,
NULL_VALUE,
UNKNOWN
};

EncryptionMode ParseEncryptionMode(std::string_view mode_str) {
EncryptionMode ParseEncryptionMode(const char* mode, int32_t mode_len, bool mode_validity) {
if (!mode_validity) {
return EncryptionMode::NULL_VALUE;
}

// Convert mode string to uppercase for case-insensitive comparison
std::string mode_str =
arrow::internal::AsciiToUpper(std::string_view(mode, mode_len));

if (mode_str == AES_ECB_MODE) return EncryptionMode::ECB;
if (mode_str == AES_ECB_PKCS7_MODE) return EncryptionMode::ECB_PKCS7;
if (mode_str == AES_ECB_NONE_MODE) return EncryptionMode::ECB_NONE;
if (mode_str == AES_CBC_MODE) return EncryptionMode::CBC;
if (mode_str == AES_CBC_PKCS7_MODE) return EncryptionMode::CBC_PKCS7;
if (mode_str == AES_CBC_NONE_MODE) return EncryptionMode::CBC_NONE;
if (mode_str == AES_GCM_MODE) return EncryptionMode::GCM;

return EncryptionMode::UNKNOWN;
}

std::string BuildUnsupportedModeError(const char* operation, const char* mode, int32_t mode_len) {
std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", ");
std::ostringstream oss;
oss << "Unsupported " << operation << " mode: " << std::string_view(mode, mode_len)
<< ". Supported modes: " << modes_str;
return oss.str();
}

int32_t EncryptModeDispatcher::encrypt(
const char* plaintext, int32_t plaintext_len, const char* key,
int32_t key_len, const char* mode, int32_t mode_len, const char* iv,
int32_t iv_len, const char* fifth_argument, int32_t fifth_argument_len,
unsigned char* cipher) {
std::string mode_str =
arrow::internal::AsciiToUpper(std::string_view(mode, mode_len));
const char* plaintext, int32_t plaintext_len,
const char* key, int32_t key_len, bool key_validity,
const char* mode, int32_t mode_len, bool mode_validity,
const char* iv, int32_t iv_len, bool iv_validity,
const char* fifth_argument, int32_t fifth_argument_len,
bool fifth_argument_validity, unsigned char* cipher) {
if (!key_validity) {
throw std::runtime_error("Encryption key cannot be NULL");
}

switch (ParseEncryptionMode(mode_str)) {
switch (ParseEncryptionMode(mode, mode_len, mode_validity)) {
case EncryptionMode::ECB:
case EncryptionMode::ECB_PKCS7:
// Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7 padding
return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, true, cipher);
case EncryptionMode::ECB_NONE:
// ECB without padding
return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, false, cipher);
case EncryptionMode::CBC:
case EncryptionMode::CBC_PKCS7:
// Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7
return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len,
iv, iv_len, true, cipher);
case EncryptionMode::CBC_NONE:
// CBC without padding
return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len,
iv, iv_len, false, cipher);
case EncryptionMode::GCM:
return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len,
iv, iv_len, fifth_argument, fifth_argument_len, cipher);
case EncryptionMode::NULL_VALUE:
throw std::runtime_error(BuildUnsupportedModeError("encryption", "NULL", 4));
case EncryptionMode::UNKNOWN:
default: {
std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", ");
std::ostringstream oss;
oss << "Unsupported encryption mode: " << mode_str
<< ". Supported modes: " << modes_str;
throw std::runtime_error(oss.str());
}
default:
throw std::runtime_error(BuildUnsupportedModeError("encryption", mode, mode_len));
}
}

int32_t EncryptModeDispatcher::decrypt(
const char* ciphertext, int32_t ciphertext_len, const char* key,
int32_t key_len, const char* mode, int32_t mode_len, const char* iv,
int32_t iv_len, const char* fifth_argument, int32_t fifth_argument_len,
unsigned char* plaintext) {
std::string mode_str =
arrow::internal::AsciiToUpper(std::string_view(mode, mode_len));
const char* ciphertext, int32_t ciphertext_len,
const char* key, int32_t key_len, bool key_validity,
const char* mode, int32_t mode_len, bool mode_validity,
const char* iv, int32_t iv_len, bool iv_validity,
const char* fifth_argument, int32_t fifth_argument_len,
bool fifth_argument_validity, unsigned char* plaintext) {
// If key is NULL (validity flag is false), throw error
if (!key_validity) {
throw std::runtime_error("Decryption key cannot be NULL");
}

switch (ParseEncryptionMode(mode_str)) {
switch (ParseEncryptionMode(mode, mode_len, mode_validity)) {
case EncryptionMode::ECB:
case EncryptionMode::ECB_PKCS7:
// Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7 padding
return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, true, plaintext);
case EncryptionMode::ECB_NONE:
// ECB without padding
return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, false, plaintext);
case EncryptionMode::CBC:
case EncryptionMode::CBC_PKCS7:
// Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7
return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, true, plaintext);
case EncryptionMode::CBC_NONE:
// CBC without padding
// CBC mode without padding
return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, false, plaintext);
case EncryptionMode::GCM:
return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, fifth_argument, fifth_argument_len, plaintext);
case EncryptionMode::UNKNOWN:
default: {
std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", ");
std::ostringstream oss;
oss << "Unsupported decryption mode: " << mode_str
<< ". Supported modes: " << modes_str;
throw std::runtime_error(oss.str());
}
default:
if (!mode_validity) {
throw std::runtime_error(BuildUnsupportedModeError("decryption", "NULL", 4));
}
throw std::runtime_error(BuildUnsupportedModeError("decryption", mode, mode_len));
}
}

Expand Down
46 changes: 28 additions & 18 deletions cpp/src/gandiva/encrypt_mode_dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,47 +33,57 @@ class EncryptModeDispatcher {
*
* @param plaintext The data to encrypt
* @param plaintext_len Length of plaintext in bytes
* @param key The encryption key
* @param key The encryption key (16, 24, or 32 bytes for AES-128/192/256)
* @param key_len Length of key in bytes
* @param mode Mode string
* @param key_validity Whether key is valid
* @param mode Mode string (case-insensitive)
* @param mode_len Length of mode string in bytes
* @param iv The initialization vector (optional, only for modes that support it)
* @param mode_validity Whether mode is valid
* @param iv The initialization vector
* @param iv_len Length of the IV in bytes
* @param fifth_argument Additional parameter (optional, only for modes that support it)
* @param iv_validity Whether IV is valid
* @param fifth_argument Additional parameter (e.g. AAD for the GCM mode)
* @param fifth_argument_len Length of fifth_argument in bytes
* @param fifth_argument_validity Whether fifth_argument is valid
* @param cipher Output buffer for encrypted data
* @return Length of encrypted data in bytes
* @throws std::runtime_error on encryption failure or unsupported mode
* @throws std::runtime_error on encryption failure, unsupported mode, or invalid parameters
*/
static int32_t encrypt(const char* plaintext, int32_t plaintext_len,
const char* key, int32_t key_len,
const char* mode, int32_t mode_len,
const char* iv, int32_t iv_len,
const char* key, int32_t key_len, bool key_validity,
const char* mode, int32_t mode_len, bool mode_validity,
const char* iv, int32_t iv_len, bool iv_validity,
const char* fifth_argument, int32_t fifth_argument_len,
bool fifth_argument_validity,
unsigned char* cipher);

/**
* Decrypt data using the specified mode
*
* @param ciphertext The data to decrypt
* @param ciphertext The data to decrypt (format depends on mode and IV parameter)
* @param ciphertext_len Length of ciphertext in bytes
* @param key The decryption key
* @param key The decryption key (16, 24, or 32 bytes for AES-128/192/256)
* @param key_len Length of key in bytes
* @param mode Mode string
* @param key_validity Whether key is valid
* @param mode Mode string (case-insensitive)
* @param mode_len Length of mode string in bytes
* @param iv The initialization vector (optional, only for modes that support it)
* @param mode_validity Whether mode is valid
* @param iv The initialization vector
* @param iv_len Length of the IV in bytes
* @param fifth_argument Additional parameter (optional, only for modes that support it)
* @param iv_validity Whether IV is valid
* @param fifth_argument Additional parameter (e.g. AAD for the GCM mode)
* @param fifth_argument_len Length of fifth_argument in bytes
* @param fifth_argument_validity Whether fifth_argument is valid
* @param plaintext Output buffer for decrypted data
* @return Length of decrypted data in bytes
* @throws std::runtime_error on decryption failure or unsupported mode
* @return Length of decrypted data in bytes (plaintext only, IV and tag removed)
* @throws std::runtime_error on decryption failure, unsupported mode, invalid parameters, or authentication failure
*/
static int32_t decrypt(const char* ciphertext, int32_t ciphertext_len,
const char* key, int32_t key_len,
const char* mode, int32_t mode_len,
const char* iv, int32_t iv_len,
const char* key, int32_t key_len, bool key_validity,
const char* mode, int32_t mode_len, bool mode_validity,
const char* iv, int32_t iv_len, bool iv_validity,
const char* fifth_argument, int32_t fifth_argument_len,
bool fifth_argument_validity,
unsigned char* plaintext);
};

Expand Down
82 changes: 64 additions & 18 deletions cpp/src/gandiva/encrypt_utils_cbc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "gandiva/encrypt_utils_cbc.h"
#include "gandiva/encrypt_utils_common.h"
#include "gandiva/encrypt_utils_iv.h"
#include <openssl/aes.h>
#include <openssl/err.h>
#include <stdexcept>
Expand Down Expand Up @@ -45,18 +46,45 @@ const EVP_CIPHER* get_cbc_cipher_algo(int32_t key_length) {
}
}

void validate_iv_length_cbc(int32_t iv_len) {
if (iv_len != CBC_IV_LENGTH) {
std::ostringstream oss;
oss << "Invalid IV length for AES-CBC: " << iv_len
<< " bytes. IV must be exactly " << CBC_IV_LENGTH << " bytes";
throw std::runtime_error(oss.str());
}
}

void validate_ciphertext_with_embedded_iv_cbc(int32_t ciphertext_len) {
constexpr int32_t MIN_CIPHERTEXT_LEN = CBC_IV_LENGTH + 16; // IV + minimum one block
if (ciphertext_len < MIN_CIPHERTEXT_LEN) {
std::ostringstream oss;
oss << "Ciphertext too short for AES-CBC with embedded IV: " << ciphertext_len
<< " bytes. Must be at least " << MIN_CIPHERTEXT_LEN
<< " bytes (16-byte IV + minimum 16-byte block)";
throw std::runtime_error(oss.str());
}
}

} // namespace

GANDIVA_EXPORT
int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key,
int32_t key_len, const char* iv, int32_t iv_len,
bool use_padding, unsigned char* cipher) {
// Validate IV length
if (iv_len != 16) {
std::ostringstream oss;
oss << "Invalid IV length for AES-CBC: " << iv_len
<< " bytes. IV must be exactly 16 bytes";
throw std::runtime_error(oss.str());
// Buffer for IV (either user-supplied or auto-generated)
unsigned char iv_buffer[CBC_IV_LENGTH];
const unsigned char* actual_iv = nullptr;
bool iv_auto_generated = false;

if (iv == nullptr || iv_len == 0) {
generate_random_iv(iv_buffer, CBC_IV_LENGTH);
actual_iv = iv_buffer;
iv_auto_generated = true;
} else {
validate_iv_length_cbc(iv_len);
actual_iv = reinterpret_cast<const unsigned char*>(iv);
iv_auto_generated = false;
}

int32_t cipher_len = 0;
Expand All @@ -69,9 +97,17 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char
get_openssl_error_string());
}

// Only prepend IV to output if it was auto-generated
// Auto-generated IV: [16-byte IV][ciphertext]
// User-supplied IV: [ciphertext]
if (iv_auto_generated) {
std::memcpy(cipher, actual_iv, CBC_IV_LENGTH);
cipher_len = CBC_IV_LENGTH;
}

if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key),
reinterpret_cast<const unsigned char*>(iv))) {
actual_iv)) {
EVP_CIPHER_CTX_free(en_ctx);
throw std::runtime_error("Could not initialize EVP cipher context for encryption: " +
get_openssl_error_string());
Expand All @@ -84,7 +120,8 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char
get_openssl_error_string());
}

if (!EVP_EncryptUpdate(en_ctx, cipher, &len,
// Encrypt plaintext (write after IV)
if (!EVP_EncryptUpdate(en_ctx, cipher + cipher_len, &len,
reinterpret_cast<const unsigned char*>(plaintext),
plaintext_len)) {
EVP_CIPHER_CTX_free(en_ctx);
Expand All @@ -94,7 +131,7 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char

cipher_len += len;

if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) {
if (!EVP_EncryptFinal_ex(en_ctx, cipher + cipher_len, &len)) {
EVP_CIPHER_CTX_free(en_ctx);
throw std::runtime_error("Could not finalize EVP cipher context for encryption: " +
get_openssl_error_string());
Expand All @@ -110,12 +147,21 @@ GANDIVA_EXPORT
int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key,
int32_t key_len, const char* iv, int32_t iv_len,
bool use_padding, unsigned char* plaintext) {
// Validate IV length
if (iv_len != 16) {
std::ostringstream oss;
oss << "Invalid IV length for AES-CBC: " << iv_len
<< " bytes. IV must be exactly 16 bytes";
throw std::runtime_error(oss.str());
// Buffer for extracted IV (if needed)
unsigned char iv_buffer[CBC_IV_LENGTH];
const unsigned char* actual_iv = nullptr;
const char* actual_ciphertext = ciphertext;
int32_t actual_ciphertext_len = ciphertext_len;

if (iv == nullptr) {
validate_ciphertext_with_embedded_iv_cbc(ciphertext_len);
extract_iv_from_ciphertext(ciphertext, ciphertext_len, CBC_IV_LENGTH,
iv_buffer, &actual_ciphertext,
&actual_ciphertext_len);
actual_iv = iv_buffer;
} else {
validate_iv_length_cbc(iv_len);
actual_iv = reinterpret_cast<const unsigned char*>(iv);
}

int32_t plaintext_len = 0;
Expand All @@ -130,7 +176,7 @@ int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const ch

if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key),
reinterpret_cast<const unsigned char*>(iv))) {
actual_iv)) {
EVP_CIPHER_CTX_free(de_ctx);
throw std::runtime_error("Could not initialize EVP cipher context for decryption: " +
get_openssl_error_string());
Expand All @@ -144,8 +190,8 @@ int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const ch
}

if (!EVP_DecryptUpdate(de_ctx, plaintext, &len,
reinterpret_cast<const unsigned char*>(ciphertext),
ciphertext_len)) {
reinterpret_cast<const unsigned char*>(actual_ciphertext),
actual_ciphertext_len)) {
EVP_CIPHER_CTX_free(de_ctx);
throw std::runtime_error("Could not update EVP cipher context for decryption: " +
get_openssl_error_string());
Expand Down
Loading
Loading