diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ParserBenchmark.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ParserBenchmark.java new file mode 100644 index 000000000..41a2d6220 --- /dev/null +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ParserBenchmark.java @@ -0,0 +1,94 @@ +package software.amazon.jdbc.benchmarks; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import software.amazon.jdbc.plugin.encryption.parser.PostgreSqlParser; + +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 3, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +public class ParserBenchmark { + + private PostgreSqlParser parser; + + @Setup + public void setup() { + parser = new PostgreSqlParser(); + } + + @Benchmark + public void parseSimpleSelect() { + parser.parse("SELECT * FROM users"); + } + + @Benchmark + public void parseSelectWithWhere() { + parser.parse("SELECT id, name FROM users WHERE age > 25"); + } + + @Benchmark + public void parseSelectWithOrderBy() { + parser.parse("SELECT * FROM products ORDER BY price DESC"); + } + + @Benchmark + public void parseComplexSelect() { + parser.parse("SELECT u.name, o.total FROM users u, orders o WHERE u.id = o.user_id AND o.total > 100"); + } + + @Benchmark + public void parseInsert() { + parser.parse("INSERT INTO users (name, age, email) VALUES ('John', 30, 'john@example.com')"); + } + + @Benchmark + public void parseInsertWithPlaceholders() { + parser.parse("INSERT INTO users (name, age, email) VALUES (?, ?, ?)"); + } + + @Benchmark + public void parseUpdate() { + parser.parse("UPDATE users SET name = 'Jane', age = 25 WHERE id = 1"); + } + + @Benchmark + public void parseUpdateWithPlaceholders() { + parser.parse("UPDATE users SET name = ?, age = ? WHERE id = ?"); + } + + @Benchmark + public void parseDelete() { + parser.parse("DELETE FROM users WHERE age < 18"); + } + + @Benchmark + public void parseCreateTable() { + parser.parse("CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR NOT NULL, price DECIMAL)"); + } + + @Benchmark + public void parseComplexExpression() { + parser.parse("SELECT * FROM orders WHERE (total > 100 AND status = 'pending') OR (total > 500 AND status = 'shipped')"); + } + + @Benchmark + public void parseScientificNotation() { + parser.parse("INSERT INTO measurements VALUES (42, 3.14159, 2.5e10)"); + } + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(ParserBenchmark.class.getSimpleName()) + .build(); + + new Runner(opt).run(); + } +} diff --git a/docs/using-the-jdbc-driver/using-plugins/UsingTheKmsEncryptionPlugin.md b/docs/using-the-jdbc-driver/using-plugins/UsingTheKmsEncryptionPlugin.md new file mode 100644 index 000000000..436f3ea94 --- /dev/null +++ b/docs/using-the-jdbc-driver/using-plugins/UsingTheKmsEncryptionPlugin.md @@ -0,0 +1,337 @@ +# Using the KMS Encryption Plugin + +The KMS Encryption Plugin provides transparent client-side encryption using AWS Key Management Service (KMS). This plugin automatically encrypts sensitive data before storing it in the database and decrypts it when retrieving data, based on metadata configuration. + +## Features + +- **Transparent Encryption**: Automatically encrypts and decrypts data without changing your application code +- **AWS KMS Integration**: Uses AWS KMS for secure key management and encryption operations +- **Metadata-Driven**: Configurable encryption based on table and column metadata +- **Audit Logging**: Optional audit logging for encryption operations +- **Minimal Performance Impact**: Efficient encryption with caching and optimized operations + +## Prerequisites + +- AWS KMS key with appropriate permissions +- Database table to store encryption metadata +- AWS credentials configured (via IAM roles, profiles, or environment variables) +- **JSqlParser 4.5.x dependency** - Required for SQL parsing and analysis + +### Creating AWS KMS Master Key + +1. **Create a KMS Key** in AWS Console or using AWS CLI: +```bash +aws kms create-key --description "Database encryption master key" --key-usage ENCRYPT_DECRYPT +``` + +2. **Note the Key ARN** from the response - you'll need this for the `kms.MasterKeyArn` property. + +3. **Set Key Permissions** - Ensure your application has the following KMS permissions: + - `kms:Encrypt` + - `kms:Decrypt` + - `kms:GenerateDataKey` + - `kms:DescribeKey` + +### Data Key Management + +The plugin automatically manages data keys: +- **Data keys are generated** automatically using the master key when encrypting new data +- **Data keys are cached** in memory for performance (configurable via `dataKeyCache.*` properties) +- **Data keys are encrypted** with the master key and stored alongside encrypted data +- **No manual data key creation** is required + +### Metadata Storage + +Create the required metadata tables to store encryption configuration: + +```sql +-- Key storage table (must be created first due to foreign key) +CREATE TABLE key_storage ( + id SERIAL PRIMARY KEY, + key_id VARCHAR(255) UNIQUE NOT NULL, + name VARCHAR(255) NOT NULL, + master_key_arn VARCHAR(512) NOT NULL, + encrypted_data_key TEXT NOT NULL, + key_spec VARCHAR(50) DEFAULT 'AES_256', + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP +); + +-- Encryption metadata table +CREATE TABLE encryption_metadata ( + table_name VARCHAR(255) NOT NULL, + column_name VARCHAR(255) NOT NULL, + encryption_algorithm VARCHAR(50) NOT NULL, + key_id INTEGER NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (table_name, column_name), + FOREIGN KEY (key_id) REFERENCES key_storage(id) +); +``` + +### Setting Up Encryption Metadata + +Use the KeyManagementUtility to properly configure encryption for your columns: + +```java +// Initialize KeyManagementUtility +KeyManagementUtility keyManagementUtility = new KeyManagementUtility( + keyManager, metadataManager, dataSource, kmsClient); + +// Configure encryption for a column +String keyId = keyManagementUtility.initializeEncryptionForColumn( + "users", // table name + "ssn", // column name + masterKeyArn, // KMS master key ARN + "AES-256-GCM" // encryption algorithm +); +``` + +**Alternative: Direct metadata insertion (not recommended for production):** +```sql +INSERT INTO encryption_metadata (table_name, column_name, encryption_algorithm, key_id) +VALUES ('users', 'ssn', 'AES-256-GCM', 'your-generated-key-id'); +``` + +### Adding JSqlParser Dependency + +The KMS Encryption Plugin requires JSqlParser 4.5.x for SQL statement analysis. Add this dependency to your project: + +**Maven:** +```xml + + com.github.jsqlparser + jsqlparser + 4.5 + +``` + +**Gradle:** +```gradle +implementation 'com.github.jsqlparser:jsqlparser:4.5' +``` + +## Configuration + +### Connection Properties + +| Property | Description | Required | Default | +|----------|-------------|----------|---------| +| `kms.region` | AWS KMS region for encryption operations | Yes | None | +| `kms.MasterKeyArn` | Master key ARN for encryption | Yes | None | +| `key.rotationDays` | Number of days for key rotation | No | `30` | +| `metadataCache.enabled` | Enable/disable metadata caching | No | `true` | +| `metadataCache.expirationMinutes` | Metadata cache expiration time in minutes | No | `60` | +| `metadataCache.refreshIntervalMs` | Metadata cache refresh interval in milliseconds | No | `300000` | +| `keyManagement.maxRetries` | Maximum number of retries for key management operations | No | `3` | +| `keyManagement.retryBackoffBaseMs` | Base backoff time in milliseconds for key management retries | No | `100` | +| `audit.loggingEnabled` | Enable/disable audit logging | No | `false` | +| `kms.connectionTimeoutMs` | KMS connection timeout in milliseconds | No | `5000` | +| `dataKeyCache.enabled` | Enable/disable data key caching | No | `true` | +| `dataKeyCache.maxSize` | Maximum size of data key cache | No | `1000` | +| `dataKeyCache.expirationMs` | Data key cache expiration in milliseconds | No | `3600000` | + +### Example Connection String + +```java +String url = "jdbc:aws-wrapper:postgresql://your-cluster.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb"; +Properties props = new Properties(); +props.setProperty("user", "username"); +props.setProperty("password", "password"); +props.setProperty("wrapperPlugins", "kmsEncryption"); +props.setProperty("kms.MasterKeyArn", "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012"); +props.setProperty("kms.region", "us-east-1"); +props.setProperty("audit.loggingEnabled", "true"); + +Connection conn = DriverManager.getConnection(url, props); +``` + +## Setup + +### 1. Create Encryption Metadata Table + +First, create the required tables to store encryption metadata and keys: + +```sql +-- Key storage table (must be created first due to foreign key) +CREATE TABLE key_storage ( + id SERIAL PRIMARY KEY, + key_id VARCHAR(255) UNIQUE NOT NULL, + name VARCHAR(255) NOT NULL, + master_key_arn VARCHAR(512) NOT NULL, + encrypted_data_key TEXT NOT NULL, + key_spec VARCHAR(50) DEFAULT 'AES_256', + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP +); + +-- Encryption metadata table +CREATE TABLE encryption_metadata ( + table_name VARCHAR(255) NOT NULL, + column_name VARCHAR(255) NOT NULL, + encryption_algorithm VARCHAR(50) NOT NULL, + key_id INTEGER NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (table_name, column_name), + FOREIGN KEY (key_id) REFERENCES key_storage(id) +); +``` + +### 2. Configure Column Encryption + +**Recommended: Use KeyManagementUtility for proper key management:** + +```java +KeyManagementUtility keyManagementUtility = new KeyManagementUtility( + keyManager, metadataManager, dataSource, kmsClient); + +// Configure encryption for sensitive columns +keyManagementUtility.initializeEncryptionForColumn("customers", "ssn", masterKeyArn); +keyManagementUtility.initializeEncryptionForColumn("customers", "credit_card", masterKeyArn); +keyManagementUtility.initializeEncryptionForColumn("customers", "phone", masterKeyArn); +keyManagementUtility.initializeEncryptionForColumn("customers", "address", masterKeyArn); +``` + +**Alternative: Direct SQL insertion (for testing only):** +```sql +-- Configure encryption for sensitive columns in the customers table +INSERT INTO encryption_metadata (table_name, column_name, encryption_algorithm, key_id) +VALUES + ('customers', 'ssn', 'AES-256-GCM', 'generated-key-id-1'), + ('customers', 'credit_card', 'AES-256-GCM', 'generated-key-id-2'), + ('customers', 'phone', 'AES-256-GCM', 'generated-key-id-3'), + ('customers', 'address', 'AES-256-GCM', 'generated-key-id-4'); +``` + +### 3. Create Your Application Tables + +Create your application tables normally: + +```sql +CREATE TABLE customers ( + customer_id SERIAL PRIMARY KEY, + first_name VARCHAR(100), + last_name VARCHAR(100), + email VARCHAR(255), + phone BYTEA, -- Will be encrypted + ssn BYTEA, -- Will be encrypted + credit_card BYTEA, -- Will be encrypted + address BYTEA, -- Will be encrypted + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); +``` + +## Usage + +Once configured, the plugin works transparently: + +```java +// Insert data - sensitive fields are automatically encrypted +String sql = "INSERT INTO customers (first_name, last_name, email, phone, ssn, credit_card, address) VALUES (?, ?, ?, ?, ?, ?, ?)"; +try (PreparedStatement stmt = connection.prepareStatement(sql)) { + stmt.setString(1, "John"); + stmt.setString(2, "Doe"); + stmt.setString(3, "john.doe@example.com"); + stmt.setString(4, "555-123-4567"); // Automatically encrypted + stmt.setString(5, "123-45-6789"); // Automatically encrypted + stmt.setString(6, "4111-1111-1111-1111"); // Automatically encrypted + stmt.setString(7, "123 Main St, City, ST 12345"); // Automatically encrypted + stmt.executeUpdate(); +} + +// Query data - encrypted fields are automatically decrypted +String query = "SELECT * FROM customers WHERE customer_id = ?"; +try (PreparedStatement stmt = connection.prepareStatement(query)) { + stmt.setInt(1, customerId); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String phone = rs.getString("phone"); // Automatically decrypted + String ssn = rs.getString("ssn"); // Automatically decrypted + String creditCard = rs.getString("credit_card"); // Automatically decrypted + String address = rs.getString("address"); // Automatically decrypted + + // Use the decrypted data normally + System.out.println("Phone: " + phone); + } + } +} +``` + +## Security Considerations + +### KMS Key Permissions + +Ensure your application has the necessary KMS permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "kms:Encrypt", + "kms:Decrypt", + "kms:GenerateDataKey" + ], + "Resource": "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012" + } + ] +} +``` + +### Data Protection + +- Encrypted data is stored as binary data in the database +- The original data never leaves your application - encryption/decryption happens locally using data keys from KMS +- Only encryption keys are managed by AWS KMS, not the actual data +- Consider using different KMS keys for different environments (dev, staging, prod) + +### Performance Considerations + +- KMS calls are only needed for data key generation/decryption, not for each data encryption/decryption +- Data key caching significantly reduces KMS API calls for repeated operations +- Consider the impact on performance for high-throughput applications during key rotation +- KMS has rate limits that may affect very high-volume key operations +- The plugin caches both metadata and data keys to minimize external calls + +## Troubleshooting + +### Common Issues + +1. **Missing KMS Permissions**: Ensure your AWS credentials have the necessary KMS permissions +2. **Metadata Table Not Found**: Verify the encryption metadata table exists and is accessible +3. **Region Mismatch**: Ensure the KMS region matches where your key is located +4. **Invalid Key ID**: Verify the KMS key ID or ARN is correct and accessible + +### Debugging + +Enable audit logging to track encryption operations: + +```java +props.setProperty("enableAuditLogging", "true"); +``` + +Check the application logs for encryption-related messages. + +## Limitations + +- Currently supports string data types for encryption +- Requires metadata configuration for each encrypted column +- Performance impact mainly during data key operations, mitigated by caching +- Limited to INSERT and UPDATE operations for automatic encryption + +## Best Practices + +1. **Use IAM Roles**: Use IAM roles instead of hardcoded credentials when possible +2. **Separate Keys**: Use different KMS keys for different environments +3. **Monitor Usage**: Monitor KMS usage and costs +4. **Test Performance**: Test the performance impact in your specific use case +5. **Backup Metadata**: Ensure the encryption metadata table is included in backups +6. **Key Rotation**: Implement a strategy for KMS key rotation + +## Example Application + +See the [KmsEncryptionExample.java](../../../examples/AWSDriverExample/src/main/java/software/amazon/KmsEncryptionExample.java) for a complete working example. diff --git a/environment.txt b/environment.txt new file mode 100644 index 000000000..848941aea --- /dev/null +++ b/environment.txt @@ -0,0 +1,2 @@ +AWS_KMS_KEY_ARN=arn:aws:kms:us-east-1:000579002577:key/d69090ec-8a8c-48ca-a1bc-36333d551e01 +TEST_ENV_INFO_JSON={"request":{"features":[]},"databaseInfo":{"username":"postgres","password":"password","defaultDbName":"postgres","clusterEndpoint":"database-1.cgnh50a2ovor.us-east-1.rds.amazonaws.com","clusterEndpointPort":5432,"instances":[]},"region":"us-east-1","databaseEngine":"postgresql"} diff --git a/gradle.properties b/gradle.properties index c1bf6b8ec..ce1f17f3c 100644 --- a/gradle.properties +++ b/gradle.properties @@ -15,7 +15,7 @@ aws-advanced-jdbc-wrapper.version.major=2 aws-advanced-jdbc-wrapper.version.minor=6 aws-advanced-jdbc-wrapper.version.subminor=4 -snapshot=false +snapshot=true nexus.publish=true org.gradle.jvmargs=-Xmx16384m -Xms8096m -XshowSettings:all diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 48873b114..01f512440 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -39,6 +39,7 @@ dependencies { optionalImplementation("software.amazon.awssdk:http-client-spi:2.33.5") // Required for IAM (light implementation) optionalImplementation("software.amazon.awssdk:sts:2.33.5") optionalImplementation("software.amazon.awssdk:secretsmanager:2.33.5") + optionalImplementation("software.amazon.awssdk:kms:2.33.5") optionalImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") optionalImplementation("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8 optionalImplementation("com.mchange:c3p0:0.11.0") @@ -49,6 +50,7 @@ dependencies { optionalImplementation("io.opentelemetry:opentelemetry-api:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") + optionalImplementation("com.github.jsqlparser:jsqlparser:4.5") // JSqlParser SQL parser (Java 8 compatible) compileOnly("org.checkerframework:checker-qual:3.49.5") compileOnly("com.mysql:mysql-connector-j:9.4.0") @@ -106,6 +108,7 @@ dependencies { testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 testImplementation("jakarta.persistence:jakarta.persistence-api:2.2.3") + testImplementation("software.amazon.awssdk:kms:2.33.5") testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.2") } @@ -434,6 +437,7 @@ tasks.register("test-all-multi-az") { tasks.register("test-all-pg-aurora") { group = "verification" filter.includeTestsMatching("integration.host.TestRunner.runTests") + filter.includeTestsMatching("integration.container.tests.KmsEncryptionIntegrationTest") doFirst { systemProperty("test-no-docker", "true") systemProperty("test-no-performance", "true") @@ -1046,3 +1050,30 @@ tasks.register("test-metrics-pg-multi-az") { systemProperty("test-no-mysql-engine", "true") } } + +tasks.register("test-kms-encryption") { + group = "verification" + filter.includeTestsMatching("integration.container.tests.KmsEncryptionPluginTest") + classpath = sourceSets.test.get().runtimeClasspath + dependsOn("jar") + systemProperty("java.util.logging.config.file", "${project.layout.buildDirectory.get()}/resources/test/logging-test.properties") + systemProperty("jdbc.drivers", "software.amazon.jdbc.Driver") +} + +tasks.register("test-kms-encryption-integration") { + group = "verification" + filter.includeTestsMatching("integration.container.tests.KmsEncryptionIntegrationTest") + classpath = sourceSets.test.get().runtimeClasspath + dependsOn("jar") + systemProperty("java.util.logging.config.file", "${project.layout.buildDirectory.get()}/resources/test/logging-test.properties") + systemProperty("jdbc.drivers", "software.amazon.jdbc.Driver") +} + +tasks.register("test-key-management-utility") { + group = "verification" + filter.includeTestsMatching("integration.container.tests.KeyManagementUtilityIntegrationTest") + classpath = sourceSets.test.get().runtimeClasspath + dependsOn("jar") + systemProperty("java.util.logging.config.file", "${project.layout.buildDirectory.get()}/resources/test/logging-test.properties") + systemProperty("jdbc.drivers", "software.amazon.jdbc.Driver") +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 411e40cd8..0f5d591fb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -42,6 +42,7 @@ import software.amazon.jdbc.plugin.customendpoint.CustomEndpointPluginFactory; import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory; import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory; +import software.amazon.jdbc.plugin.encryption.KmsEncryptionConnectionPluginFactory; import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory; import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory; import software.amazon.jdbc.plugin.federatedauth.OktaAuthPluginFactory; @@ -88,6 +89,7 @@ public class ConnectionPluginChainBuilder { put("initialConnection", new AuroraInitialConnectionStrategyPluginFactory()); put("limitless", new LimitlessConnectionPluginFactory()); put("bg", new BlueGreenConnectionPluginFactory()); + put("kmsEncryption", new KmsEncryptionConnectionPluginFactory()); } }; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 8757e2c0a..f82308ee1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -40,6 +40,7 @@ import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin; import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin; +import software.amazon.jdbc.plugin.encryption.KmsEncryptionConnectionPlugin; import software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin; import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin; import software.amazon.jdbc.plugin.federatedauth.OktaAuthPlugin; @@ -90,6 +91,7 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { put(DefaultConnectionPlugin.class, "plugin:targetDriver"); put(AuroraInitialConnectionStrategyPlugin.class, "plugin:initialConnection"); put(CustomEndpointPlugin.class, "plugin:customEndpoint"); + put(KmsEncryptionConnectionPlugin.class,"plugin.kmsEncryption"); } }; diff --git a/wrapper/src/main/java/software/amazon/jdbc/factory/EncryptingDataSourceFactory.java b/wrapper/src/main/java/software/amazon/jdbc/factory/EncryptingDataSourceFactory.java new file mode 100644 index 000000000..e6d1aedad --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/factory/EncryptingDataSourceFactory.java @@ -0,0 +1,281 @@ +package software.amazon.jdbc.factory; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.jdbc.plugin.encryption.wrapper.EncryptingDataSource; + +import javax.sql.DataSource; +import java.sql.SQLException; +import java.util.Properties; + +/** + * Factory for creating EncryptingDataSource instances that integrate with the AWS Advanced JDBC Wrapper. + * This factory provides convenient methods to wrap existing DataSources with encryption capabilities. + */ +public class EncryptingDataSourceFactory { + + private static final Logger logger = LoggerFactory.getLogger(EncryptingDataSourceFactory.class); + + /** + * Creates an EncryptingDataSource that wraps the provided DataSource with encryption capabilities. + * + * @param dataSource The underlying DataSource to wrap + * @param encryptionProperties Properties for configuring encryption + * @return An EncryptingDataSource instance + * @throws SQLException if encryption initialization fails + */ + public static EncryptingDataSource create(DataSource dataSource, Properties encryptionProperties) throws SQLException { + logger.info("Creating EncryptingDataSource with encryption properties"); + + // Validate required properties + validateEncryptionProperties(encryptionProperties); + + return new EncryptingDataSource(dataSource, encryptionProperties); + } + + /** + * Creates an EncryptingDataSource using AWS JDBC Wrapper with encryption. + * This method creates an AWS Wrapper DataSource and then wraps it with encryption. + * + * @param jdbcUrl The JDBC URL for the database + * @param username Database username + * @param password Database password + * @param encryptionProperties Properties for configuring encryption + * @return An EncryptingDataSource instance + * @throws SQLException if DataSource creation or encryption initialization fails + */ + public static EncryptingDataSource createWithAwsWrapper(String jdbcUrl, String username, String password, + Properties encryptionProperties) throws SQLException { + logger.info("Creating EncryptingDataSource with AWS JDBC Wrapper for URL: {}", jdbcUrl); + + try { + // Create properties for AWS JDBC Wrapper + Properties awsWrapperProperties = new Properties(); + awsWrapperProperties.setProperty("jdbcUrl", jdbcUrl); + awsWrapperProperties.setProperty("username", username); + awsWrapperProperties.setProperty("password", password); + + // Add any additional AWS wrapper properties from encryption properties + copyAwsWrapperProperties(encryptionProperties, awsWrapperProperties); + + // Create AWS Wrapper DataSource using reflection to avoid compile-time dependency + DataSource awsDataSource = createAwsWrapperDataSource(awsWrapperProperties); + + // Wrap with encryption + return create(awsDataSource, encryptionProperties); + + } catch (Exception e) { + logger.error("Failed to create EncryptingDataSource with AWS Wrapper", e); + throw new SQLException("Failed to create encrypted DataSource: " + e.getMessage(), e); + } + } + + /** + * Creates an EncryptingDataSource with default encryption properties. + * + * @param dataSource The underlying DataSource to wrap + * @param kmsKeyArn The KMS key ARN for encryption + * @param region The AWS region + * @return An EncryptingDataSource instance + * @throws SQLException if encryption initialization fails + */ + public static EncryptingDataSource createWithDefaults(DataSource dataSource, String kmsKeyArn, String region) throws SQLException { + Properties encryptionProperties = createDefaultEncryptionProperties(kmsKeyArn, region); + return create(dataSource, encryptionProperties); + } + + /** + * Validates that required encryption properties are present. + * + * @param properties The properties to validate + * @throws SQLException if required properties are missing + */ + private static void validateEncryptionProperties(Properties properties) throws SQLException { + if (properties == null) { + throw new SQLException("Encryption properties cannot be null"); + } + + // Check for required properties (these will be validated by EncryptionConfig) + logger.debug("Validating encryption properties"); + + // The actual validation is done by EncryptionConfig.validate() in the plugin + // We just do basic null checks here + } + + /** + * Copies AWS Wrapper specific properties from encryption properties. + * + * @param encryptionProperties Source properties + * @param awsWrapperProperties Target properties + */ + private static void copyAwsWrapperProperties(Properties encryptionProperties, Properties awsWrapperProperties) { + // Copy AWS wrapper specific properties + String[] awsWrapperKeys = { + "wrapperPlugins", + "wrapperLogUnclosedConnections", + "wrapperLoggerLevel", + "aws.region" + }; + + for (String key : awsWrapperKeys) { + String value = encryptionProperties.getProperty(key); + if (value != null) { + awsWrapperProperties.setProperty(key, value); + } + } + } + + /** + * Creates an AWS Wrapper DataSource using reflection to avoid compile-time dependency issues. + * + * @param properties Properties for the AWS Wrapper DataSource + * @return DataSource instance + * @throws Exception if DataSource creation fails + */ + private static DataSource createAwsWrapperDataSource(Properties properties) throws Exception { + try { + // Try to create AWS Wrapper DataSource using reflection + Class awsDataSourceClass = Class.forName("software.amazon.jdbc.AwsWrapperDataSource"); + return (DataSource) awsDataSourceClass.getConstructor(Properties.class).newInstance(properties); + } catch (ClassNotFoundException e) { + logger.warn("AWS JDBC Wrapper not found, falling back to direct PostgreSQL DataSource"); + return createPostgreSqlDataSource(properties); + } + } + + /** + * Creates a PostgreSQL DataSource as fallback when AWS Wrapper is not available. + * + * @param properties Properties for the DataSource + * @return DataSource instance + * @throws Exception if DataSource creation fails + */ + private static DataSource createPostgreSqlDataSource(Properties properties) throws Exception { + // Create a basic PostgreSQL DataSource + Class pgDataSourceClass = Class.forName("org.postgresql.ds.PGSimpleDataSource"); + DataSource dataSource = (DataSource) pgDataSourceClass.getDeclaredConstructor().newInstance(); + + // Set properties using reflection + String jdbcUrl = properties.getProperty("jdbcUrl"); + String username = properties.getProperty("username"); + String password = properties.getProperty("password"); + + if (jdbcUrl != null) { + // Parse URL to extract host, port, database + // This is a simplified implementation + pgDataSourceClass.getMethod("setUrl", String.class).invoke(dataSource, jdbcUrl); + } + + if (username != null) { + pgDataSourceClass.getMethod("setUser", String.class).invoke(dataSource, username); + } + + if (password != null) { + pgDataSourceClass.getMethod("setPassword", String.class).invoke(dataSource, password); + } + + return dataSource; + } + + /** + * Creates default encryption properties. + * + * @param kmsKeyArn The KMS key ARN + * @param region The AWS region + * @return Properties with default encryption settings + */ + private static Properties createDefaultEncryptionProperties(String kmsKeyArn, String region) { + Properties properties = new Properties(); + + // KMS configuration + properties.setProperty("kms.region", region != null ? region : "us-east-1"); + properties.setProperty("kms.keyArn", kmsKeyArn); + + // Cache configuration + properties.setProperty("cache.enabled", "true"); + properties.setProperty("cache.expirationMinutes", "30"); + properties.setProperty("cache.maxSize", "1000"); + + // Retry configuration + properties.setProperty("kms.maxRetries", "3"); + properties.setProperty("kms.retryBackoffBaseMs", "100"); + + // Metadata configuration + properties.setProperty("metadata.refreshIntervalMinutes", "5"); + + logger.debug("Created default encryption properties for KMS key: {}, region: {}", kmsKeyArn, region); + + return properties; + } + + /** + * Builder class for creating EncryptingDataSource with fluent API. + */ + public static class Builder { + private DataSource dataSource; + private String jdbcUrl; + private String username; + private String password; + private final Properties encryptionProperties = new Properties(); + + public Builder dataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + public Builder jdbcUrl(String jdbcUrl) { + this.jdbcUrl = jdbcUrl; + return this; + } + + public Builder username(String username) { + this.username = username; + return this; + } + + public Builder password(String password) { + this.password = password; + return this; + } + + public Builder kmsKeyArn(String kmsKeyArn) { + encryptionProperties.setProperty("kms.keyArn", kmsKeyArn); + return this; + } + + public Builder region(String region) { + encryptionProperties.setProperty("kms.region", region); + return this; + } + + public Builder cacheEnabled(boolean enabled) { + encryptionProperties.setProperty("cache.enabled", String.valueOf(enabled)); + return this; + } + + public Builder cacheExpirationMinutes(int minutes) { + encryptionProperties.setProperty("cache.expirationMinutes", String.valueOf(minutes)); + return this; + } + + public Builder cacheMaxSize(int maxSize) { + encryptionProperties.setProperty("cache.maxSize", String.valueOf(maxSize)); + return this; + } + + public Builder property(String key, String value) { + encryptionProperties.setProperty(key, value); + return this; + } + + public EncryptingDataSource build() throws SQLException { + if (dataSource != null) { + return create(dataSource, encryptionProperties); + } else if (jdbcUrl != null && username != null && password != null) { + return createWithAwsWrapper(jdbcUrl, username, password, encryptionProperties); + } else { + throw new SQLException("Either dataSource or (jdbcUrl, username, password) must be provided"); + } + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionConnectionPlugin.java new file mode 100644 index 000000000..99ba015d2 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionConnectionPlugin.java @@ -0,0 +1,251 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption; + +import java.util.logging.Logger; +import software.amazon.jdbc.*; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +/** + * ConnectionPlugin implementation that integrates KmsEncryptionPlugin with AWS JDBC Wrapper. + * This class acts as a bridge between the AWS JDBC Wrapper plugin system and our encryption functionality. + */ +public class KmsEncryptionConnectionPlugin implements ConnectionPlugin { + + private static final Logger LOGGER = Logger.getLogger(KmsEncryptionConnectionPlugin.class.getName()); + + private final KmsEncryptionPlugin encryptionPlugin; + private final PluginService pluginService; + + public static final String KMS_ENCRYPTION_PLUGIN_CODE = "kmsEncryption"; + + /** + * Constructor that creates the encryption plugin with PluginService. + * + * @param pluginService The PluginService instance from AWS JDBC Wrapper + * @param properties Configuration properties + */ + public KmsEncryptionConnectionPlugin(PluginService pluginService, Properties properties) { + this.pluginService = pluginService; + this.encryptionPlugin = new KmsEncryptionPlugin(pluginService); + + try { + this.encryptionPlugin.initialize(properties); + LOGGER.info(()->"KmsEncryptionConnectionPlugin initialized successfully"); + } catch (SQLException e) { + LOGGER.severe(()->String.format("Failed to initialize KmsEncryptionConnectionPlugin %s", e.getMessage())); + throw new RuntimeException("Failed to initialize encryption plugin", e); + } + } + + /** + * Returns the underlying encryption plugin. + * + * @return KmsEncryptionPlugin instance + */ + public KmsEncryptionPlugin getEncryptionPlugin() { + return encryptionPlugin; + } + + /** + * Executes JDBC method calls and applies encryption/decryption wrapping when needed. + * + * @param Return type + * @param Exception type + * @param methodClass Method class + * @param methodReturnType Return type class + * @param methodInvokeOn Object to invoke method on + * @param methodName Method name + * @param jdbcCallable Callable to execute + * @param args Method arguments + * @return Method result, potentially wrapped with encryption/decryption + * @throws E if method execution fails + */ + @Override + public T execute(Class methodClass, Class methodReturnType, Object methodInvokeOn, + String methodName, JdbcCallable jdbcCallable, Object... args) throws E { + // Execute the original method first + T result = jdbcCallable.call(); + + try { + // Apply encryption/decryption wrapping if needed + if (result instanceof java.sql.PreparedStatement && args.length > 0 && args[0] instanceof String) { + String sql = (String) args[0]; + @SuppressWarnings("unchecked") + T wrappedResult = (T) encryptionPlugin.wrapPreparedStatement((java.sql.PreparedStatement) result, sql); + return wrappedResult; + } else if (result instanceof java.sql.ResultSet) { + @SuppressWarnings("unchecked") + T wrappedResult = (T) encryptionPlugin.wrapResultSet((java.sql.ResultSet) result); + return wrappedResult; + } + } catch (SQLException e) { + // If E is SQLException or a superclass, we can throw it + if (methodReturnType.isAssignableFrom(SQLException.class)) { + @SuppressWarnings("unchecked") + E exception = (E) e; + throw exception; + } else { + // Otherwise wrap in RuntimeException + throw new RuntimeException("Failed to wrap JDBC object with encryption", e); + } + } + + return result; + } + + /** + * Delegates connection creation to the original function. + * + * @param driverProtocol Driver protocol + * @param hostSpec Host specification + * @param props Connection properties + * @param isInitialConnection Whether this is initial connection + * @param connectFunc Connection function + * @return Database connection + * @throws SQLException if connection fails + */ + @Override + public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, + boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { + // Delegate to the original connection function + return connectFunc.call(); + } + + /** + * Returns the set of JDBC methods this plugin subscribes to. + * + * @return Set of method names to intercept + */ + @Override + public Set getSubscribedMethods() { + // Subscribe to PreparedStatement and ResultSet creation methods + return new HashSet<>(Arrays.asList( + "Connection.prepareStatement", + "Connection.prepareCall", + "Statement.executeQuery", + "PreparedStatement.executeQuery" + )); + } + + /** + * Delegates host provider initialization to the original function. + * + * @param driverProtocol Driver protocol + * @param initialUrl Initial URL + * @param props Properties + * @param hostListProviderService Host list provider service + * @param initFunc Initialization function + * @throws SQLException if initialization fails + */ + @Override + public void initHostProvider(String driverProtocol, String initialUrl, Properties props, + HostListProviderService hostListProviderService, JdbcCallable initFunc) throws SQLException { + // Delegate to the original initialization + initFunc.call(); + } + + /** + * Handles node list change notifications (no action needed for encryption). + * + * @param changes Map of node changes + */ + @Override + public void notifyNodeListChanged(Map> changes) { + // No action needed for encryption plugin + } + + /** + * Accepts all strategies since encryption is transparent. + * + * @param role Host role + * @param strategy Strategy name + * @return Always true + */ + @Override + public boolean acceptsStrategy(HostRole role, String strategy) { + // Accept all strategies - encryption is transparent + return true; + } + + /** + * Not supported - encryption plugin does not provide host selection. + * + * @param role Host role + * @param strategy Strategy name + * @return Never returns + * @throws SQLException Always throws UnsupportedOperationException + */ + @Override + public HostSpec getHostSpecByStrategy(HostRole role, String strategy) throws SQLException { + throw new UnsupportedOperationException("Encryption plugin does not provide host selection"); + } + + + /** + * Not supported - encryption plugin does not provide host selection. + * + * @param hosts List of host specs + * @param role Host role + * @param strategy Strategy name + * @return Never returns + * @throws SQLException Always throws UnsupportedOperationException + */ + public HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strategy) throws SQLException { + throw new UnsupportedOperationException("Encryption plugin does not provide host selection"); + } + + /** + * Forces connection creation by delegating to the original function. + * + * @param driverProtocol Driver protocol + * @param hostSpec Host specification + * @param props Connection properties + * @param isInitialConnection Whether this is initial connection + * @param connectFunc Connection function + * @return Database connection + * @throws SQLException if connection fails + */ + @Override + public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, + boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { + // Delegate to the original connection function + return connectFunc.call(); + } + + /** + * Handles connection change notifications (no special action needed). + * + * @param changes Set of node change options + * @return NO_OPINION - no special action required + */ + @Override + public OldConnectionSuggestedAction notifyConnectionChanged(EnumSet changes) { + // No special action needed for connection changes + return OldConnectionSuggestedAction.NO_OPINION; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionConnectionPluginFactory.java new file mode 100644 index 000000000..fd2ff8b59 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionConnectionPluginFactory.java @@ -0,0 +1,47 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption; + +import java.util.logging.Logger; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; + +import java.util.Properties; + +/** + * Factory for creating KmsEncryptionConnectionPlugin instances. + * This factory is used by the AWS JDBC Wrapper to create plugin instances. + */ +public class KmsEncryptionConnectionPluginFactory implements ConnectionPluginFactory { + + private static final Logger LOGGER = Logger.getLogger(KmsEncryptionConnectionPluginFactory.class.getName()); + + /** + * Creates a new KmsEncryptionConnectionPlugin instance. + * + * @param pluginService The PluginService instance from AWS JDBC Wrapper + * @param properties Configuration properties for the plugin + * @return New plugin instance + */ + @Override + public ConnectionPlugin getInstance(PluginService pluginService, Properties properties) { + LOGGER.info(()->"Creating KmsEncryptionConnectionPlugin instance"); + return new KmsEncryptionConnectionPlugin(pluginService, properties); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionPlugin.java new file mode 100644 index 000000000..b1e169d8f --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/KmsEncryptionPlugin.java @@ -0,0 +1,514 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption; + + +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.plugin.encryption.factory.IndependentDataSource; +import software.amazon.jdbc.plugin.encryption.logging.AuditLogger; +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.metadata.MetadataException; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import software.amazon.jdbc.plugin.encryption.key.KeyManager; +import software.amazon.jdbc.plugin.encryption.sql.SqlAnalysisService; +import software.amazon.jdbc.plugin.encryption.wrapper.EncryptingPreparedStatement; +import software.amazon.jdbc.plugin.encryption.service.EncryptionService; +import software.amazon.jdbc.plugin.encryption.wrapper.DecryptingResultSet; + +import java.util.logging.Logger; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kms.KmsClient; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Main encryption plugin that integrates with the AWS Advanced JDBC Wrapper + * to provide transparent client-side encryption using AWS KMS. + * + * This plugin intercepts JDBC operations to automatically encrypt data before storage + * and decrypt data upon retrieval based on metadata configuration. + */ +public class KmsEncryptionPlugin { + + private static final Logger LOGGER = Logger.getLogger(KmsEncryptionPlugin.class.getName()); + + // Plugin configuration + private EncryptionConfig config; + private MetadataManager metadataManager; + private KeyManager keyManager; + private EncryptionService encryptionService; + private KmsClient kmsClient; + + // Plugin services + private PluginService pluginService; + private IndependentDataSource independentDataSource; + + // SQL Analysis + private SqlAnalysisService sqlAnalysisService; + + // Monitoring and metrics + private AuditLogger auditLogger; + + // Plugin lifecycle state + private final AtomicBoolean initialized = new AtomicBoolean(false); + private final AtomicBoolean closed = new AtomicBoolean(false); + + // Track connections where custom types have been registered + private final java.util.Map registeredConnections = + new java.util.WeakHashMap<>(); + + // Plugin properties + private Properties pluginProperties; + + /** + * Constructor that accepts PluginService for integration with AWS JDBC Wrapper. + * + * @param pluginService The PluginService instance from AWS JDBC Wrapper + */ + public KmsEncryptionPlugin(PluginService pluginService) { + this.pluginService = pluginService; + LOGGER.fine(() -> String.format("KmsEncryptionPlugin created with PluginService: %s", pluginService != null ? "available" : "null")); + } + + /** + * Default constructor for backward compatibility. + */ + public KmsEncryptionPlugin() { + this.pluginService = null; + LOGGER.warning("KmsEncryptionPlugin created without PluginService - connection parameter extraction may fail"); + } + + /** + * Sets the PluginService instance. This method can be called to provide + * the PluginService after construction if it wasn't available during construction. + * + * @param pluginService The PluginService instance from AWS JDBC Wrapper + */ + public void setPluginService(PluginService pluginService) { + if (this.pluginService == null) { + this.pluginService = pluginService; + LOGGER.info(() -> String.format("PluginService set after construction: %s", pluginService != null ? "available" : "null")); + } else { + LOGGER.warning("PluginService already set, ignoring new instance"); + } + } + + /** + * Initializes the plugin with the provided configuration. + * This method is called by the AWS JDBC Wrapper during plugin loading. + * + * @param properties Configuration properties for the plugin + * @throws SQLException if initialization fails + */ + public void initialize(Properties properties) throws SQLException { + if (initialized.get()) { + LOGGER.warning("Plugin already initialized, skipping re-initialization"); + return; + } + + LOGGER.info("Initializing KmsEncryptionPlugin"); + + try { + // Store properties for later use + this.pluginProperties = new Properties(); + this.pluginProperties.putAll(properties); + + // Load and validate configuration + this.config = loadConfiguration(properties); + config.validate(); + + // Initialize AWS KMS client + this.kmsClient = createKmsClient(config); + + // Initialize core services + this.encryptionService = new EncryptionService(); + + // Initialize audit LOGGER + this.auditLogger = new AuditLogger(config.isAuditLoggingEnabled()); + + LOGGER.info("KmsEncryptionPlugin initialized successfully"); + initialized.set(true); + + } catch (Exception e) { + LOGGER.severe(() -> String.format("Failed to initialize KmsEncryptionPlugin %s", e.getMessage())); + throw new SQLException("Plugin initialization failed: " + e.getMessage(), e); + } + } + + /** + * Initializes plugin components that require a database connection. + * This method uses PluginService to get connection parameters instead of extraction. + * + * @throws SQLException if initialization fails + */ + private void initializeWithDataSource() throws SQLException { + if (metadataManager != null) { + return; // Already initialized + } + + try { + if (pluginService != null) { + // Create independent DataSource using PluginService + this.independentDataSource = new IndependentDataSource(pluginService, pluginProperties); + + // Log success + auditLogger.logConnectionParameterExtraction("PluginService", "PLUGIN_SERVICE", true, null); + + // Initialize managers with PluginService + this.keyManager = new KeyManager(kmsClient, pluginService, config); + this.metadataManager = new MetadataManager(pluginService, config); + metadataManager.initialize(); + + // Initialize SQL analysis service + this.sqlAnalysisService = new SqlAnalysisService(pluginService, metadataManager); + + LOGGER.info("Plugin initialized with PluginService connection parameters"); + + } else { + LOGGER.severe("PluginService not available - cannot create independent connections"); + + auditLogger.logConnectionParameterExtraction("PluginService", "PLUGIN_SERVICE", false, "PluginService not available"); + + throw new SQLException("PluginService not available - cannot create independent connections"); + } + + } catch (MetadataException e) { + LOGGER.severe(()->String.format("Failed to initialize plugin components with database %s", e.getMessage())); + throw new SQLException("Failed to initialize plugin with database: " + e.getMessage(), e); + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to initialize plugin with PluginService %s", e.getMessage())); + throw new SQLException("Failed to initialize plugin: " + e.getMessage(), e); + } + } + + /** + * Registers custom PostgreSQL types with the JDBC driver for a specific connection. + * Only registers once per connection. + */ + private void registerPostgresTypesForConnection(java.sql.Connection conn) { + if (conn == null) { + return; + } + + synchronized (registeredConnections) { + if (registeredConnections.containsKey(conn)) { + return; // Already registered for this connection + } + + try { + org.postgresql.PGConnection pgConn = conn.unwrap(org.postgresql.PGConnection.class); + pgConn.addDataType("encrypted_data", software.amazon.jdbc.plugin.encryption.wrapper.EncryptedData.class); + registeredConnections.put(conn, Boolean.TRUE); + LOGGER.fine("Registered encrypted_data type for connection"); + } catch (Exception e) { + LOGGER.fine(() -> "Failed to register PostgreSQL custom types: " + e.getMessage()); + } + } + } + + /** + * Wraps a PreparedStatement to add encryption capabilities. + * + * @param statement The original PreparedStatement + * @param sql The SQL statement + * @return Wrapped PreparedStatement with encryption support + * @throws SQLException if wrapping fails + */ + public PreparedStatement wrapPreparedStatement(PreparedStatement statement, String sql) + throws SQLException { + if (!initialized.get()) { + throw new SQLException("Plugin not initialized"); + } + + // Initialize with DataSource if needed (lazy initialization) + if (metadataManager == null) { + try { + initializeWithDataSource(); + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to initialize plugin with connection %s", e.getMessage())); + throw new SQLException("Failed to initialize plugin: " + e.getMessage(), e); + } + } + + // Register custom types for this connection + registerPostgresTypesForConnection(statement.getConnection()); + + LOGGER.fine(()->String.format("Wrapping PreparedStatement for SQL: %s", sql)); + + // Analyze SQL to determine if encryption is needed + SqlAnalysisService.SqlAnalysisResult analysisResult; + if (sqlAnalysisService != null) { + analysisResult = sqlAnalysisService.analyzeSql(sql); + LOGGER.fine(()->String.format("SQL analysis result: %s", analysisResult)); + } else { + analysisResult = null; + } + + return new EncryptingPreparedStatement( + statement, + metadataManager, + encryptionService, + keyManager, + sqlAnalysisService, + sql + ); + } + + /** + * Wraps a ResultSet to add decryption capabilities. + * + * @param resultSet The original ResultSet + * @return Wrapped ResultSet with decryption support + * @throws SQLException if wrapping fails + */ + public ResultSet wrapResultSet(ResultSet resultSet) throws SQLException { + if (!initialized.get()) { + throw new SQLException("Plugin not initialized"); + } + + // Initialize with DataSource if needed (lazy initialization) + if (metadataManager == null) { + try { + initializeWithDataSource(); + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to initialize plugin with connection %s", e.getMessage())); + throw new SQLException("Failed to initialize plugin: " + e.getMessage(), e); + } + } + + // Register custom types for this connection + try { + registerPostgresTypesForConnection(resultSet.getStatement().getConnection()); + } catch (Exception e) { + LOGGER.fine(() -> "Could not register types for ResultSet connection: " + e.getMessage()); + } + + LOGGER.finest(()->"Wrapping ResultSet"); + + return new DecryptingResultSet( + resultSet, + metadataManager, + encryptionService, + keyManager + ); + } + + /** + * Returns the plugin name for identification. + * + * @return Plugin name + */ + public String getPluginName() { + return "KmsEncryptionPlugin"; + } + + /** + * Cleans up plugin resources. + * This method is called when the plugin is being unloaded. + */ + public void cleanup() { + if (closed.get()) { + return; + } + + LOGGER.info("Cleaning up KmsEncryptionPlugin resources"); + + // Log final connection status + if (independentDataSource != null) { + try { + independentDataSource.logHealthStatus(); + } catch (Exception e) { + LOGGER.warning(()->String.format("Error logging final DataSource health status %s", e.getMessage())); + } + } + + try { + if (kmsClient != null) { + kmsClient.close(); + } + } catch (Exception e) { + LOGGER.warning(()->String.format("Error closing KMS client %s", e.getMessage())); + } + + closed.set(true); + LOGGER.info("KmsEncryptionPlugin cleanup completed"); + } + + /** + * Loads configuration from properties. + * + * @param properties Configuration properties + * @return EncryptionConfig instance + * @throws SQLException if configuration is invalid + */ + private EncryptionConfig loadConfiguration(Properties properties) throws SQLException { + try { + // Set default region if not provided + if (!properties.containsKey("kms.region")) { + properties.setProperty("kms.region", "us-east-1"); + } + + EncryptionConfig config = EncryptionConfig.fromProperties(properties); + + LOGGER.info(()->String.format("Loaded encryption configuration: region=%s, cacheEnabled=%s, maxRetries=%s", + config.getKmsRegion(), config.isCacheEnabled(), config.getMaxRetries())); + + return config; + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to load configuration from properties %s", e.getMessage())); + throw new SQLException("Invalid configuration: " + e.getMessage(), e); + } + } + + /** + * Creates a KMS client with the specified configuration. + * + * @param config Encryption configuration + * @return Configured KMS client + */ + private KmsClient createKmsClient(EncryptionConfig config) { + LOGGER.fine(()->String.format("Creating KMS client for region: %s", config.getKmsRegion())); + + return KmsClient.builder() + .region(Region.of(config.getKmsRegion())) + .build(); + } + + + // Getters for testing and monitoring + + /** + * Returns the current configuration. + * + * @return EncryptionConfig instance + */ + public EncryptionConfig getConfig() { + return config; + } + + /** + * Returns the metadata manager. + * + * @return MetadataManager instance + */ + public MetadataManager getMetadataManager() { + return metadataManager; + } + + /** + * Returns the key manager. + * + * @return KeyManager instance + */ + public KeyManager getKeyManager() { + return keyManager; + } + + /** + * Returns the encryption service. + * + * @return EncryptionService instance + */ + public EncryptionService getEncryptionService() { + return encryptionService; + } + + /** + * Checks if the plugin is initialized. + * + * @return true if initialized, false otherwise + */ + public boolean isInitialized() { + return initialized.get(); + } + + /** + * Checks if the plugin is closed. + * + * @return true if closed, false otherwise + */ + public boolean isClosed() { + return closed.get(); + } + + /** + * Returns the plugin service. + * + * @return PluginService instance + */ + public PluginService getPluginService() { + return pluginService; + } + + /** + * Returns the independent DataSource used by MetadataManager. + * + * @return IndependentDataSource instance, or null if not initialized + */ + public IndependentDataSource getIndependentDataSource() { + return independentDataSource; + } + + /** + * Checks if the plugin is using independent connections. + * + * @return true if using independent connections, false otherwise + */ + public boolean isUsingIndependentConnections() { + return independentDataSource != null; + } + + + /** + * Creates a detailed status message about the current connection mode. + * + * @return a comprehensive status message + */ + public String getConnectionModeStatus() { + if (isUsingIndependentConnections()) { + return "Plugin is using independent connections via PluginService"; + } else { + return "Plugin connection mode is not yet determined"; + } + } + + /** + * Logs the current connection status and performance metrics. + * This method can be called for troubleshooting purposes. + */ + public void logCurrentStatus() { + LOGGER.info("=== KmsEncryptionPlugin Status Report ==="); + + // Log connection mode status + LOGGER.info(()->String.format("Connection Mode: %s", getConnectionModeStatus())); + + // Log DataSource health + if (independentDataSource != null) { + independentDataSource.logHealthStatus(); + } else { + LOGGER.info("Independent DataSource: Not configured"); + } + + LOGGER.info("=== End Status Report ==="); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/cache/DataKeyCache.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/cache/DataKeyCache.java new file mode 100644 index 000000000..a9d7b4f0e --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/cache/DataKeyCache.java @@ -0,0 +1,368 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.cache; + +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import java.util.logging.Logger; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * Thread-safe cache for data keys with configurable expiration and size limits. + * Provides metrics for cache performance monitoring. + */ +public class DataKeyCache { + + private static final Logger LOGGER = Logger.getLogger(DataKeyCache.class.getName()); + + private final Map cache; + private final ReadWriteLock cacheLock; + private final ScheduledExecutorService cleanupExecutor; + private final EncryptionConfig config; + + // Metrics + private final AtomicLong hitCount = new AtomicLong(0); + private final AtomicLong missCount = new AtomicLong(0); + private final AtomicLong evictionCount = new AtomicLong(0); + + public DataKeyCache(EncryptionConfig config) { + this.config = config; + this.cache = new ConcurrentHashMap<>(); + this.cacheLock = new ReentrantReadWriteLock(); + this.cleanupExecutor = Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "DataKeyCache-Cleanup"); + t.setDaemon(true); + return t; + }); + + // Schedule periodic cleanup of expired entries + long cleanupIntervalMs = Math.max(config.getDataKeyCacheExpiration().toMillis() / 4, 30000); + cleanupExecutor.scheduleAtFixedRate(this::cleanupExpiredEntries, + cleanupIntervalMs, cleanupIntervalMs, TimeUnit.MILLISECONDS); + + LOGGER.info(()->String.format("DataKeyCache initialized with maxSize=%s, expiration=%s, cleanupInterval=%sms", + config.getDataKeyCacheMaxSize(), config.getDataKeyCacheExpiration(), cleanupIntervalMs)); + } + + /** + * Retrieves a data key from the cache. + * + * @param keyId the key identifier + * @return decrypted data key bytes, or null if not found or expired + */ + public byte[] get(String keyId) { + if (!config.isDataKeyCacheEnabled() || keyId == null) { + return null; + } + + cacheLock.readLock().lock(); + try { + CacheEntry entry = cache.get(keyId); + if (entry == null) { + missCount.incrementAndGet(); + LOGGER.finest(()->String.format("Cache miss for key: %s", keyId)); + return null; + } + + if (entry.isExpired(config.getDataKeyCacheExpiration())) { + missCount.incrementAndGet(); + LOGGER.finest(()->String.format("Cache entry expired for key: %s", keyId)); + // Remove expired entry (will be cleaned up by background thread) + return null; + } + + hitCount.incrementAndGet(); + LOGGER.finest(()->String.format("Cache hit for key: %s", keyId)); + return entry.getDataKey(); + + } finally { + cacheLock.readLock().unlock(); + } + } + + /** + * Stores a data key in the cache. + * + * @param keyId the key identifier + * @param dataKey the decrypted data key bytes + */ + public void put(String keyId, byte[] dataKey) { + if (!config.isDataKeyCacheEnabled() || keyId == null || dataKey == null) { + return; + } + + cacheLock.writeLock().lock(); + try { + // Check if we need to evict entries to make room + if (cache.size() >= config.getDataKeyCacheMaxSize()) { + evictOldestEntry(); + } + + CacheEntry entry = new CacheEntry(dataKey.clone()); + cache.put(keyId, entry); + + LOGGER.finest(()->String.format("Cached data key for: %s", keyId)); + + } finally { + cacheLock.writeLock().unlock(); + } + } + + /** + * Removes a specific key from the cache. + * + * @param keyId the key identifier to remove + */ + public void remove(String keyId) { + if (keyId == null) { + return; + } + + cacheLock.writeLock().lock(); + try { + CacheEntry removed = cache.remove(keyId); + if (removed != null) { + removed.clear(); + LOGGER.finest(()->String.format("Removed key from cache: %s", keyId)); + } + } finally { + cacheLock.writeLock().unlock(); + } + } + + /** + * Clears all entries from the cache. + */ + public void clear() { + cacheLock.writeLock().lock(); + try { + // Clear sensitive data before removing entries + cache.values().forEach(CacheEntry::clear); + cache.clear(); + LOGGER.info("Cache cleared"); + } finally { + cacheLock.writeLock().unlock(); + } + } + + /** + * Returns cache statistics. + * + * @return CacheStats object with current metrics + */ + public CacheStats getStats() { + cacheLock.readLock().lock(); + try { + return new CacheStats( + cache.size(), + hitCount.get(), + missCount.get(), + evictionCount.get(), + calculateHitRate()); + } finally { + cacheLock.readLock().unlock(); + } + } + + /** + * Shuts down the cache and cleans up resources. + */ + public void shutdown() { + LOGGER.info("Shutting down DataKeyCache"); + + cleanupExecutor.shutdown(); + try { + if (!cleanupExecutor.awaitTermination(5, TimeUnit.SECONDS)) { + cleanupExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + cleanupExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + + clear(); + } + + /** + * Removes expired entries from the cache. + */ + private void cleanupExpiredEntries() { + if (!config.isDataKeyCacheEnabled()) { + return; + } + + cacheLock.writeLock().lock(); + try { + Duration expiration = config.getDataKeyCacheExpiration(); + int removedCount = 0; + + Iterator> iterator = cache.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (entry.getValue().isExpired(expiration)) { + entry.getValue().clear(); + iterator.remove(); + removedCount++; + } + } + + if (removedCount > 0) { + int finalRemovedCount = removedCount; + LOGGER.finest(()->String.format("Cleaned up %d expired cache entries", finalRemovedCount)); + } + + } finally { + cacheLock.writeLock().unlock(); + } + } + + /** + * Evicts the oldest entry from the cache to make room for new entries. + */ + private void evictOldestEntry() { + if (cache.isEmpty()) { + return; + } + + // Find the oldest entry + String oldestKey = null; + Instant oldestTime = Instant.MAX; + + for (Map.Entry entry : cache.entrySet()) { + if (entry.getValue().getCreatedAt().isBefore(oldestTime)) { + oldestTime = entry.getValue().getCreatedAt(); + oldestKey = entry.getKey(); + } + } + + if (oldestKey != null) { + CacheEntry removed = cache.remove(oldestKey); + if (removed != null) { + removed.clear(); + evictionCount.incrementAndGet(); + String finalOldestKey = oldestKey; + LOGGER.finest(()->String.format("Evicted oldest cache entry: %s", finalOldestKey)); + } + } + } + + /** + * Calculates the current cache hit rate. + */ + private double calculateHitRate() { + long hits = hitCount.get(); + long misses = missCount.get(); + long total = hits + misses; + + return total > 0 ? (double) hits / total : 0.0; + } + + /** + * Cache entry wrapper that tracks creation time and provides secure cleanup. + */ + private static class CacheEntry { + private final byte[] dataKey; + private final Instant createdAt; + private volatile boolean cleared = false; + + public CacheEntry(byte[] dataKey) { + this.dataKey = dataKey; + this.createdAt = Instant.now(); + } + + public byte[] getDataKey() { + if (cleared) { + return null; + } + return dataKey.clone(); // Return copy for security + } + + public Instant getCreatedAt() { + return createdAt; + } + + public boolean isExpired(Duration expiration) { + return Instant.now().isAfter(createdAt.plus(expiration)); + } + + public void clear() { + if (!cleared && dataKey != null) { + Arrays.fill(dataKey, (byte) 0); + cleared = true; + } + } + } + + /** + * Cache statistics data class. + */ + public static class CacheStats { + private final int size; + private final long hitCount; + private final long missCount; + private final long evictionCount; + private final double hitRate; + + public CacheStats(int size, long hitCount, long missCount, long evictionCount, double hitRate) { + this.size = size; + this.hitCount = hitCount; + this.missCount = missCount; + this.evictionCount = evictionCount; + this.hitRate = hitRate; + } + + public int getSize() { + return size; + } + + public long getHitCount() { + return hitCount; + } + + public long getMissCount() { + return missCount; + } + + public long getEvictionCount() { + return evictionCount; + } + + public double getHitRate() { + return hitRate; + } + + @Override + public String toString() { + return String.format("CacheStats{size=%d, hits=%d, misses=%d, evictions=%d, hitRate=%.2f%%}", + size, hitCount, missCount, evictionCount, hitRate * 100); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/AwsWrapperEncryptionExample.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/AwsWrapperEncryptionExample.java new file mode 100644 index 000000000..fedbee002 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/AwsWrapperEncryptionExample.java @@ -0,0 +1,290 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.example; + +import software.amazon.jdbc.factory.EncryptingDataSourceFactory; +import java.util.logging.Logger; +import software.amazon.jdbc.plugin.encryption.wrapper.EncryptingDataSource; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Properties; + +/** + * Example demonstrating how to use the encryption functionality with AWS Advanced JDBC Wrapper. + * This example shows different ways to configure and use encrypted database connections. + */ +public class AwsWrapperEncryptionExample { + + private static final Logger LOGGER = Logger.getLogger(AwsWrapperEncryptionExample.class.getName()); + + public static void main(String[] args) { + try { + // Example 1: Using builder pattern + demonstrateBuilderPattern(); + + // Example 2: Using factory with properties + demonstrateFactoryWithProperties(); + + // Example 3: Using existing DataSource + demonstrateWrappingExistingDataSource(); + + } catch (Exception e) { + LOGGER.severe(()->String.format("Example execution failed", e)); + } + } + + /** + * Demonstrates using the builder pattern to create an encrypted DataSource. + */ + private static void demonstrateBuilderPattern() throws SQLException { + LOGGER.info("=== Builder Pattern Example ==="); + + EncryptingDataSource dataSource = new EncryptingDataSourceFactory.Builder() + .jdbcUrl("jdbc:postgresql://localhost:5432/mydb") + .username("myuser") + .password("mypassword") + .kmsKeyArn("arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012") + .region("us-east-1") + .cacheEnabled(true) + .cacheExpirationMinutes(30) + .cacheMaxSize(1000) + .build(); + + // Use the DataSource + performDatabaseOperations(dataSource, "Builder Pattern"); + + // Clean up + dataSource.close(); + } + + /** + * Demonstrates using the factory with explicit properties. + */ + private static void demonstrateFactoryWithProperties() throws SQLException { + LOGGER.info("=== Factory with Properties Example ==="); + + Properties encryptionProperties = new Properties(); + + // KMS configuration + encryptionProperties.setProperty("kms.keyArn", "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012"); + encryptionProperties.setProperty("kms.region", "us-east-1"); + + // Cache configuration + encryptionProperties.setProperty("cache.enabled", "true"); + encryptionProperties.setProperty("cache.expirationMinutes", "30"); + encryptionProperties.setProperty("cache.maxSize", "1000"); + + // Retry configuration + encryptionProperties.setProperty("kms.maxRetries", "3"); + encryptionProperties.setProperty("kms.retryBackoffBaseMs", "100"); + + // AWS Wrapper configuration (optional) + encryptionProperties.setProperty("wrapperLogUnclosedConnections", "true"); + encryptionProperties.setProperty("wrapperLoggerLevel", "INFO"); + + EncryptingDataSource dataSource = EncryptingDataSourceFactory.createWithAwsWrapper( + "jdbc:postgresql://localhost:5432/mydb", + "myuser", + "mypassword", + encryptionProperties + ); + + // Use the DataSource + performDatabaseOperations(dataSource, "Factory with Properties"); + + // Clean up + dataSource.close(); + } + + /** + * Demonstrates wrapping an existing DataSource with encryption. + */ + private static void demonstrateWrappingExistingDataSource() throws SQLException { + LOGGER.info("=== Wrapping Existing DataSource Example ==="); + + // Create an existing DataSource (this could be from a connection pool, etc.) + DataSource existingDataSource = createExistingDataSource(); + + // Wrap it with encryption + EncryptingDataSource encryptingDataSource = EncryptingDataSourceFactory.createWithDefaults( + existingDataSource, + "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012", + "us-east-1" + ); + + // Use the encrypted DataSource + performDatabaseOperations(encryptingDataSource, "Wrapped Existing DataSource"); + + // Clean up + encryptingDataSource.close(); + } + + /** + * Performs sample database operations to demonstrate encryption/decryption. + */ + private static void performDatabaseOperations(DataSource dataSource, String exampleName) { + LOGGER.info(()->String.format("Performing database operations for: %s", exampleName)); + + try (Connection connection = dataSource.getConnection()) { + + // Create test table (if not exists) + createTestTable(connection); + + // Insert encrypted data + insertTestData(connection); + + // Query and decrypt data + queryTestData(connection); + + LOGGER.info(()->String.format("Database operations completed successfully for: %s", exampleName)); + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Database operations failed for: " + exampleName, e)); + } + } + + /** + * Creates a test table for demonstration. + */ + private static void createTestTable(Connection connection) throws SQLException { + String createTableSql = "CREATE TABLE IF NOT EXISTS test_users (" + + "id SERIAL PRIMARY KEY, " + + "name VARCHAR(100) NOT NULL, " + + "email VARCHAR(100), " + + "ssn VARCHAR(20), " + + "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP" + + ")"; + + try (PreparedStatement stmt = connection.prepareStatement(createTableSql)) { + stmt.executeUpdate(); + LOGGER.finest(()->"Test table created or already exists"); + } + } + + /** + * Inserts test data that will be automatically encrypted for configured columns. + */ + private static void insertTestData(Connection connection) throws SQLException { + String insertSql = "INSERT INTO test_users (name, email, ssn) VALUES (?, ?, ?)"; + + try (PreparedStatement stmt = connection.prepareStatement(insertSql)) { + // Insert first user + stmt.setString(1, "John Doe"); + stmt.setString(2, "john.doe@example.com"); // Will be encrypted if configured + stmt.setString(3, "123-45-6789"); // Will be encrypted if configured + stmt.executeUpdate(); + + // Insert second user + stmt.setString(1, "Jane Smith"); + stmt.setString(2, "jane.smith@example.com"); // Will be encrypted if configured + stmt.setString(3, "987-65-4321"); // Will be encrypted if configured + stmt.executeUpdate(); + + LOGGER.info("Inserted test data with automatic encryption"); + } + } + + /** + * Queries test data that will be automatically decrypted for configured columns. + */ + private static void queryTestData(Connection connection) throws SQLException { + String selectSql = "SELECT id, name, email, ssn FROM test_users ORDER BY id"; + + try (PreparedStatement stmt = connection.prepareStatement(selectSql); + ResultSet rs = stmt.executeQuery()) { + + LOGGER.info("Querying test data with automatic decryption:"); + + while (rs.next()) { + int id = rs.getInt("id"); + String name = rs.getString("name"); + String email = rs.getString("email"); // Will be decrypted if configured + String ssn = rs.getString("ssn"); // Will be decrypted if configured + + LOGGER.info(()->String.format("User %s: Name=%s, Email=%s, SSN=%s", id, name, email, ssn)); + } + } + } + + /** + * Creates a sample existing DataSource for demonstration. + * In a real application, this might come from a connection pool or dependency injection. + */ + private static DataSource createExistingDataSource() { + // This is a simplified example - in practice you might use HikariCP, etc. + return new DataSource() { + @Override + public Connection getConnection() throws SQLException { + return java.sql.DriverManager.getConnection( + "jdbc:postgresql://localhost:5432/mydb", + "myuser", + "mypassword" + ); + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + return java.sql.DriverManager.getConnection( + "jdbc:postgresql://localhost:5432/mydb", + username, + password + ); + } + + // Other DataSource methods with default implementations + @Override + public java.io.PrintWriter getLogWriter() throws SQLException { + return null; + } + + @Override + public void setLogWriter(java.io.PrintWriter out) throws SQLException { + // No-op + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + // No-op + } + + @Override + public int getLoginTimeout() throws SQLException { + return 0; + } + + @Override + public java.util.logging.Logger getParentLogger() { + return java.util.logging.Logger.getLogger("javax.sql.DataSource"); + } + + @Override + public T unwrap(Class iface) throws SQLException { + throw new SQLException("Cannot unwrap to " + iface.getName()); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return false; + } + }; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/DataSourceLifecycleExample.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/DataSourceLifecycleExample.java new file mode 100644 index 000000000..270020c55 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/DataSourceLifecycleExample.java @@ -0,0 +1,253 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.example; + +import software.amazon.jdbc.factory.EncryptingDataSourceFactory; +import java.util.logging.Logger; +import software.amazon.jdbc.plugin.encryption.wrapper.EncryptingDataSource; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; + +/** + * Example demonstrating proper DataSource lifecycle management with encryption. + * Shows how to handle connection failures and DataSource state management. + */ +public class DataSourceLifecycleExample { + + private static final Logger LOGGER = Logger.getLogger(DataSourceLifecycleExample.class.getName()); + + public static void main(String[] args) { + EncryptingDataSource dataSource = null; + + try { + // Create the DataSource + dataSource = createDataSource(); + + // Demonstrate proper usage patterns + demonstrateHealthyUsage(dataSource); + + // Demonstrate error handling + demonstrateErrorHandling(dataSource); + + // Demonstrate lifecycle management + demonstrateLifecycleManagement(dataSource); + + } catch (Exception e) { + LOGGER.severe(()->String.format("Example execution failed %s", e.getMessage())); + } finally { + // Always clean up resources + if (dataSource != null) { + dataSource.close(); + LOGGER.info("DataSource closed in finally block"); + } + } + } + + /** + * Creates an EncryptingDataSource for demonstration. + */ + private static EncryptingDataSource createDataSource() throws SQLException { + LOGGER.info("=== Creating EncryptingDataSource ==="); + + EncryptingDataSource dataSource = new EncryptingDataSourceFactory.Builder() + .jdbcUrl("jdbc:postgresql://localhost:5432/mydb") + .username("myuser") + .password("mypassword") + .kmsKeyArn("arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012") + .region("us-east-1") + .cacheEnabled(true) + .build(); + + LOGGER.info("EncryptingDataSource created successfully"); + return dataSource; + } + + /** + * Demonstrates healthy DataSource usage patterns. + */ + private static void demonstrateHealthyUsage(EncryptingDataSource dataSource) { + LOGGER.info("=== Demonstrating Healthy Usage ==="); + + // Check if DataSource is available before using + if (!dataSource.isConnectionAvailable()) { + LOGGER.warning("DataSource is not available - skipping operations"); + return; + } + + // Use try-with-resources for proper connection management + try (Connection connection = dataSource.getConnection()) { + LOGGER.info(()->String.format("Successfully obtained connection: %s", connection.getClass().getSimpleName())); + + // Verify connection is valid + if (connection.isValid(5)) { + LOGGER.info(()->"Connection is valid"); + } else { + LOGGER.warning(()->"Connection is not valid"); + } + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Failed to get or use connection %s", e.getMessage())); + } + } + + /** + * Demonstrates error handling patterns. + */ + private static void demonstrateErrorHandling(EncryptingDataSource dataSource) { + LOGGER.info(()->"=== Demonstrating Error Handling ==="); + + // Attempt to get multiple connections to test resilience + for (int i = 0; i < 3; i++) { + try (Connection connection = dataSource.getConnection()) { + int finalI = i; + LOGGER.info(()->String.format("Connection attempt %d: Success", finalI + 1)); + + // Simulate some work + Thread.sleep(100); + + } catch (SQLException e) { + int finalI1 = i; + LOGGER.severe(()->String.format("Connection attempt %s failed: %s", finalI1 + 1, e.getMessage())); + + // Check if DataSource is still healthy + if (!dataSource.isConnectionAvailable()) { + LOGGER.severe("DataSource is no longer available - stopping attempts"); + break; + } + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + } + + /** + * Demonstrates DataSource lifecycle management. + */ + private static void demonstrateLifecycleManagement(EncryptingDataSource dataSource) { + LOGGER.info("=== Demonstrating Lifecycle Management ==="); + + // Check initial state + LOGGER.info(()->String.format("DataSource closed: %s", dataSource.isClosed())); + LOGGER.info(()->String.format("Connection available: %s", dataSource.isConnectionAvailable())); + + // Get a connection before closing + try (Connection connection = dataSource.getConnection()) { + LOGGER.info(()->String.format("Got connection before close: %s", connection.getClass().getSimpleName())); + } catch (SQLException e) { + LOGGER.severe(()->String.format("Failed to get connection before close %s", e.getMessage())); + } + + // Close the DataSource + dataSource.close(); + LOGGER.info(()->String.format("DataSource closed: %s", dataSource.isClosed())); + LOGGER.info(()->String.format("Connection available after close: %s", dataSource.isConnectionAvailable())); + + // Try to get connection after close (should fail) + try (Connection connection = dataSource.getConnection()) { + LOGGER.severe(()->"Unexpectedly got connection after close!"); + } catch (SQLException e) { + LOGGER.info(()->String.format("Expected failure getting connection after close: %s", e.getMessage())); + } + + // Multiple close calls should be safe + dataSource.close(); + dataSource.close(); + LOGGER.info(()->"Multiple close calls completed safely"); + } + + /** + * Demonstrates connection validation and recovery patterns. + * + * @param originalDataSource Original data source to wrap + */ + public static void demonstrateConnectionRecovery(DataSource originalDataSource) { + LOGGER.info(()->"=== Demonstrating Connection Recovery ==="); + + EncryptingDataSource dataSource = null; + + try { + // Wrap the original DataSource + dataSource = EncryptingDataSourceFactory.createWithDefaults( + originalDataSource, + "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012", + "us-east-1" + ); + + // Implement retry logic for connection failures + Connection connection = getConnectionWithRetry(dataSource, 3, 1000); + + if (connection != null) { + try (Connection conn = connection) { + LOGGER.info(()->"Successfully recovered connection"); + } + } else { + LOGGER.severe(()->"Failed to recover connection after retries"); + } + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Connection recovery demonstration failed %s", e.getMessage())); + } finally { + if (dataSource != null) { + dataSource.close(); + } + } + } + + /** + * Attempts to get a connection with retry logic. + */ + private static Connection getConnectionWithRetry(EncryptingDataSource dataSource, int maxRetries, long delayMs) { + for (int attempt = 1; attempt <= maxRetries; attempt++) { + int finalAttempt = attempt; + try { + LOGGER.info(()->String.format("Connection attempt %s of %s", finalAttempt, maxRetries)); + + if (!dataSource.isConnectionAvailable()) { + LOGGER.warning(()->String.format("DataSource not available on attempt %s", finalAttempt)); + Thread.sleep(delayMs); + continue; + } + + Connection connection = dataSource.getConnection(); + LOGGER.info(()->String.format("Successfully got connection on attempt %s", finalAttempt)); + return connection; + + } catch (SQLException e) { + LOGGER.warning(()->String.format("Connection attempt %s failed: %s", finalAttempt, e.getMessage())); + + if (attempt < maxRetries) { + try { + Thread.sleep(delayMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + break; + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + + return null; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/PropertiesFileExample.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/PropertiesFileExample.java new file mode 100644 index 000000000..78e037868 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/example/PropertiesFileExample.java @@ -0,0 +1,170 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.example; + +import software.amazon.jdbc.factory.EncryptingDataSourceFactory; +import java.util.logging.Logger; +import software.amazon.jdbc.plugin.encryption.wrapper.EncryptingDataSource; + +import java.io.IOException; +import java.io.InputStream; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Properties; + +/** + * Example demonstrating how to use the encryption functionality with a properties file. + */ +public class PropertiesFileExample { + + private static final Logger LOGGER = Logger.getLogger(PropertiesFileExample.class.getName()); + + public static void main(String[] args) { + try { + // Load properties from file + Properties properties = loadPropertiesFromFile("example-jdbc-wrapper.properties"); + + // Create EncryptingDataSource using the properties + EncryptingDataSource dataSource = createDataSourceFromProperties(properties); + + // Use the DataSource + demonstrateEncryptedOperations(dataSource); + + // Clean up + dataSource.close(); + + } catch (Exception e) { + LOGGER.severe(()->String.format("Example execution failed %s", e.getMessage())); + } + } + + /** + * Loads properties from a file in the classpath. + */ + private static Properties loadPropertiesFromFile(String filename) throws IOException { + Properties properties = new Properties(); + + try (InputStream inputStream = PropertiesFileExample.class.getClassLoader() + .getResourceAsStream(filename)) { + + if (inputStream == null) { + throw new IOException("Properties file not found: " + filename); + } + + properties.load(inputStream); + LOGGER.info(()->String.format("Loaded properties from file: %s", filename)); + } + + return properties; + } + + /** + * Creates an EncryptingDataSource from properties. + */ + private static EncryptingDataSource createDataSourceFromProperties(Properties properties) throws SQLException { + String jdbcUrl = properties.getProperty("jdbcUrl"); + String username = properties.getProperty("username"); + String password = properties.getProperty("password"); + + if (jdbcUrl == null || username == null || password == null) { + throw new SQLException("Missing required database connection properties"); + } + + LOGGER.info(()->String.format("Creating EncryptingDataSource for URL: %s", jdbcUrl)); + + return EncryptingDataSourceFactory.createWithAwsWrapper(jdbcUrl, username, password, properties); + } + + /** + * Demonstrates encrypted database operations. + */ + private static void demonstrateEncryptedOperations(EncryptingDataSource dataSource) throws SQLException { + LOGGER.info(()->"Demonstrating encrypted database operations"); + + try (Connection connection = dataSource.getConnection()) { + + // Create test table + createTestTable(connection); + + // Insert encrypted data + insertTestData(connection); + + // Query and decrypt data + queryTestData(connection); + + LOGGER.info("Encrypted operations completed successfully"); + } + } + + /** + * Creates a test table for demonstration. + */ + private static void createTestTable(Connection connection) throws SQLException { + String createTableSql = "CREATE TABLE IF NOT EXISTS test_users (" + + "id SERIAL PRIMARY KEY, " + + "name VARCHAR(100) NOT NULL, " + + "email VARCHAR(100), " + + "ssn VARCHAR(20), " + + "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP" + + ")"; + + try (PreparedStatement stmt = connection.prepareStatement(createTableSql)) { + stmt.executeUpdate(); + LOGGER.finest(()->"Test table created or already exists"); + } + } + + /** + * Inserts test data that will be automatically encrypted for configured columns. + */ + private static void insertTestData(Connection connection) throws SQLException { + String insertSql = "INSERT INTO test_users (name, email, ssn) VALUES (?, ?, ?)"; + + try (PreparedStatement stmt = connection.prepareStatement(insertSql)) { + // Insert test user + stmt.setString(1, "Jane Doe"); + stmt.setString(2, "jane.doe@example.com"); // Will be encrypted if configured + stmt.setString(3, "987-65-4321"); // Will be encrypted if configured + stmt.executeUpdate(); + + LOGGER.info("Inserted test data with automatic encryption"); + } + } + + /** + * Queries test data that will be automatically decrypted for configured columns. + */ + private static void queryTestData(Connection connection) throws SQLException { + String selectSql = "SELECT id, name, email, ssn FROM test_users ORDER BY id DESC LIMIT 1"; + + try (PreparedStatement stmt = connection.prepareStatement(selectSql); + ResultSet rs = stmt.executeQuery()) { + + if (rs.next()) { + int id = rs.getInt("id"); + String name = rs.getString("name"); + String email = rs.getString("email"); // Will be decrypted if configured + String ssn = rs.getString("ssn"); // Will be decrypted if configured + + LOGGER.info(()->String.format("Retrieved user %s: Name=%s, Email=%s, SSN=%s", id, name, email, ssn)); + } + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/exception/IndependentConnectionException.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/exception/IndependentConnectionException.java new file mode 100644 index 000000000..7f44da22e --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/exception/IndependentConnectionException.java @@ -0,0 +1,217 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.exception; + +import software.amazon.jdbc.plugin.encryption.model.ConnectionParameters; + +import java.sql.SQLException; + +/** + * Exception thrown when independent connection creation fails. + * This exception provides detailed context about the connection creation failure, + * including the connection parameters that were attempted. + */ +public class IndependentConnectionException extends SQLException { + + private final ConnectionParameters attemptedParameters; + private final String connectionAttempt; + private final String failureReason; + + /** + * Creates a new IndependentConnectionException with a message and connection parameters. + * + * @param message the detailed error message + * @param attemptedParameters the connection parameters that failed to create a connection + */ + public IndependentConnectionException(String message, ConnectionParameters attemptedParameters) { + super(formatMessage(message, attemptedParameters, null)); + this.attemptedParameters = attemptedParameters; + this.connectionAttempt = null; + this.failureReason = null; + } + + /** + * Creates a new IndependentConnectionException with a message, cause, and connection parameters. + * + * @param message the detailed error message + * @param cause the underlying cause of the connection failure + * @param attemptedParameters the connection parameters that failed to create a connection + */ + public IndependentConnectionException(String message, Throwable cause, ConnectionParameters attemptedParameters) { + super(formatMessage(message, attemptedParameters, cause), cause); + this.attemptedParameters = attemptedParameters; + this.connectionAttempt = null; + this.failureReason = null; + } + + /** + * Creates a new IndependentConnectionException with detailed context. + * + * @param message the detailed error message + * @param attemptedParameters the connection parameters that failed to create a connection + * @param connectionAttempt description of what connection creation was attempted + * @param failureReason specific reason for the connection failure + */ + public IndependentConnectionException(String message, ConnectionParameters attemptedParameters, + String connectionAttempt, String failureReason) { + super(formatMessage(message, attemptedParameters, null, connectionAttempt, failureReason)); + this.attemptedParameters = attemptedParameters; + this.connectionAttempt = connectionAttempt; + this.failureReason = failureReason; + } + + /** + * Creates a new IndependentConnectionException with detailed context and cause. + * + * @param message the detailed error message + * @param cause the underlying cause of the connection failure + * @param attemptedParameters the connection parameters that failed to create a connection + * @param connectionAttempt description of what connection creation was attempted + * @param failureReason specific reason for the connection failure + */ + public IndependentConnectionException(String message, Throwable cause, ConnectionParameters attemptedParameters, + String connectionAttempt, String failureReason) { + super(formatMessage(message, attemptedParameters, cause, connectionAttempt, failureReason), cause); + this.attemptedParameters = attemptedParameters; + this.connectionAttempt = connectionAttempt; + this.failureReason = failureReason; + } + + /** + * Gets the connection parameters that failed to create a connection. + * + * @return the attempted connection parameters + */ + public ConnectionParameters getAttemptedParameters() { + return attemptedParameters; + } + + /** + * Gets the description of what connection creation was attempted. + * + * @return the connection attempt description, or null if not provided + */ + public String getConnectionAttempt() { + return connectionAttempt; + } + + /** + * Gets the specific reason for the connection failure. + * + * @return the failure reason, or null if not provided + */ + public String getFailureReason() { + return failureReason; + } + + /** + * Formats the error message with connection parameters and cause information. + */ + private static String formatMessage(String message, ConnectionParameters attemptedParameters, Throwable cause) { + StringBuilder sb = new StringBuilder(); + sb.append("Independent connection creation failed"); + + if (message != null && !message.isEmpty()) { + sb.append(" - ").append(message); + } + + if (attemptedParameters != null) { + sb.append(" (attempted URL: "); + String jdbcUrl = attemptedParameters.getJdbcUrl(); + if (jdbcUrl != null) { + // Mask sensitive information in URL + sb.append(maskSensitiveUrl(jdbcUrl)); + } else { + sb.append("null"); + } + sb.append(")"); + } + + if (cause != null) { + sb.append(" (caused by: ").append(cause.getClass().getSimpleName()); + if (cause.getMessage() != null) { + sb.append(": ").append(cause.getMessage()); + } + sb.append(")"); + } + + return sb.toString(); + } + + /** + * Formats the error message with detailed context information. + */ + private static String formatMessage(String message, ConnectionParameters attemptedParameters, Throwable cause, + String connectionAttempt, String failureReason) { + StringBuilder sb = new StringBuilder(); + sb.append("Independent connection creation failed"); + + if (connectionAttempt != null && !connectionAttempt.isEmpty()) { + sb.append(" while attempting: ").append(connectionAttempt); + } + + if (message != null && !message.isEmpty()) { + sb.append(" - ").append(message); + } + + if (failureReason != null && !failureReason.isEmpty()) { + sb.append(" (reason: ").append(failureReason).append(")"); + } + + if (attemptedParameters != null) { + sb.append(" (attempted URL: "); + String jdbcUrl = attemptedParameters.getJdbcUrl(); + if (jdbcUrl != null) { + sb.append(maskSensitiveUrl(jdbcUrl)); + } else { + sb.append("null"); + } + sb.append(")"); + } + + if (cause != null) { + sb.append(" (caused by: ").append(cause.getClass().getSimpleName()); + if (cause.getMessage() != null) { + sb.append(": ").append(cause.getMessage()); + } + sb.append(")"); + } + + return sb.toString(); + } + + /** + * Masks sensitive information in JDBC URLs for logging purposes. + * Removes passwords and other sensitive parameters while preserving + * useful debugging information. + */ + private static String maskSensitiveUrl(String jdbcUrl) { + if (jdbcUrl == null) { + return null; + } + + // Remove password parameters from URL + String masked = jdbcUrl.replaceAll("([?&]password=)[^&]*", "$1***"); + masked = masked.replaceAll("([?&]pwd=)[^&]*", "$1***"); + + // Remove user credentials from URL if present + masked = masked.replaceAll("://[^:/@]+:[^@]*@", "://***:***@"); + + return masked; + } +} \ No newline at end of file diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/factory/IndependentDataSource.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/factory/IndependentDataSource.java new file mode 100644 index 000000000..d73aebff9 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/factory/IndependentDataSource.java @@ -0,0 +1,358 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.factory; + +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.plugin.encryption.logging.ErrorContext; +import java.util.logging.Logger; +import org.slf4j.MDC; + +import javax.sql.DataSource; +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicLong; + +/** + * DataSource implementation that creates independent connections using PluginService. + * This ensures that MetadataManager gets its own connections and doesn't share with client applications. + */ +public class IndependentDataSource implements DataSource { + + private static final Logger LOGGER = Logger.getLogger(IndependentDataSource.class.getName()); + + private final PluginService pluginService; + private final Properties connectionProperties; + private int loginTimeout = 0; + private PrintWriter logWriter; + + // Connection monitoring metrics + private final AtomicLong connectionRequestCount = new AtomicLong(0); + private final AtomicLong successfulConnectionCount = new AtomicLong(0); + private final AtomicLong failedConnectionCount = new AtomicLong(0); + private volatile long lastSuccessfulConnectionTime = 0; + private volatile long lastFailedConnectionTime = 0; + + /** + * Creates an IndependentDataSource with the given PluginService. + * + * @param pluginService the PluginService to use for creating connections + * @throws IllegalArgumentException if pluginService is null + */ + public IndependentDataSource(PluginService pluginService) { + this(pluginService, new Properties()); + } + + /** + * Creates an IndependentDataSource with PluginService and connection properties. + * + * @param pluginService the PluginService to use for creating connections + * @param connectionProperties additional connection properties + * @throws IllegalArgumentException if pluginService is null + */ + public IndependentDataSource(PluginService pluginService, Properties connectionProperties) { + if (pluginService == null) { + throw new IllegalArgumentException("PluginService cannot be null"); + } + + this.pluginService = pluginService; + this.connectionProperties = connectionProperties != null ? connectionProperties : new Properties(); + + LOGGER.info(()->"Created IndependentDataSource with PluginService"); + LOGGER.finest(()->String.format("IndependentDataSource configuration: PropertiesCount=%s", + this.connectionProperties.size())); + } + + @Override + public Connection getConnection() throws SQLException { + long requestId = connectionRequestCount.incrementAndGet(); + + MDC.put("operation", "GET_INDEPENDENT_CONNECTION"); + MDC.put("requestId", String.valueOf(requestId)); + + try { + LOGGER.finest(()->String.format("Connection request #%s - creating new independent connection via PluginService", requestId)); + return createNewConnection(); + } finally { + MDC.remove("operation"); + MDC.remove("requestId"); + } + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + long requestId = connectionRequestCount.incrementAndGet(); + + MDC.put("operation", "GET_INDEPENDENT_CONNECTION_WITH_CREDENTIALS"); + MDC.put("requestId", String.valueOf(requestId)); + + try { + LOGGER.finest(()->String.format("Connection request #%s - creating new independent connection with provided credentials", requestId)); + + // Create modified properties with the provided credentials + Properties modifiedProps = new Properties(connectionProperties); + modifiedProps.setProperty("user", username); + modifiedProps.setProperty("password", password); + + return createNewConnection(modifiedProps); + } finally { + MDC.remove("operation"); + MDC.remove("requestId"); + } + } + + /** + * Creates a new independent connection using the PluginService. + * + * @return a new database connection + * @throws SQLException if connection creation fails + */ + private Connection createNewConnection() throws SQLException { + return createNewConnection(connectionProperties); + } + + /** + * Creates a new independent connection using the PluginService with specified properties. + * + * @param props the connection properties to use + * @return a new database connection + * @throws SQLException if connection creation fails + */ + private Connection createNewConnection(Properties props) throws SQLException { + long startTime = System.currentTimeMillis(); + + LOGGER.finest(()->"Creating new independent connection via PluginService"); + + try { + // Get current host spec from PluginService + HostSpec hostSpec = pluginService.getCurrentHostSpec(); + + // Create connection using PluginService + Connection connection = pluginService.forceConnect(hostSpec, props); + + long duration = System.currentTimeMillis() - startTime; + successfulConnectionCount.incrementAndGet(); + lastSuccessfulConnectionTime = System.currentTimeMillis(); + + LOGGER.info(()->String.format("Successfully created independent connection via PluginService in %sms " + + "(total successful: %s, total failed: %s)", + duration, successfulConnectionCount.get(), failedConnectionCount.get())); + + return connection; + + } catch (SQLException e) { + long duration = System.currentTimeMillis() - startTime; + failedConnectionCount.incrementAndGet(); + lastFailedConnectionTime = System.currentTimeMillis(); + + LOGGER.severe(()->String.format("Failed to create independent connection via PluginService after %sms: %s " + + "(total successful: %d, total failed: %d)", + duration, e.getMessage(), + successfulConnectionCount.get(), failedConnectionCount.get())); + + // Create detailed error context for troubleshooting + String errorDetails = ErrorContext.builder() + .operation("CREATE_INDEPENDENT_CONNECTION_VIA_PLUGIN_SERVICE") + .buildMessage("Connection creation failed: " + e.getMessage()); + + LOGGER.severe(()->String.format("Connection creation error details: %s", errorDetails)); + + throw new SQLException( + "Failed to create independent connection via PluginService: " + e.getMessage(), + e + ); + } + } + + /** + * Validates that a connection can be created with the current PluginService. + * + * @return true if a connection can be created, false otherwise + */ + public boolean validateConnection() { + try (Connection conn = getConnection()) { + return conn != null && !conn.isClosed(); + } catch (SQLException e) { + LOGGER.finest(()->String.format("Connection validation failed", e)); + return false; + } + } + + /** + * Gets the PluginService used by this DataSource. + * + * @return the PluginService + */ + public PluginService getPluginService() { + return pluginService; + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isInstance(this)) { + return iface.cast(this); + } + throw new SQLException("Cannot unwrap to " + iface.getName()); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return iface.isInstance(this); + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return logWriter; + } + + @Override + public void setLogWriter(PrintWriter out) throws SQLException { + this.logWriter = out; + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + this.loginTimeout = seconds; + } + + @Override + public int getLoginTimeout() throws SQLException { + return loginTimeout; + } + + @Override + public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException { + throw new SQLFeatureNotSupportedException("getParentLogger is not supported"); + } + + // Connection monitoring and metrics methods + + /** + * Gets the total number of connection requests made to this DataSource. + * + * @return the total connection request count + */ + public long getConnectionRequestCount() { + return connectionRequestCount.get(); + } + + /** + * Gets the number of successful connection creations. + * + * @return the successful connection count + */ + public long getSuccessfulConnectionCount() { + return successfulConnectionCount.get(); + } + + /** + * Gets the number of failed connection creation attempts. + * + * @return the failed connection count + */ + public long getFailedConnectionCount() { + return failedConnectionCount.get(); + } + + /** + * Gets the timestamp of the last successful connection creation. + * + * @return the timestamp in milliseconds, or 0 if no successful connections + */ + public long getLastSuccessfulConnectionTime() { + return lastSuccessfulConnectionTime; + } + + /** + * Gets the timestamp of the last failed connection attempt. + * + * @return the timestamp in milliseconds, or 0 if no failed connections + */ + public long getLastFailedConnectionTime() { + return lastFailedConnectionTime; + } + + /** + * Calculates the connection success rate as a percentage. + * + * @return the success rate (0.0 to 1.0), or 1.0 if no attempts have been made + */ + public double getConnectionSuccessRate() { + long total = connectionRequestCount.get(); + if (total == 0) return 1.0; + + return (double) successfulConnectionCount.get() / total; + } + + /** + * Checks if the DataSource is currently healthy based on recent connection attempts. + * + * @return true if the DataSource appears healthy, false otherwise + */ + public boolean isHealthy() { + // Consider healthy if success rate is above 80% or if we haven't had failures recently + double successRate = getConnectionSuccessRate(); + long timeSinceLastFailure = System.currentTimeMillis() - lastFailedConnectionTime; + + return successRate >= 0.8 || (lastFailedConnectionTime == 0) || (timeSinceLastFailure > 300000); // 5 minutes + } + + /** + * Gets a comprehensive status message about the DataSource health and metrics. + * + * @return a detailed status message + */ + public String getHealthStatus() { + StringBuilder sb = new StringBuilder(); + + sb.append("IndependentDataSource Status: "); + sb.append("Healthy=").append(isHealthy()).append(", "); + sb.append("Requests=").append(connectionRequestCount.get()).append(", "); + sb.append("Successful=").append(successfulConnectionCount.get()).append(", "); + sb.append("Failed=").append(failedConnectionCount.get()).append(", "); + sb.append("SuccessRate=").append(String.format("%.2f%%", getConnectionSuccessRate() * 100)); + + if (lastSuccessfulConnectionTime > 0) { + long timeSinceSuccess = System.currentTimeMillis() - lastSuccessfulConnectionTime; + sb.append(", LastSuccess=").append(timeSinceSuccess).append("ms ago"); + } + + if (lastFailedConnectionTime > 0) { + long timeSinceFailure = System.currentTimeMillis() - lastFailedConnectionTime; + sb.append(", LastFailure=").append(timeSinceFailure).append("ms ago"); + } + + return sb.toString(); + } + + /** + * Logs the current health status and metrics. + */ + public void logHealthStatus() { + String status = getHealthStatus(); + + if (isHealthy()) { + LOGGER.info(()->String.format("IndependentDataSource health check: %s", status)); + } else { + LOGGER.warning(()->String.format("IndependentDataSource health check - UNHEALTHY: %s", status)); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementExample.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementExample.java new file mode 100644 index 000000000..5142b9ed4 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementExample.java @@ -0,0 +1,204 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.key; + +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import java.util.logging.Logger; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kms.KmsClient; + +import javax.sql.DataSource; +import java.time.Duration; +import java.util.List; + +/** + * Example demonstrating how to use the KeyManagementUtility for administrative tasks. + * This class shows typical workflows for setting up and managing encryption keys. + */ +public class KeyManagementExample { + + private static final Logger LOGGER = Logger.getLogger(KeyManagementExample.class.getName()); + + private final KeyManagementUtility keyManagementUtility; + + public KeyManagementExample(DataSource dataSource, KmsClient kmsClient) { + // Create encryption configuration + EncryptionConfig config = EncryptionConfig.builder() + .kmsRegion("us-east-1") + .defaultMasterKeyArn("arn:aws:kms:us-east-1:123456789012:key/default-key") + .cacheEnabled(true) + .cacheExpirationMinutes(30) + .maxRetries(3) + .retryBackoffBase(Duration.ofMillis(100)) + .build(); + + // Create managers + KeyManager keyManager = null; //new KeyManager(kmsClient, dataSource, config); + MetadataManager metadataManager = null; //new MetadataManager(dataSource, config); + + // Create utility + this.keyManagementUtility = new KeyManagementUtility( + keyManager, metadataManager, dataSource, kmsClient, config); + } + + /** + * Example: Setting up encryption for a new application. + * + * @throws KeyManagementException if key management operations fail + */ + public void setupNewApplication() throws KeyManagementException { + LOGGER.info("Setting up encryption for new application"); + + // 1. Create a master key for the application + String masterKeyArn = keyManagementUtility.createMasterKeyWithPermissions( + "JDBC Encryption Master Key for MyApp"); + + LOGGER.info(()->String.format("Created master key: %s", masterKeyArn)); + + // 2. Initialize encryption for sensitive columns + String userEmailKeyId = keyManagementUtility.initializeEncryptionForColumn( + "users", "email", masterKeyArn); + + String userSsnKeyId = keyManagementUtility.initializeEncryptionForColumn( + "users", "ssn", masterKeyArn); + + String orderCreditCardKeyId = keyManagementUtility.initializeEncryptionForColumn( + "orders", "credit_card_number", masterKeyArn); + + LOGGER.info(()->String.format("Initialized encryption for users.email with key: %s", userEmailKeyId)); + LOGGER.info(()->String.format("Initialized encryption for users.ssn with key: %s", userSsnKeyId)); + LOGGER.info(()->String.format("Initialized encryption for orders.credit_card_number with key: %s", orderCreditCardKeyId)); + } + + /** + * Example: Adding encryption to an existing column. + * + * @throws KeyManagementException if key management operations fail + */ + public void addEncryptionToExistingColumn() throws KeyManagementException { + LOGGER.info("Adding encryption to existing column"); + + String masterKeyArn = "arn:aws:kms:us-east-1:123456789012:key/existing-master-key"; + + // Validate the master key first + if (!keyManagementUtility.validateMasterKey(masterKeyArn)) { + throw new KeyManagementException("Master key is not valid or accessible: " + masterKeyArn); + } + + // Initialize encryption for the column + String keyId = keyManagementUtility.initializeEncryptionForColumn( + "customers", "phone_number", masterKeyArn, "AES-256-GCM"); + + LOGGER.info(()->String.format("Added encryption to customers.phone_number with key: %s", keyId)); + } + + /** + * Example: Rotating keys for security compliance. + * + * @throws KeyManagementException if key management operations fail + */ + public void performKeyRotation() throws KeyManagementException { + LOGGER.info("Performing key rotation for security compliance"); + + // Rotate key for a specific column + String newKeyId = keyManagementUtility.rotateDataKey("users", "ssn", null); + LOGGER.info(()->String.format("Rotated key for users.ssn, new key ID: %s", newKeyId)); + + // Rotate with a new master key + String newMasterKeyArn = keyManagementUtility.createMasterKeyWithPermissions( + "New Master Key for Enhanced Security"); + + String newKeyIdWithNewMaster = keyManagementUtility.rotateDataKey( + "orders", "credit_card_number", newMasterKeyArn); + + LOGGER.info(()->String.format("Rotated key for orders.credit_card_number with new master key, new key ID: %s", + newKeyIdWithNewMaster)); + } + + /** + * Example: Auditing and managing existing keys. + * + * @throws KeyManagementException if key management operations fail + */ + public void auditExistingKeys() throws KeyManagementException { + LOGGER.info("Auditing existing encryption keys"); + + // Find all columns using a specific key + String keyIdToAudit = "some-existing-key-id"; + List columnsUsingKey = keyManagementUtility.getColumnsUsingKey(keyIdToAudit); + + LOGGER.info(()->String.format("Key %s is used by %s columns: %s", + keyIdToAudit, columnsUsingKey.size(), columnsUsingKey)); + + // Validate all master keys are still accessible + String[] masterKeysToValidate = { + "arn:aws:kms:us-east-1:123456789012:key/key1", + "arn:aws:kms:us-east-1:123456789012:key/key2", + "arn:aws:kms:us-east-1:123456789012:key/key3" + }; + + for (String masterKeyArn : masterKeysToValidate) { + boolean isValid = keyManagementUtility.validateMasterKey(masterKeyArn); + LOGGER.info(()->String.format("Master key %s validation: %s", masterKeyArn, isValid ? "VALID" : "INVALID")); + } + } + + /** + * Example: Removing encryption from a column (for decommissioning). + * + * @throws KeyManagementException if key management operations fail + */ + public void removeEncryptionFromColumn() throws KeyManagementException { + LOGGER.info("Removing encryption from decommissioned column"); + + // Remove encryption configuration (keys remain for data recovery) + keyManagementUtility.removeEncryptionForColumn("old_table", "deprecated_column"); + + LOGGER.info("Removed encryption configuration for old_table.deprecated_column"); + } + + /** + * Main method demonstrating the complete workflow. + * + * @param args Command line arguments + */ + public static void main(String[] args) { + try { + // In a real application, you would configure these properly + DataSource dataSource = null; // Configure your DataSource + KmsClient kmsClient = KmsClient.builder() + .region(Region.US_EAST_1) + .build(); + + KeyManagementExample example = new KeyManagementExample(dataSource, kmsClient); + + // Run examples (commented out since we don't have real connections) + // example.setupNewApplication(); + // example.addEncryptionToExistingColumn(); + // example.performKeyRotation(); + // example.auditExistingKeys(); + // example.removeEncryptionFromColumn(); + + LOGGER.info("Key management examples completed successfully"); + + } catch (Exception e) { + LOGGER.severe(()->String.format("Error running key management examples", e)); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementException.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementException.java new file mode 100644 index 000000000..1419cf2c9 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementException.java @@ -0,0 +1,272 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.key; + +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +/** + * Exception thrown when key management operations fail. + * Extends SQLException to integrate with JDBC error handling. + * Provides enhanced error context information for better troubleshooting. + */ +public class KeyManagementException extends SQLException { + + private static final long serialVersionUID = 1L; + + // SQL State codes for different key management error types + public static final String KEY_CREATION_FAILED_STATE = "KEY01"; + public static final String KEY_RETRIEVAL_FAILED_STATE = "KEY02"; + public static final String KEY_DECRYPTION_FAILED_STATE = "KEY03"; + public static final String KEY_STORAGE_FAILED_STATE = "KEY04"; + public static final String KMS_CONNECTION_FAILED_STATE = "KEY05"; + public static final String INVALID_KEY_METADATA_STATE = "KEY06"; + + private final Map errorContext = new HashMap<>(); + + /** + * Constructs a KeyManagementException with the specified detail message. + * + * @param message the detail message + */ + public KeyManagementException(String message) { + super(message, KEY_RETRIEVAL_FAILED_STATE); + } + + /** + * Constructs a KeyManagementException with the specified detail message and cause. + * + * @param message the detail message + * @param cause the cause of this exception + */ + public KeyManagementException(String message, Throwable cause) { + super(message, KEY_RETRIEVAL_FAILED_STATE, cause); + } + + /** + * Constructs a KeyManagementException with the specified cause. + * + * @param cause the cause of this exception + */ + public KeyManagementException(Throwable cause) { + super(cause.getMessage(), KEY_RETRIEVAL_FAILED_STATE, cause); + } + + /** + * Constructs a KeyManagementException with the specified detail message, + * SQL state, and vendor code. + * + * @param message the detail message + * @param sqlState the SQL state + * @param vendorCode the vendor-specific error code + */ + public KeyManagementException(String message, String sqlState, int vendorCode) { + super(message, sqlState, vendorCode); + } + + /** + * Constructs a KeyManagementException with the specified detail message, + * SQL state, vendor code, and cause. + * + * @param message the detail message + * @param sqlState the SQL state + * @param vendorCode the vendor-specific error code + * @param cause the cause of this exception + */ + public KeyManagementException(String message, String sqlState, int vendorCode, Throwable cause) { + super(message, sqlState, vendorCode, cause); + } + + /** + * Constructs a KeyManagementException with the specified detail message, SQL state and cause. + * + * @param message the detail message + * @param sqlState the SQL state + * @param cause the cause of this exception + */ + public KeyManagementException(String message, String sqlState, Throwable cause) { + super(message, sqlState, cause); + } + + /** + * Adds context information to the exception. + * + * @param key the context key + * @param value the context value + * @return this exception for method chaining + */ + public KeyManagementException withContext(String key, Object value) { + errorContext.put(key, value); + return this; + } + + /** + * Adds key ID to the error context (sanitized). + * + * @param keyId the key ID + * @return this exception for method chaining + */ + public KeyManagementException withKeyId(String keyId) { + return withContext("keyId", sanitizeKeyId(keyId)); + } + + /** + * Adds master key ARN to the error context (sanitized). + * + * @param masterKeyArn the master key ARN + * @return this exception for method chaining + */ + public KeyManagementException withMasterKeyArn(String masterKeyArn) { + return withContext("masterKeyArn", sanitizeArn(masterKeyArn)); + } + + /** + * Adds operation type to the error context. + * + * @param operation the operation being performed + * @return this exception for method chaining + */ + public KeyManagementException withOperation(String operation) { + return withContext("operation", operation); + } + + /** + * Adds retry attempt information to the error context. + * + * @param attempt the current attempt number + * @param maxAttempts the maximum number of attempts + * @return this exception for method chaining + */ + public KeyManagementException withRetryInfo(int attempt, int maxAttempts) { + return withContext("retryAttempt", attempt).withContext("maxRetryAttempts", maxAttempts); + } + + /** + * Gets the error context map. + * + * @return a copy of the error context + */ + public Map getErrorContext() { + return new HashMap<>(errorContext); + } + + /** + * Gets a formatted error message including context information. + * + * @return formatted error message with context + */ + public String getDetailedMessage() { + if (errorContext.isEmpty()) { + return getMessage(); + } + + StringBuilder sb = new StringBuilder(getMessage()); + sb.append(" [Context: "); + + boolean first = true; + for (Map.Entry entry : errorContext.entrySet()) { + if (!first) { + sb.append(", "); + } + sb.append(entry.getKey()).append("=").append(entry.getValue()); + first = false; + } + + sb.append("]"); + return sb.toString(); + } + + /** + * Creates a KeyManagementException for key creation failures. + * + * @param message Error message + * @param cause Root cause + * @return KeyManagementException instance + */ + public static KeyManagementException keyCreationFailed(String message, Throwable cause) { + return new KeyManagementException(message, KEY_CREATION_FAILED_STATE, cause); + } + + /** + * Creates a KeyManagementException for key decryption failures. + * + * @param keyId Key ID + * @param masterKeyArn Master key ARN + * @param cause Root cause + * @return KeyManagementException instance + */ + public static KeyManagementException keyDecryptionFailed(String keyId, String masterKeyArn, Throwable cause) { + return new KeyManagementException("Failed to decrypt data key", KEY_DECRYPTION_FAILED_STATE, cause) + .withKeyId(keyId) + .withMasterKeyArn(masterKeyArn); + } + + /** + * Creates a KeyManagementException for key storage failures. + * + * @param message Error message + * @param cause Root cause + * @return KeyManagementException instance + */ + public static KeyManagementException keyStorageFailed(String message, Throwable cause) { + return new KeyManagementException(message, KEY_STORAGE_FAILED_STATE, cause); + } + + /** + * Creates a KeyManagementException for KMS connection failures. + * + * @param message Error message + * @param cause Root cause + * @return KeyManagementException instance + */ + public static KeyManagementException kmsConnectionFailed(String message, Throwable cause) { + return new KeyManagementException(message, KMS_CONNECTION_FAILED_STATE, cause); + } + + /** + * Creates a KeyManagementException for invalid key metadata. + * + * @param message Error message + * @return New KeyManagementException instance + */ + public static KeyManagementException invalidKeyMetadata(String message) { + return new KeyManagementException(message, INVALID_KEY_METADATA_STATE, null); + } + + // Sanitization methods to prevent sensitive data exposure + + private String sanitizeKeyId(String keyId) { + if (keyId == null) return null; + // Show only first and last 4 characters of key ID + if (keyId.length() > 8) { + return keyId.substring(0, 4) + "***" + keyId.substring(keyId.length() - 4); + } + return "***"; + } + + private String sanitizeArn(String arn) { + if (arn == null) return null; + // Keep only the key ID part of the ARN + int lastSlash = arn.lastIndexOf('/'); + if (lastSlash != -1 && lastSlash < arn.length() - 1) { + return "arn:aws:kms:***:***:key/" + arn.substring(lastSlash + 1); + } + return "arn:aws:kms:***:***:key/***"; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementUtility.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementUtility.java new file mode 100644 index 000000000..6fabec4a8 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManagementUtility.java @@ -0,0 +1,476 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.key; + +import software.amazon.jdbc.plugin.encryption.metadata.MetadataException; +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.model.ColumnEncryptionConfig; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import software.amazon.jdbc.plugin.encryption.model.KeyMetadata; +import java.util.logging.Logger; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.*; + +import javax.sql.DataSource; +import java.sql.*; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * Utility class providing administrative functions for key management operations. + * This class offers high-level methods for creating master keys, setting up encryption + * for tables/columns, rotating keys, and managing the encryption lifecycle. + */ +public class KeyManagementUtility { + + private static final Logger LOGGER = Logger.getLogger(KeyManagementUtility.class.getName()); + + private final KeyManager keyManager; + private final MetadataManager metadataManager; + private final DataSource dataSource; + private final KmsClient kmsClient; + private final EncryptionConfig config; + + public KeyManagementUtility(KeyManager keyManager, MetadataManager metadataManager, + DataSource dataSource, KmsClient kmsClient, EncryptionConfig config) { + this.keyManager = Objects.requireNonNull(keyManager, "KeyManager cannot be null"); + this.metadataManager = Objects.requireNonNull(metadataManager, "MetadataManager cannot be null"); + this.dataSource = Objects.requireNonNull(dataSource, "DataSource cannot be null"); + this.kmsClient = Objects.requireNonNull(kmsClient, "KmsClient cannot be null"); + this.config = Objects.requireNonNull(config, "EncryptionConfig cannot be null"); + } + + private String getInsertEncryptionMetadataSql() { + String schema = config.getEncryptionMetadataSchema(); + return "INSERT INTO " + schema + ".encryption_metadata (table_name, column_name, encryption_algorithm, key_id, created_at, updated_at) " + + "VALUES (?, ?, ?, ?, ?, ?) " + + "ON CONFLICT (table_name, column_name) DO UPDATE SET " + + "encryption_algorithm = EXCLUDED.encryption_algorithm, " + + "key_id = EXCLUDED.key_id, " + + "updated_at = EXCLUDED.updated_at"; + } + + private String getUpdateEncryptionMetadataKeySql() { + return "UPDATE " + config.getEncryptionMetadataSchema() + ".encryption_metadata SET key_id = ?, updated_at = ? " + + "WHERE table_name = ? AND column_name = ?"; + } + + private String getSelectColumnsWithKeySql() { + return "SELECT table_name, column_name FROM " + config.getEncryptionMetadataSchema() + ".encryption_metadata WHERE key_id = ?"; + } + + private String getDeleteEncryptionMetadataSql() { + return "DELETE FROM " + config.getEncryptionMetadataSchema() + ".encryption_metadata WHERE table_name = ? AND column_name = ?"; + } + + /** + * Creates a new KMS master key with proper permissions for encryption operations. + * + * @param description Description for the master key + * @param keyPolicy Optional key policy JSON string. If null, uses default policy + * @return The ARN of the created master key + * @throws KeyManagementException if key creation fails + */ + public String createMasterKeyWithPermissions(String description, String keyPolicy) + throws KeyManagementException { + Objects.requireNonNull(description, "Description cannot be null"); + + LOGGER.info(()->String.format("Creating KMS master key with permissions: %s", description)); + + try { + CreateKeyRequest.Builder requestBuilder = CreateKeyRequest.builder() + .description(description) + .keyUsage(KeyUsageType.ENCRYPT_DECRYPT) + .keySpec(KeySpec.SYMMETRIC_DEFAULT); + + // Add key policy if provided + if (keyPolicy != null && !keyPolicy.trim().isEmpty()) { + requestBuilder.policy(keyPolicy); + LOGGER.finest(()->"Using custom key policy for master key creation"); + } + + CreateKeyResponse response = kmsClient.createKey(requestBuilder.build()); + String keyArn = response.keyMetadata().arn(); + + // Create an alias for easier management + String aliasName = "alias/jdbc-encryption-" + System.currentTimeMillis(); + CreateAliasRequest aliasRequest = CreateAliasRequest.builder() + .aliasName(aliasName) + .targetKeyId(keyArn) + .build(); + + kmsClient.createAlias(aliasRequest); + + LOGGER.info(()->String.format("Successfully created KMS master key: %s with alias: %s", keyArn, aliasName)); + return keyArn; + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to create KMS master key with permissions", e.getMessage())); + throw new KeyManagementException("Failed to create KMS master key: " + e.getMessage(), e); + } + } + + /** + * Creates a master key with default permissions suitable for JDBC encryption. + * + * @param description Description for the master key + * @return The ARN of the created master key + * @throws KeyManagementException if key creation fails + */ + public String createMasterKeyWithPermissions(String description) throws KeyManagementException { + return createMasterKeyWithPermissions(description, null); + } + + /** + * Generates and stores a data key for the specified table and column. + * This method creates the complete encryption setup for a column. + * + * @param tableName Name of the table + * @param columnName Name of the column + * @param masterKeyArn ARN of the master key to use + * @param algorithm Encryption algorithm (defaults to AES-256-GCM if null) + * @return The generated key ID + * @throws KeyManagementException if key generation or storage fails + */ + public String generateAndStoreDataKey(String tableName, String columnName, + String masterKeyArn, String algorithm) + throws KeyManagementException { + Objects.requireNonNull(tableName, "Table name cannot be null"); + Objects.requireNonNull(columnName, "Column name cannot be null"); + Objects.requireNonNull(masterKeyArn, "Master key ARN cannot be null"); + + if (algorithm == null || algorithm.trim().isEmpty()) { + algorithm = "AES-256-GCM"; + } + + LOGGER.info(()->String.format("Generating and storing data key for %s.%s using master key: %s", + tableName, columnName, masterKeyArn)); + + try { + + // Generate the data key using KMS + KeyManager.DataKeyResult dataKeyResult = keyManager.generateDataKey(masterKeyArn); + + try { + // Generate a unique key name + String keyName = "key-" + tableName + "-" + columnName + "-" + System.currentTimeMillis(); + + // Create key metadata + KeyMetadata keyMetadata = KeyMetadata.builder() + .keyId("dummy") // Not used anymore but required by builder + .keyName(keyName) + .masterKeyArn(masterKeyArn) + .encryptedDataKey(dataKeyResult.getEncryptedKey()) + .keySpec("AES_256") + .createdAt(Instant.now()) + .lastUsedAt(Instant.now()) + .build(); + + // Store key metadata in database and get the generated integer ID + int generatedKeyId = keyManager.storeKeyMetadata(tableName, columnName, keyMetadata); + + // Store encryption metadata using the generated integer key ID + storeEncryptionMetadata(tableName, columnName, algorithm, generatedKeyId); + + // Refresh metadata cache + metadataManager.refreshMetadata(); + + LOGGER.info(()->String.format("Successfully generated and stored data key for %s.%s with key ID: %s", + tableName, columnName, generatedKeyId)); + + return String.valueOf(generatedKeyId); + + } finally { + // Clear sensitive data from memory + dataKeyResult.clearPlaintextKey(); + } + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to generate and store data key for %s.%s", tableName, columnName, e.getMessage())); + throw new KeyManagementException("Failed to generate and store data key: " + e.getMessage(), e); + } + } + + /** + * Rotates the data key for an existing encrypted column. + * This creates a new data key while preserving the existing encryption metadata. + * + * @param tableName Name of the table + * @param columnName Name of the column + * @param newMasterKeyArn Optional new master key ARN. If null, uses existing master key + * @return The new key ID + * @throws KeyManagementException if key rotation fails + */ + public String rotateDataKey(String tableName, String columnName, String newMasterKeyArn) + throws KeyManagementException { + Objects.requireNonNull(tableName, "Table name cannot be null"); + Objects.requireNonNull(columnName, "Column name cannot be null"); + + LOGGER.info(()->String.format("Rotating data key for %s.%s", tableName, columnName)); + + try { + // Get current encryption configuration + ColumnEncryptionConfig currentConfig = metadataManager.getColumnConfig(tableName, columnName); + if (currentConfig == null) { + throw new KeyManagementException("No encryption configuration found for " + tableName + "." + columnName); + } + + // Use existing master key if new one not provided + String masterKeyArn = newMasterKeyArn != null ? newMasterKeyArn : + currentConfig.getKeyMetadata().getMasterKeyArn(); + + // Generate new data key + String newKeyId = keyManager.generateKeyId(); + KeyManager.DataKeyResult dataKeyResult = keyManager.generateDataKey(masterKeyArn); + + try { + // Create new key metadata + KeyMetadata newKeyMetadata = KeyMetadata.builder() + .keyId(newKeyId) + .masterKeyArn(masterKeyArn) + .encryptedDataKey(dataKeyResult.getEncryptedKey()) + .keySpec("AES_256") + .createdAt(Instant.now()) + .lastUsedAt(Instant.now()) + .build(); + + // Store new key metadata + keyManager.storeKeyMetadata(tableName, columnName, newKeyMetadata); + + // Update encryption metadata to use new key + updateEncryptionMetadataKey(tableName, columnName, newKeyId); + + // Refresh metadata cache + metadataManager.refreshMetadata(); + + LOGGER.info(()->String.format("Successfully rotated data key for %s.%s from %s to %s", + tableName, columnName, currentConfig.getKeyId(), newKeyId)); + + return newKeyId; + + } finally { + dataKeyResult.clearPlaintextKey(); + } + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to rotate data key for %s.%s", tableName, columnName, e.getMessage())); + throw new KeyManagementException("Failed to rotate data key: " + e.getMessage(), e); + } + } + + /** + * Initializes encryption for a new table and column combination. + * This is a convenience method that creates everything needed for encryption. + * + * @param tableName Name of the table + * @param columnName Name of the column + * @param masterKeyArn ARN of the master key to use + * @return The generated key ID + * @throws KeyManagementException if initialization fails + */ + public String initializeEncryptionForColumn(String tableName, String columnName, String masterKeyArn) + throws KeyManagementException { + return initializeEncryptionForColumn(tableName, columnName, masterKeyArn, "AES-256-GCM"); + } + + /** + * Initializes encryption for a new table and column combination with specified algorithm. + * + * @param tableName Name of the table + * @param columnName Name of the column + * @param masterKeyArn ARN of the master key to use + * @param algorithm Encryption algorithm to use + * @return The generated key ID + * @throws KeyManagementException if initialization fails + */ + public String initializeEncryptionForColumn(String tableName, String columnName, + String masterKeyArn, String algorithm) + throws KeyManagementException { + LOGGER.info(()->String.format("Initializing encryption for column %s.%s", tableName, columnName)); + + // Check if column is already encrypted + try { + if (metadataManager.isColumnEncrypted(tableName, columnName)) { + throw new KeyManagementException("Column " + tableName + "." + columnName + " is already encrypted"); + } + } catch (MetadataException e) { + throw new KeyManagementException("Failed to check existing encryption status", e); + } + + // Generate and store the data key + return generateAndStoreDataKey(tableName, columnName, masterKeyArn, algorithm); + } + + /** + * Removes encryption configuration for a table and column. + * This does not delete the actual key data for security reasons. + * + * @param tableName Name of the table + * @param columnName Name of the column + * @throws KeyManagementException if removal fails + */ + public void removeEncryptionForColumn(String tableName, String columnName) + throws KeyManagementException { + Objects.requireNonNull(tableName, "Table name cannot be null"); + Objects.requireNonNull(columnName, "Column name cannot be null"); + + LOGGER.info(()->String.format("Removing encryption configuration for %s.%s", tableName, columnName)); + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(getDeleteEncryptionMetadataSql())) { + + stmt.setString(1, tableName); + stmt.setString(2, columnName); + + int rowsAffected = stmt.executeUpdate(); + if (rowsAffected == 0) { + LOGGER.warning(()->String.format("No encryption configuration found for %s.%s", tableName, columnName)); + } else { + LOGGER.info(()->String.format("Successfully removed encryption configuration for %s.%s", tableName, columnName)); + } + + // Refresh metadata cache + metadataManager.refreshMetadata(); + + } catch (MetadataException e) { + LOGGER.severe(()->String.format("Failed to refresh metadata after removing encryption configuration", e)); + throw new KeyManagementException("Failed to refresh metadata: " + e.getMessage(), e); + } catch (SQLException e) { + LOGGER.severe(()->String.format("Failed to remove encryption configuration for %s.%s", tableName, columnName, e)); + throw new KeyManagementException("Failed to remove encryption configuration: " + e.getMessage(), e); + } + } + + /** + * Lists all columns that use a specific key ID. + * Useful for understanding the impact of key operations. + * + * @param keyId The key ID to search for + * @return List of table.column identifiers using the key + * @throws KeyManagementException if query fails + */ + public List getColumnsUsingKey(String keyId) throws KeyManagementException { + Objects.requireNonNull(keyId, "Key ID cannot be null"); + + LOGGER.finest(()->String.format("Finding columns using key ID: %s", keyId)); + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(getSelectColumnsWithKeySql())) { + + stmt.setString(1, keyId); + + try (ResultSet rs = stmt.executeQuery()) { + List columns = new ArrayList<>(); + while (rs.next()) { + String tableName = rs.getString("table_name"); + String columnName = rs.getString("column_name"); + columns.add(tableName + "." + columnName); + } + return columns; + } + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Failed to find columns using key ID: %s", keyId, e.getMessage())); + throw new KeyManagementException("Failed to find columns using key: " + e.getMessage(), e); + } + } + + /** + * Validates that a master key exists and is accessible. + * + * @param masterKeyArn ARN of the master key to validate + * @return true if key is valid and accessible + * @throws KeyManagementException if validation fails + */ + public boolean validateMasterKey(String masterKeyArn) throws KeyManagementException { + Objects.requireNonNull(masterKeyArn, "Master key ARN cannot be null"); + + LOGGER.finest(()->String.format("Validating master key: %s", masterKeyArn)); + + try { + DescribeKeyRequest request = DescribeKeyRequest.builder() + .keyId(masterKeyArn) + .build(); + + DescribeKeyResponse response = kmsClient.describeKey(request); + software.amazon.awssdk.services.kms.model.KeyMetadata keyMetadata = response.keyMetadata(); + + boolean isValid = keyMetadata.enabled() && + keyMetadata.keyState() == KeyState.ENABLED && + keyMetadata.keyUsage() == KeyUsageType.ENCRYPT_DECRYPT; + + LOGGER.finest(()->String.format("Master key %s validation result: %s", masterKeyArn, isValid)); + return isValid; + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to validate master key: %s", masterKeyArn, e.getMessage())); + throw new KeyManagementException("Failed to validate master key: " + e.getMessage(), e); + } + } + + /** + * Stores encryption metadata in the database. + */ + private void storeEncryptionMetadata(String tableName, String columnName, + String algorithm, int keyId) throws SQLException { + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(getInsertEncryptionMetadataSql())) { + + Timestamp now = Timestamp.from(Instant.now()); + + stmt.setString(1, tableName); + stmt.setString(2, columnName); + stmt.setString(3, algorithm); + stmt.setInt(4, keyId); + stmt.setTimestamp(5, now); + stmt.setTimestamp(6, now); + + int rowsAffected = stmt.executeUpdate(); + if (rowsAffected == 0) { + throw new SQLException("Failed to store encryption metadata - no rows affected"); + } + + LOGGER.finest(()->String.format("Successfully stored encryption metadata for %s.%s", tableName, columnName)); + } + } + + /** + * Updates the key ID for existing encryption metadata. + */ + private void updateEncryptionMetadataKey(String tableName, String columnName, String newKeyId) + throws SQLException { + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(getUpdateEncryptionMetadataKeySql())) { + + stmt.setString(1, newKeyId); + stmt.setTimestamp(2, Timestamp.from(Instant.now())); + stmt.setString(3, tableName); + stmt.setString(4, columnName); + + int rowsAffected = stmt.executeUpdate(); + if (rowsAffected == 0) { + throw new SQLException("Failed to update encryption metadata key - no rows affected"); + } + + LOGGER.finest(()->String.format("Successfully updated encryption metadata key for %s.%s to %s", + tableName, columnName, newKeyId)); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManager.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManager.java new file mode 100644 index 000000000..bd7b0cd47 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/key/KeyManager.java @@ -0,0 +1,454 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.key; + +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.plugin.encryption.cache.DataKeyCache; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import software.amazon.jdbc.plugin.encryption.model.KeyMetadata; +import java.util.logging.Logger; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.*; + +import java.sql.*; +import java.time.Instant; +import java.util.Base64; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Manages KMS operations and data key lifecycle for the encryption plugin. + * Handles key creation, data key generation/decryption, and database storage of key metadata. + */ +public class KeyManager { + + private static final Logger LOGGER = Logger.getLogger(KeyManager.class.getName()); + + private final KmsClient kmsClient; + private final PluginService pluginService; + private final EncryptionConfig config; + private final DataKeyCache dataKeyCache; + + public KeyManager(KmsClient kmsClient, PluginService pluginService, EncryptionConfig config) { + this.kmsClient = Objects.requireNonNull(kmsClient, "KmsClient cannot be null"); + this.pluginService = Objects.requireNonNull(pluginService, "DataSource cannot be null"); + this.config = Objects.requireNonNull(config, "EncryptionConfig cannot be null"); + this.dataKeyCache = new DataKeyCache(config); + } + + private String getInsertKeyMetadataSql() { + String schema = config.getEncryptionMetadataSchema(); + return "INSERT INTO " + schema + ".key_storage (name, master_key_arn, encrypted_data_key, key_spec, created_at, last_used_at) " + + "VALUES (?, ?, ?, ?, ?, ?) " + + "RETURNING id"; + } + + private String getSelectKeyMetadataSql() { + return "SELECT id, name, master_key_arn, encrypted_data_key, key_spec, created_at, last_used_at " + + "FROM " + config.getEncryptionMetadataSchema() + ".key_storage WHERE id = ?"; + } + + private String getUpdateLastUsedSql() { + return "UPDATE " + config.getEncryptionMetadataSchema() + ".key_storage SET last_used_at = ? WHERE key_id = ?"; + } + + /** + * Creates a new KMS master key with the specified description. + * + * @param description Description for the master key + * @return The ARN of the created master key + * @throws KeyManagementException if key creation fails + */ + public String createMasterKey(String description) throws KeyManagementException { + Objects.requireNonNull(description, "Description cannot be null"); + + LOGGER.info(()->String.format("Creating KMS master key with description: %s", description)); + + try { + CreateKeyRequest request = CreateKeyRequest.builder() + .description(description) + .keyUsage(KeyUsageType.ENCRYPT_DECRYPT) + .keySpec(KeySpec.SYMMETRIC_DEFAULT) + .build(); + + CreateKeyResponse response = executeWithRetry(() -> kmsClient.createKey(request)); + String keyArn = response.keyMetadata().arn(); + + LOGGER.info(()->String.format("Successfully created KMS master key: %s", keyArn)); + return keyArn; + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to create KMS master key", e)); + throw new KeyManagementException("Failed to create KMS master key: " + e.getMessage(), e); + } + } + + /** + * Generates a new data key using the specified master key. + * + * @param masterKeyArn ARN of the master key to use for data key generation + * @return DataKeyResult containing both plaintext and encrypted data keys + * @throws KeyManagementException if data key generation fails + */ + public DataKeyResult generateDataKey(String masterKeyArn) throws KeyManagementException { + Objects.requireNonNull(masterKeyArn, "Master key ARN cannot be null"); + + LOGGER.finest(()->String.format("Generating data key using master key: %s", masterKeyArn)); + + try { + GenerateDataKeyRequest request = GenerateDataKeyRequest.builder() + .keyId(masterKeyArn) + .keySpec(DataKeySpec.AES_256) + .build(); + + GenerateDataKeyResponse response = executeWithRetry(() -> kmsClient.generateDataKey(request)); + + byte[] plaintextKey = response.plaintext().asByteArray(); + String encryptedKey = Base64.getEncoder().encodeToString(response.ciphertextBlob().asByteArray()); + + LOGGER.finest(()->String.format("Successfully generated data key for master key: %s", masterKeyArn)); + return new DataKeyResult(plaintextKey, encryptedKey); + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to generate data key for master key: %s", masterKeyArn, e)); + throw new KeyManagementException("Failed to generate data key: " + e.getMessage(), e); + } + } + + /** + * Decrypts an encrypted data key using KMS with caching support. + * + * @param encryptedDataKey Base64-encoded encrypted data key + * @param masterKeyArn ARN of the master key used for encryption + * @return Decrypted data key as byte array + * @throws KeyManagementException if decryption fails + */ + public byte[] decryptDataKey(String encryptedDataKey, String masterKeyArn) throws KeyManagementException { + Objects.requireNonNull(encryptedDataKey, "Encrypted data key cannot be null"); + Objects.requireNonNull(masterKeyArn, "Master key ARN cannot be null"); + + // Create cache key from encrypted data key hash + String cacheKey = createCacheKey(encryptedDataKey); + + // Try cache first if enabled + if (config.isDataKeyCacheEnabled()) { + byte[] cachedKey = dataKeyCache.get(cacheKey); + if (cachedKey != null) { + LOGGER.finest(()->"Cache hit for data key decryption"); + return cachedKey; + } + } + + LOGGER.finest(()->String.format("Decrypting data key using master key: %s", masterKeyArn)); + + try { + byte[] encryptedKeyBytes = Base64.getDecoder().decode(encryptedDataKey); + + DecryptRequest request = DecryptRequest.builder() + .ciphertextBlob(SdkBytes.fromByteArray(encryptedKeyBytes)) + .keyId(masterKeyArn) + .build(); + + DecryptResponse response = executeWithRetry(() -> kmsClient.decrypt(request)); + byte[] plaintextKey = response.plaintext().asByteArray(); + + // Cache the decrypted key if caching is enabled + if (config.isDataKeyCacheEnabled()) { + dataKeyCache.put(cacheKey, plaintextKey); + } + + LOGGER.finest(()->String.format("Successfully decrypted data key for master key: %s", masterKeyArn)); + return plaintextKey; + + } catch (Exception e) { + LOGGER.severe(()->String.format("Failed to decrypt data key for master key: %s", masterKeyArn, e)); + throw new KeyManagementException("Failed to decrypt data key: " + e.getMessage(), e); + } + } + + /** + * Stores key metadata in the database for the specified table and column. + * + * @param tableName Name of the table + * @param columnName Name of the column + * @param keyMetadata Key metadata to store + * @return the generated integer ID + * @throws KeyManagementException if storage fails + */ + public int storeKeyMetadata(String tableName, String columnName, KeyMetadata keyMetadata) + throws KeyManagementException { + Objects.requireNonNull(tableName, "Table name cannot be null"); + Objects.requireNonNull(columnName, "Column name cannot be null"); + Objects.requireNonNull(keyMetadata, "Key metadata cannot be null"); + + if (!keyMetadata.isValid()) { + throw new KeyManagementException("Invalid key metadata provided"); + } + + LOGGER.finest(()->String.format("Storing key metadata for %s.%s", tableName, columnName)); + + try (Connection conn = pluginService.forceConnect(pluginService.getCurrentHostSpec(), pluginService.getProperties()); + PreparedStatement stmt = conn.prepareStatement(getInsertKeyMetadataSql())) { + + stmt.setString(1, keyMetadata.getKeyName()); + stmt.setString(2, keyMetadata.getMasterKeyArn()); + stmt.setString(3, keyMetadata.getEncryptedDataKey()); + stmt.setString(4, keyMetadata.getKeySpec()); + stmt.setTimestamp(5, Timestamp.from(keyMetadata.getCreatedAt())); + stmt.setTimestamp(6, Timestamp.from(keyMetadata.getLastUsedAt())); + + ResultSet rs = stmt.executeQuery(); + if (rs.next()) { + int generatedId = rs.getInt(1); + LOGGER.finest(()->String.format("Successfully stored key metadata for %s.%s with ID: %s", tableName, columnName, generatedId)); + return generatedId; + } else { + throw new KeyManagementException("Failed to get generated key ID"); + } + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Database error storing key metadata for %s.%s %s", tableName, columnName, e.getMessage())); + throw new KeyManagementException("Failed to store key metadata: " + e.getMessage(), e); + } + } + + /** + * Retrieves key metadata from the database for the specified key ID. + * + * @param keyId Key ID to retrieve metadata for + * @return Optional containing key metadata if found + * @throws KeyManagementException if retrieval fails + */ + public Optional getKeyMetadata(String keyId) throws KeyManagementException { + Objects.requireNonNull(keyId, "Key ID cannot be null"); + + LOGGER.finest(()->String.format("Retrieving key metadata for key ID: %s", keyId)); + + try (Connection conn = pluginService.forceConnect(pluginService.getCurrentHostSpec(), pluginService.getProperties()); + PreparedStatement stmt = conn.prepareStatement(getSelectKeyMetadataSql())) { + + stmt.setString(1, keyId); + + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + KeyMetadata metadata = KeyMetadata.builder() + .keyId(rs.getString("key_id")) + .masterKeyArn(rs.getString("master_key_arn")) + .encryptedDataKey(rs.getString("encrypted_data_key")) + .keySpec(rs.getString("key_spec")) + .createdAt(rs.getTimestamp("created_at").toInstant()) + .lastUsedAt(rs.getTimestamp("last_used_at").toInstant()) + .build(); + + LOGGER.finest(()->String.format("Successfully retrieved key metadata for key ID: %s", keyId)); + return Optional.of(metadata); + } else { + LOGGER.finest(()->String.format("No key metadata found for key ID: %s", keyId)); + return Optional.empty(); + } + } + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Database error retrieving key metadata for key ID: %s", keyId, e)); + throw new KeyManagementException("Failed to retrieve key metadata: " + e.getMessage(), e); + } + } + + /** + * Updates the last used timestamp for the specified key. + * + * @param keyId Key ID to update + * @throws KeyManagementException if update fails + */ + public void updateLastUsed(String keyId) throws KeyManagementException { + Objects.requireNonNull(keyId, "Key ID cannot be null"); + + try (Connection conn = pluginService.forceConnect(pluginService.getCurrentHostSpec(), pluginService.getProperties()); + PreparedStatement stmt = conn.prepareStatement(getUpdateLastUsedSql())) { + + stmt.setTimestamp(1, Timestamp.from(Instant.now())); + stmt.setString(2, keyId); + + stmt.executeUpdate(); + + } catch (SQLException e) { + LOGGER.severe(()->String.format("Database error updating last used timestamp for key ID: %s %s", keyId, e.getMessage())); + throw new KeyManagementException("Failed to update last used timestamp: " + e.getMessage(), e); + } + } + + /** + * Generates a unique key ID for new keys to store in the database. + * + * @return Unique key ID + */ + public String generateKeyId() { + return UUID.randomUUID().toString(); + } + + /** + * Returns the data key cache for metrics and management. + * + * @return Data key cache instance + */ + public DataKeyCache getDataKeyCache() { + return dataKeyCache; + } + + /** + * Clears the data key cache. + */ + public void clearCache() { + dataKeyCache.clear(); + LOGGER.info(()->"Data key cache cleared"); + } + + /** + * Shuts down the key manager and cleans up resources. + */ + public void shutdown() { + LOGGER.info(()->"Shutting down KeyManager"); + dataKeyCache.shutdown(); + } + + /** + * Executes a KMS operation with retry logic and exponential backoff. + */ + private T executeWithRetry(KmsOperation operation) throws Exception { + Exception lastException = null; + int maxRetries = config.getMaxRetries(); + + for (int attempt = 0; attempt <= maxRetries; attempt++) { + try { + return operation.execute(); + } catch (Exception e) { + lastException = e; + + if (attempt == maxRetries) { + break; + } + + if (isRetryableException(e)) { + long backoffMs = calculateBackoff(attempt); + int finalAttempt = attempt; + LOGGER.warning(()->String.format("KMS operation failed (attempt %s/%s), retrying in %sms: %s", + finalAttempt + 1, maxRetries + 1, backoffMs, e.getMessage())); + + try { + Thread.sleep(backoffMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new KeyManagementException("Operation interrupted during retry", ie); + } + } else { + // Non-retryable exception, fail immediately + break; + } + } + } + + throw lastException; + } + + /** + * Determines if an exception is retryable. + */ + private boolean isRetryableException(Exception e) { + if (e instanceof KmsException) { + KmsException kmsException = (KmsException) e; + // Retry on throttling, service unavailable, and internal errors + boolean isServerError = kmsException.statusCode() >= 500; + boolean isThrottling = kmsException.statusCode() == 429; + + // Check error code if available + boolean isThrottlingError = false; + if (kmsException.awsErrorDetails() != null && kmsException.awsErrorDetails().errorCode() != null) { + isThrottlingError = "ThrottlingException".equals(kmsException.awsErrorDetails().errorCode()); + } + + return isServerError || isThrottling || isThrottlingError; + } + + // Retry on general network/connection issues + return e instanceof java.net.ConnectException || + e instanceof java.net.SocketTimeoutException || + e instanceof java.io.IOException; + } + + /** + * Calculates exponential backoff with jitter. + */ + private long calculateBackoff(int attempt) { + long baseMs = config.getRetryBackoffBase().toMillis(); + long exponentialBackoff = baseMs * (1L << attempt); + + // Add jitter (±25% of the calculated backoff) + long jitter = (long) (exponentialBackoff * 0.25 * (ThreadLocalRandom.current().nextDouble() - 0.5) * 2); + + return Math.max(baseMs, exponentialBackoff + jitter); + } + + /** + * Creates a cache key from an encrypted data key. + */ + private String createCacheKey(String encryptedDataKey) { + // Use a hash of the encrypted data key as cache key for security + return "datakey_" + Math.abs(encryptedDataKey.hashCode()); + } + + /** + * Functional interface for KMS operations that can be retried. + */ + @FunctionalInterface + private interface KmsOperation { + T execute() throws Exception; + } + + /** + * Result class for data key generation operations. + */ + public static class DataKeyResult { + private final byte[] plaintextKey; + private final String encryptedKey; + + public DataKeyResult(byte[] plaintextKey, String encryptedKey) { + this.plaintextKey = Objects.requireNonNull(plaintextKey, "Plaintext key cannot be null"); + this.encryptedKey = Objects.requireNonNull(encryptedKey, "Encrypted key cannot be null"); + } + + public byte[] getPlaintextKey() { + return plaintextKey.clone(); // Return copy for security + } + + public String getEncryptedKey() { + return encryptedKey; + } + + /** + * Clears the plaintext key from memory for security. + */ + public void clearPlaintextKey() { + if (plaintextKey != null) { + java.util.Arrays.fill(plaintextKey, (byte) 0); + } + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/logging/AuditLogger.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/logging/AuditLogger.java new file mode 100644 index 000000000..21f50aa37 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/logging/AuditLogger.java @@ -0,0 +1,468 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.logging; + +import java.util.logging.Logger; +import org.slf4j.MDC; + +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Audit LOGGER for KMS operations and encryption activities. + * Provides structured logging without exposing sensitive data. + */ +public class AuditLogger { + + private static final Logger auditLogger = Logger.getLogger(AuditLogger.class.getName()); + + // Thread-local context for audit information + private static final ThreadLocal> auditContext = + ThreadLocal.withInitial(ConcurrentHashMap::new); + + private final boolean auditEnabled; + + public AuditLogger(boolean auditEnabled) { + this.auditEnabled = auditEnabled; + } + + /** + * Sets audit context information for the current thread. + * + * @param key Context key + * @param value Context value + */ + public static void setContext(String key, String value) { + auditContext.get().put(key, value); + MDC.put(key, value); + } + + /** + * Clears audit context for the current thread. + */ + public static void clearContext() { + auditContext.get().clear(); + MDC.clear(); + } + + /** + * Logs KMS key creation operation. + * + * @param masterKeyArn Master key ARN + * @param description Key description + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logKeyCreation(String masterKeyArn, String description, boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "CREATE_MASTER_KEY"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("KMS master key created successfully - ARN: %s, Description: %s", + sanitizeArn(masterKeyArn), sanitizeDescription(description))); + } else { + auditLogger.warning(()->String.format("KMS master key creation failed - Description: %s, Error: %s", + sanitizeDescription(description), sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs data key generation operation. + * + * @param masterKeyArn Master key ARN + * @param keyId Key ID + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logDataKeyGeneration(String masterKeyArn, String keyId, boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "GENERATE_DATA_KEY"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("Data key generated successfully - Master Key: %s, Key ID: %s", + sanitizeArn(masterKeyArn), sanitizeKeyId(keyId))); + } else { + auditLogger.warning(()->String.format("()->String.format(Data key generation failed - Master Key: %s, Error: %s", + sanitizeArn(masterKeyArn), sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs data key decryption operation. + * + * @param masterKeyArn Master key ARN + * @param keyId Key ID + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logDataKeyDecryption(String masterKeyArn, String keyId, boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "DECRYPT_DATA_KEY"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("Data key decrypted successfully - Master Key: %s, Key ID: %s", + sanitizeArn(masterKeyArn), sanitizeKeyId(keyId))); + } else { + auditLogger.warning(()->String.format("Data key decryption failed - Master Key: %s, Key ID: %s, Error: %s", + sanitizeArn(masterKeyArn), sanitizeKeyId(keyId), sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs encryption operation. + * + * @param tableName Table name + * @param columnName Column name + * @param keyId Key ID + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logEncryption(String tableName, String columnName, String keyId, boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "ENCRYPT_DATA"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("Data encrypted successfully - Table: %s, Column: %s, Key ID: %s", + sanitizeTableName(tableName), sanitizeColumnName(columnName), sanitizeKeyId(keyId))); + } else { + auditLogger.warning(()->String.format("Data encryption failed - Table: %s, Column: %s, Key ID: %s, Error: %s", + sanitizeTableName(tableName), sanitizeColumnName(columnName), + sanitizeKeyId(keyId), sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs decryption operation. + * + * @param tableName Table name + * @param columnName Column name + * @param keyId Key ID + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logDecryption(String tableName, String columnName, String keyId, boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "DECRYPT_DATA"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("Data decrypted successfully - Table: %s, Column: %s, Key ID: %s", + sanitizeTableName(tableName), sanitizeColumnName(columnName), sanitizeKeyId(keyId))); + } else { + auditLogger.warning(()->String.format("Data decryption failed - Table: %s, Column: %s, Key ID: %s, Error: %s", + sanitizeTableName(tableName), sanitizeColumnName(columnName), + sanitizeKeyId(keyId), sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs metadata operations. + * + * @param operation Operation type + * @param tableName Table name + * @param columnName Column name + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logMetadataOperation(String operation, String tableName, String columnName, + boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "METADATA_" + operation.toUpperCase()); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("Metadata operation completed - Operation: %s, Table: %s, Column: %s", + operation, sanitizeTableName(tableName), sanitizeColumnName(columnName))); + } else { + auditLogger.warning(()->String.format("Metadata operation failed - Operation: %s, Table: %s, Column: %s, Error: %s", + operation, sanitizeTableName(tableName), sanitizeColumnName(columnName), + sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs configuration changes. + * + * @param configType Configuration type + * @param details Configuration details + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logConfigurationChange(String configType, String details, boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "CONFIG_CHANGE"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + + if (success) { + auditLogger.info(()->String.format("Configuration changed successfully - Type: %s, Details: %s", + configType, sanitizeConfigDetails(details))); + } else { + auditLogger.warning(()->String.format("Configuration change failed - Type: %s, Details: %s, Error: %s", + configType, sanitizeConfigDetails(details), sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs connection parameter extraction operations. + * + * @param strategy Extraction strategy + * @param connectionType Connection type + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + */ + public void logConnectionParameterExtraction(String strategy, String connectionType, + boolean success, String errorMessage) { + if (!auditEnabled) return; + + try { + setContext("operation", "CONNECTION_PARAMETER_EXTRACTION"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + setContext("strategy", strategy); + setContext("connectionType", connectionType); + + if (success) { + auditLogger.info(()->String.format("Connection parameter extraction successful - Strategy: %s, Type: %s", + strategy, connectionType)); + } else { + auditLogger.warning(()->String.format("Connection parameter extraction failed - Strategy: %s, Type: %s, Error: %s", + strategy, connectionType, sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs independent connection creation operations. + * + * @param jdbcUrl JDBC URL + * @param success Whether operation succeeded + * @param errorMessage Error message if failed + * @param usedFallback Whether fallback was used + */ + public void logIndependentConnectionCreation(String jdbcUrl, boolean success, String errorMessage, + boolean usedFallback) { + if (!auditEnabled) return; + + try { + setContext("operation", "INDEPENDENT_CONNECTION_CREATION"); + setContext("timestamp", Instant.now().toString()); + setContext("success", String.valueOf(success)); + setContext("usedFallback", String.valueOf(usedFallback)); + + String sanitizedUrl = sanitizeJdbcUrl(jdbcUrl); + + if (success) { + if (usedFallback) { + auditLogger.warning(()->String.format("Independent connection created using fallback - URL: %s", + sanitizedUrl)); + } else { + auditLogger.info(()->String.format("Independent connection created successfully - URL: %s", + sanitizedUrl)); + } + } else { + auditLogger.fine(()->String.format("Independent connection creation failed - URL: %s, Error: %s", + sanitizedUrl, sanitizeErrorMessage(errorMessage))); + } + } finally { + clearContext(); + } + } + + /** + * Logs connection sharing fallback activation. + * + * @param reason Reason for fallback + * @param originalFailure Original failure message + * @param isActive Whether fallback is active + */ + public void logConnectionSharingFallback(String reason, String originalFailure, boolean isActive) { + if (!auditEnabled) return; + + try { + setContext("operation", "CONNECTION_SHARING_FALLBACK"); + setContext("timestamp", Instant.now().toString()); + setContext("isActive", String.valueOf(isActive)); + + if (isActive) { + auditLogger.fine(()->String.format("CONNECTION SHARING FALLBACK ACTIVATED - Reason: %s, Original Failure: %s", + sanitizeErrorMessage(reason), sanitizeErrorMessage(originalFailure))); + auditLogger.fine(()->"WARNING: MetadataManager will share connections with client application!"); + auditLogger.fine(()->"This may cause connection closure issues when MetadataManager operations complete."); + } else { + auditLogger.info(()->String.format("Connection sharing fallback deactivated - Reason: %s", + sanitizeErrorMessage(reason))); + } + } finally { + clearContext(); + } + } + + /** + * Logs connection health monitoring events. + * + * @param dataSourceType Data source type + * @param isHealthy Whether connection is healthy + * @param successCount Number of successful connections + * @param failureCount Number of failed connections + * @param successRate Success rate as decimal + */ + public void logConnectionHealthCheck(String dataSourceType, boolean isHealthy, + long successCount, long failureCount, double successRate) { + if (!auditEnabled) return; + + try { + setContext("operation", "CONNECTION_HEALTH_CHECK"); + setContext("timestamp", Instant.now().toString()); + setContext("dataSourceType", dataSourceType); + setContext("isHealthy", String.valueOf(isHealthy)); + setContext("successCount", String.valueOf(successCount)); + setContext("failureCount", String.valueOf(failureCount)); + setContext("successRate", String.format("%.2f", successRate * 100)); + + if (isHealthy) { + auditLogger.info(()->String.format("Connection health check passed - Type: %s, Success Rate: {:.2f}%, " + + "Successful: %s, Failed: %s", + dataSourceType, successRate * 100, successCount, failureCount)); + } else { + auditLogger.warning(()->String.format("Connection health check failed - Type: %s, Success Rate: {:.2f}%, " + + "Successful: %s, Failed: %s", + dataSourceType, successRate * 100, successCount, failureCount)); + } + } finally { + clearContext(); + } + } + + // Sanitization methods to prevent sensitive data exposure + + private String sanitizeArn(String arn) { + if (arn == null) return "null"; + // Keep only the key ID part of the ARN for audit purposes + int lastSlash = arn.lastIndexOf('/'); + if (lastSlash != -1 && lastSlash < arn.length() - 1) { + return "arn:aws:kms:***:***:key/" + arn.substring(lastSlash + 1); + } + return "arn:aws:kms:***:***:key/***"; + } + + private String sanitizeKeyId(String keyId) { + if (keyId == null) return "null"; + // Show only first and last 4 characters of key ID + if (keyId.length() > 8) { + return keyId.substring(0, 4) + "***" + keyId.substring(keyId.length() - 4); + } + return "***"; + } + + private String sanitizeTableName(String tableName) { + if (tableName == null) return "null"; + // Table names are generally not sensitive, but limit length + return tableName.length() > 50 ? tableName.substring(0, 47) + "..." : tableName; + } + + private String sanitizeColumnName(String columnName) { + if (columnName == null) return "null"; + // Column names are generally not sensitive, but limit length + return columnName.length() > 50 ? columnName.substring(0, 47) + "..." : columnName; + } + + private String sanitizeDescription(String description) { + if (description == null) return "null"; + // Limit description length and remove potential sensitive patterns + String sanitized = description.replaceAll("(?i)(password|secret|key|token)=[^\\s]+", "$1=***"); + return sanitized.length() > 100 ? sanitized.substring(0, 97) + "..." : sanitized; + } + + private String sanitizeErrorMessage(String errorMessage) { + if (errorMessage == null) return "null"; + // Remove potential sensitive information from error messages + String sanitized = errorMessage + .replaceAll("(?i)(password|secret|key|token)=[^\\s]+", "$1=***") + .replaceAll("arn:aws:kms:[^:]+:[^:]+:key/[a-f0-9-]+", "arn:aws:kms:***:***:key/***"); + return sanitized.length() > 200 ? sanitized.substring(0, 197) + "..." : sanitized; + } + + private String sanitizeConfigDetails(String details) { + if (details == null) return "null"; + // Remove sensitive configuration values + String sanitized = details + .replaceAll("(?i)(password|secret|key|token|credential)=[^\\s,}]+", "$1=***") + .replaceAll("arn:aws:kms:[^:]+:[^:]+:key/[a-f0-9-]+", "arn:aws:kms:***:***:key/***"); + return sanitized.length() > 150 ? sanitized.substring(0, 147) + "..." : sanitized; + } + + private String sanitizeJdbcUrl(String jdbcUrl) { + if (jdbcUrl == null) return "null"; + + // Remove password parameters from URL + String sanitized = jdbcUrl.replaceAll("(?i)[?&]password=[^&]*", "?password=***") + .replaceAll("(?i)[?&]pwd=[^&]*", "?pwd=***") + .replaceAll("(?i)://[^:]+:[^@]+@", "://***:***@"); + + return sanitized; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/logging/ErrorContext.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/logging/ErrorContext.java new file mode 100644 index 000000000..dce652dfa --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/logging/ErrorContext.java @@ -0,0 +1,378 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.logging; + +import java.util.HashMap; +import java.util.Map; + +/** + * Utility class for building detailed error messages with context information. + * Helps provide clear error messages that include table/column information + * without exposing sensitive data. + */ +public class ErrorContext { + + private final Map context = new HashMap<>(); + + private ErrorContext(){} + + /** + * Creates a new error context builder. + * + * @return New ErrorContext instance + */ + public static ErrorContext builder() { + return new ErrorContext(); + } + + /** + * Adds table name to the error context. + * + * @param tableName Table name + * @return This ErrorContext instance for chaining + */ + public ErrorContext table(String tableName) { + context.put("table", tableName); + return this; + } + + /** + * Adds column name to the error context. + * + * @param columnName Column name + * @return This ErrorContext instance for chaining + */ + public ErrorContext column(String columnName) { + context.put("column", columnName); + return this; + } + + /** + * Adds operation type to the error context. + * + * @param operation Operation type + * @return This ErrorContext instance for chaining + */ + public ErrorContext operation(String operation) { + context.put("operation", operation); + return this; + } + + /** + * Adds key ID to the error context. + * + * @param keyId Key ID + * @return This ErrorContext instance for chaining + */ + public ErrorContext keyId(String keyId) { + context.put("keyId", sanitizeKeyId(keyId)); + return this; + } + + /** + * Adds master key ARN to the error context. + * + * @param masterKeyArn Master key ARN + * @return This ErrorContext instance for chaining + */ + public ErrorContext masterKeyArn(String masterKeyArn) { + context.put("masterKeyArn", sanitizeArn(masterKeyArn)); + return this; + } + + /** + * Adds algorithm to the error context. + * + * @param algorithm Algorithm name + * @return This ErrorContext instance for chaining + */ + public ErrorContext algorithm(String algorithm) { + context.put("algorithm", algorithm); + return this; + } + + /** + * Adds parameter index to the error context. + * + * @param parameterIndex Parameter index + * @return This ErrorContext instance for chaining + */ + public ErrorContext parameterIndex(int parameterIndex) { + context.put("parameterIndex", parameterIndex); + return this; + } + + /** + * Adds column index to the error context. + * + * @param columnIndex Column index + * @return This ErrorContext instance for chaining + */ + public ErrorContext columnIndex(int columnIndex) { + context.put("columnIndex", columnIndex); + return this; + } + + /** + * Adds SQL statement to the error context (sanitized). + * + * @param sql SQL statement + * @return This ErrorContext instance for chaining + */ + public ErrorContext sql(String sql) { + context.put("sql", sanitizeSql(sql)); + return this; + } + + /** + * Adds data type to the error context. + * + * @param dataType Data type + * @return This ErrorContext instance for chaining + */ + public ErrorContext dataType(String dataType) { + context.put("dataType", dataType); + return this; + } + + /** + * Adds retry attempt information to the error context. + * + * @param attempt Current attempt number + * @param maxAttempts Maximum number of attempts + * @return This ErrorContext instance for chaining + */ + public ErrorContext retryAttempt(int attempt, int maxAttempts) { + context.put("retryAttempt", attempt); + context.put("maxRetryAttempts", maxAttempts); + return this; + } + + /** + * Adds cache information to the error context. + * + * @param cacheType Type of cache + * @param cacheHit Whether cache was hit + * @return This ErrorContext instance for chaining + */ + public ErrorContext cacheInfo(String cacheType, boolean cacheHit) { + context.put("cacheType", cacheType); + context.put("cacheHit", cacheHit); + return this; + } + + /** + * Builds an error message with the provided base message and context. + * + * @param baseMessage Base error message + * @return Formatted error message with context + */ + public String buildMessage(String baseMessage) { + if (context.isEmpty()) { + return baseMessage; + } + + StringBuilder sb = new StringBuilder(baseMessage); + sb.append(" [Context: "); + + boolean first = true; + for (Map.Entry entry : context.entrySet()) { + if (!first) { + sb.append(", "); + } + sb.append(entry.getKey()).append("=").append(entry.getValue()); + first = false; + } + + sb.append("]"); + return sb.toString(); + } + + /** + * Builds an error message for encryption operations. + * + * @param baseMessage Base error message + * @return Formatted encryption error message + */ + public String buildEncryptionErrorMessage(String baseMessage) { + StringBuilder sb = new StringBuilder("Encryption failed"); + + if (baseMessage != null && !baseMessage.trim().isEmpty()) { + sb.append(": ").append(baseMessage); + } + + addContextualInfo(sb); + return sb.toString(); + } + + /** + * Builds an error message for decryption operations. + * + * @param baseMessage Base error message + * @return Formatted decryption error message + */ + public String buildDecryptionErrorMessage(String baseMessage) { + StringBuilder sb = new StringBuilder("Decryption failed"); + + if (baseMessage != null && !baseMessage.trim().isEmpty()) { + sb.append(": ").append(baseMessage); + } + + addContextualInfo(sb); + return sb.toString(); + } + + /** + * Builds an error message for key management operations. + * + * @param baseMessage Base error message + * @return Formatted key management error message + */ + public String buildKeyManagementErrorMessage(String baseMessage) { + StringBuilder sb = new StringBuilder("Key management operation failed"); + + if (baseMessage != null && !baseMessage.trim().isEmpty()) { + sb.append(": ").append(baseMessage); + } + + addContextualInfo(sb); + return sb.toString(); + } + + /** + * Builds an error message for metadata operations. + * + * @param baseMessage Base error message + * @return Formatted metadata error message + */ + public String buildMetadataErrorMessage(String baseMessage) { + StringBuilder sb = new StringBuilder("Metadata operation failed"); + + if (baseMessage != null && !baseMessage.trim().isEmpty()) { + sb.append(": ").append(baseMessage); + } + + addContextualInfo(sb); + return sb.toString(); + } + + /** + * Gets the context map for external use. + * + * @return Copy of the context map + */ + public Map getContext() { + return new HashMap<>(context); + } + + /** + * Adds contextual information to the error message. + */ + private void addContextualInfo(StringBuilder sb) { + // Add table.column information if available + String table = (String) context.get("table"); + String column = (String) context.get("column"); + + if (table != null && column != null) { + sb.append(" for column ").append(table).append(".").append(column); + } else if (table != null) { + sb.append(" for table ").append(table); + } else if (column != null) { + sb.append(" for column ").append(column); + } + + // Add operation information if available + String operation = (String) context.get("operation"); + if (operation != null) { + sb.append(" during ").append(operation); + } + + // Add parameter/column index information if available + Integer paramIndex = (Integer) context.get("parameterIndex"); + Integer colIndex = (Integer) context.get("columnIndex"); + + if (paramIndex != null) { + sb.append(" (parameter index: ").append(paramIndex).append(")"); + } else if (colIndex != null) { + sb.append(" (column index: ").append(colIndex).append(")"); + } + + // Add retry information if available + Integer retryAttempt = (Integer) context.get("retryAttempt"); + Integer maxRetries = (Integer) context.get("maxRetryAttempts"); + + if (retryAttempt != null && maxRetries != null) { + sb.append(" (retry ").append(retryAttempt).append("/").append(maxRetries).append(")"); + } + + // Add additional context in brackets + Map additionalContext = new HashMap<>(); + for (Map.Entry entry : context.entrySet()) { + String key = entry.getKey(); + if (!key.equals("table") && !key.equals("column") && !key.equals("operation") && + !key.equals("parameterIndex") && !key.equals("columnIndex") && + !key.equals("retryAttempt") && !key.equals("maxRetryAttempts")) { + additionalContext.put(key, entry.getValue()); + } + } + + if (!additionalContext.isEmpty()) { + sb.append(" ["); + boolean first = true; + for (Map.Entry entry : additionalContext.entrySet()) { + if (!first) { + sb.append(", "); + } + sb.append(entry.getKey()).append("=").append(entry.getValue()); + first = false; + } + sb.append("]"); + } + } + + // Sanitization methods + + private String sanitizeKeyId(String keyId) { + if (keyId == null) return null; + // Show only first and last 4 characters of key ID + if (keyId.length() > 8) { + return keyId.substring(0, 4) + "***" + keyId.substring(keyId.length() - 4); + } + return "***"; + } + + private String sanitizeArn(String arn) { + if (arn == null) return null; + // Keep only the key ID part of the ARN + int lastSlash = arn.lastIndexOf('/'); + if (lastSlash != -1 && lastSlash < arn.length() - 1) { + return "arn:aws:kms:***:***:key/" + arn.substring(lastSlash + 1); + } + return "arn:aws:kms:***:***:key/***"; + } + + private String sanitizeSql(String sql) { + if (sql == null) return null; + // Remove potential sensitive data from SQL and limit length + String sanitized = sql + .replaceAll("'[^']*'", "'***'") // Replace string literals + .replaceAll("\\b\\d+\\b", "***"); // Replace numeric literals + + return sanitized.length() > 100 ? sanitized.substring(0, 97) + "..." : sanitized; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/metadata/MetadataException.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/metadata/MetadataException.java new file mode 100644 index 000000000..d09ebc3bd --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/metadata/MetadataException.java @@ -0,0 +1,260 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.metadata; + +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +/** + * Exception thrown when metadata operations fail, such as loading encryption + * configuration from database or cache operations. + * Provides enhanced error context information for better troubleshooting. + */ +public class MetadataException extends SQLException { + + private static final long serialVersionUID = 1L; + + // SQL State codes for different metadata error types + public static final String METADATA_LOAD_FAILED_STATE = "META01"; + public static final String METADATA_CACHE_FAILED_STATE = "META02"; + public static final String METADATA_REFRESH_FAILED_STATE = "META03"; + public static final String METADATA_LOOKUP_FAILED_STATE = "META04"; + public static final String METADATA_VALIDATION_FAILED_STATE = "META05"; + + private final Map errorContext = new HashMap<>(); + + /** + * Constructs a MetadataException with the specified detail message. + * + * @param message the detail message + */ + public MetadataException(String message) { + super(message, METADATA_LOOKUP_FAILED_STATE); + } + + /** + * Constructs a MetadataException with the specified detail message and cause. + * + * @param message the detail message + * @param cause the cause of this exception + */ + public MetadataException(String message, Throwable cause) { + super(message, METADATA_LOOKUP_FAILED_STATE, cause); + } + + /** + * Constructs a MetadataException with the specified cause. + * + * @param cause the cause of this exception + */ + public MetadataException(Throwable cause) { + super(cause.getMessage(), METADATA_LOOKUP_FAILED_STATE, cause); + } + + /** + * Constructs a MetadataException with the specified detail message, cause, + * SQL state, and vendor code. + * + * @param message the detail message + * @param sqlState the SQL state + * @param vendorCode the vendor-specific error code + * @param cause the cause of this exception + */ + public MetadataException(String message, String sqlState, int vendorCode, Throwable cause) { + super(message, sqlState, vendorCode, cause); + } + + /** + * Constructs a MetadataException with the specified detail message, SQL state and cause. + * + * @param message the detail message + * @param sqlState the SQL state + * @param cause the cause of this exception + */ + public MetadataException(String message, String sqlState, Throwable cause) { + super(message, sqlState, cause); + } + + /** + * Adds context information to the exception. + * + * @param key the context key + * @param value the context value + * @return this exception for method chaining + */ + public MetadataException withContext(String key, Object value) { + errorContext.put(key, value); + return this; + } + + /** + * Adds table name to the error context. + * + * @param tableName the table name + * @return this exception for method chaining + */ + public MetadataException withTable(String tableName) { + return withContext("table", tableName); + } + + /** + * Adds column name to the error context. + * + * @param columnName the column name + * @return this exception for method chaining + */ + public MetadataException withColumn(String columnName) { + return withContext("column", columnName); + } + + /** + * Adds operation type to the error context. + * + * @param operation the operation being performed + * @return this exception for method chaining + */ + public MetadataException withOperation(String operation) { + return withContext("operation", operation); + } + + /** + * Adds cache information to the error context. + * + * @param cacheSize the current cache size + * @param cacheHit whether this was a cache hit or miss + * @return this exception for method chaining + */ + public MetadataException withCacheInfo(int cacheSize, boolean cacheHit) { + return withContext("cacheSize", cacheSize).withContext("cacheHit", cacheHit); + } + + /** + * Adds SQL query information to the error context (sanitized). + * + * @param sql the SQL query + * @return this exception for method chaining + */ + public MetadataException withSql(String sql) { + return withContext("sql", sanitizeSql(sql)); + } + + /** + * Gets the error context map. + * + * @return a copy of the error context + */ + public Map getErrorContext() { + return new HashMap<>(errorContext); + } + + /** + * Gets a formatted error message including context information. + * + * @return formatted error message with context + */ + public String getDetailedMessage() { + if (errorContext.isEmpty()) { + return getMessage(); + } + + StringBuilder sb = new StringBuilder(getMessage()); + sb.append(" [Context: "); + + boolean first = true; + for (Map.Entry entry : errorContext.entrySet()) { + if (!first) { + sb.append(", "); + } + sb.append(entry.getKey()).append("=").append(entry.getValue()); + first = false; + } + + sb.append("]"); + return sb.toString(); + } + + /** + * Creates a MetadataException for metadata loading failures. + * + * @param message Error message + * @param cause Root cause + * @return New MetadataException instance + */ + public static MetadataException loadFailed(String message, Throwable cause) { + return new MetadataException(message, METADATA_LOAD_FAILED_STATE, cause); + } + + /** + * Creates a MetadataException for cache operation failures. + * + * @param message Error message + * @param cause Root cause + * @return New MetadataException instance + */ + public static MetadataException cacheFailed(String message, Throwable cause) { + return new MetadataException(message, METADATA_CACHE_FAILED_STATE, cause); + } + + /** + * Creates a MetadataException for metadata refresh failures. + * + * @param message Error message + * @param cause Root cause + * @return New MetadataException instance + */ + public static MetadataException refreshFailed(String message, Throwable cause) { + return new MetadataException(message, METADATA_REFRESH_FAILED_STATE, cause); + } + + /** + * Creates a MetadataException for metadata lookup failures. + * + * @param tableName Table name + * @param columnName Column name + * @param cause Root cause + * @return New MetadataException instance + */ + public static MetadataException lookupFailed(String tableName, String columnName, Throwable cause) { + return new MetadataException("Failed to lookup metadata", METADATA_LOOKUP_FAILED_STATE, cause) + .withTable(tableName) + .withColumn(columnName); + } + + /** + * Creates a MetadataException for metadata validation failures. + * + * @param message Error message + * @return New MetadataException instance + */ + public static MetadataException validationFailed(String message) { + return new MetadataException(message, METADATA_VALIDATION_FAILED_STATE, null); + } + + // Sanitization methods + + private String sanitizeSql(String sql) { + if (sql == null) return null; + // Remove potential sensitive data from SQL and limit length + String sanitized = sql + .replaceAll("'[^']*'", "'***'") // Replace string literals + .replaceAll("\\b\\d+\\b", "***"); // Replace numeric literals + + return sanitized.length() > 100 ? sanitized.substring(0, 97) + "..." : sanitized; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/metadata/MetadataManager.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/metadata/MetadataManager.java new file mode 100644 index 000000000..ff0984cf6 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/metadata/MetadataManager.java @@ -0,0 +1,462 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.metadata; + +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.plugin.encryption.model.ColumnEncryptionConfig; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import software.amazon.jdbc.plugin.encryption.model.KeyMetadata; +import java.util.logging.Logger; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * Manages encryption metadata by loading configuration from database tables, + * providing caching mechanisms, and offering lookup methods for column encryption settings. + */ +public class MetadataManager { + + private static final Logger LOGGER = Logger.getLogger(MetadataManager.class.getName()); + + private final PluginService pluginService; + private volatile EncryptionConfig config; + private final Map metadataCache; + private final ReadWriteLock cacheLock; + private volatile Instant lastRefreshTime; + private volatile ScheduledExecutorService refreshExecutor; + + public MetadataManager(PluginService pluginService, EncryptionConfig config) { + this.pluginService = pluginService; + this.config = config; + this.metadataCache = new ConcurrentHashMap<>(); + this.cacheLock = new ReentrantReadWriteLock(); + this.lastRefreshTime = Instant.EPOCH; + this.refreshExecutor = createRefreshExecutor(); + } + + private String getLoadEncryptionMetadataSql() { + String schema = config.getEncryptionMetadataSchema(); + return "SELECT em.table_name, em.column_name, em.encryption_algorithm, em.key_id, " + + " em.created_at, em.updated_at, " + + " ks.name, ks.master_key_arn, ks.encrypted_data_key, ks.hmac_key, ks.key_spec, " + + " ks.created_at as key_created_at, ks.last_used_at " + + "FROM " + schema + ".encryption_metadata em " + + "JOIN " + schema + ".key_storage ks ON em.key_id = ks.id " + + "ORDER BY em.table_name, em.column_name"; + } + + private String getCheckColumnEncryptedSql() { + return "SELECT 1 FROM " + config.getEncryptionMetadataSchema() + ".encryption_metadata " + + "WHERE table_name = ? AND column_name = ?"; + } + + private String getColumnConfigSql() { + String schema = config.getEncryptionMetadataSchema(); + return "SELECT em.table_name, em.column_name, em.encryption_algorithm, em.key_id, " + + " em.created_at, em.updated_at, " + + " ks.master_key_arn, ks.encrypted_data_key, ks.hmac_key, ks.key_spec, " + + " ks.created_at as key_created_at, ks.last_used_at " + + "FROM " + schema + ".encryption_metadata em " + + "JOIN " + schema + ".key_storage ks ON em.key_id = ks.id " + + "WHERE em.table_name = ? AND em.column_name = ?"; + } + + /** + * Loads encryption metadata from database tables and returns a map of column configurations. + * + * @return Map of column identifiers to ColumnEncryptionConfig objects + * @throws MetadataException if database operations fail + */ + public Map loadEncryptionMetadata() throws MetadataException { + LOGGER.finest(()->"Loading encryption metadata from database"); + + Map metadata = new ConcurrentHashMap<>(); + + try (Connection connection = pluginService.forceConnect(pluginService.getCurrentHostSpec(), pluginService.getProperties()); + PreparedStatement stmt = connection.prepareStatement(getLoadEncryptionMetadataSql()); + ResultSet rs = stmt.executeQuery()) { + + while (rs.next()) { + ColumnEncryptionConfig columnConfig = buildColumnConfigFromResultSet(rs); + String columnIdentifier = columnConfig.getColumnIdentifier(); + metadata.put(columnIdentifier, columnConfig); + + LOGGER.finest(()->String.format("Loaded encryption config for column: %s", columnIdentifier)); + } + + LOGGER.info(()->String.format("Successfully loaded %d encryption configurations", metadata.size())); + + } catch (SQLException e) { + String errorMsg = "Failed to load encryption metadata from database"; + LOGGER.severe(()->errorMsg + e.getMessage()); + throw new MetadataException(errorMsg, e); + } + + return metadata; + } + + /** + * Refreshes the metadata cache by reloading from the database. + * This method is thread-safe and can be called without application restart. + * + * @throws MetadataException if refresh operation fails + */ + public void refreshMetadata() throws MetadataException { + LOGGER.info("Refreshing encryption metadata cache"); + + cacheLock.writeLock().lock(); + try { + Map newMetadata = loadEncryptionMetadata(); + + // Clear existing cache and populate with new data + metadataCache.clear(); + metadataCache.putAll(newMetadata); + lastRefreshTime = Instant.now(); + + LOGGER.info(()->String.format("Metadata cache refreshed successfully with %s configurations", + metadataCache.size())); + + } finally { + cacheLock.writeLock().unlock(); + } + } + + /** + * Checks if a specific column is configured for encryption. + * Uses cache if available and valid, otherwise queries database directly. + * + * @param tableName the table name + * @param columnName the column name + * @return true if column is encrypted, false otherwise + * @throws MetadataException if database operations fail + */ + public boolean isColumnEncrypted(String tableName, String columnName) throws MetadataException { + if (tableName == null || columnName == null) { + return false; + } + + String columnIdentifier = tableName + "." + columnName; + + // Try cache first if caching is enabled + if (config.isCacheEnabled() && isCacheValid()) { + cacheLock.readLock().lock(); + try { + boolean result = metadataCache.containsKey(columnIdentifier); + LOGGER.finest(()->String.format("Cache lookup for column %s: %s", columnIdentifier, result)); + return result; + } finally { + cacheLock.readLock().unlock(); + } + } + + // Fallback to database query + return isColumnEncryptedFromDatabase(tableName, columnName); + } + + /** + * Retrieves the encryption configuration for a specific column. + * Uses cache if available and valid, otherwise queries database directly. + * + * @param tableName the table name + * @param columnName the column name + * @return ColumnEncryptionConfig if found, null otherwise + * @throws MetadataException if database operations fail + */ + public ColumnEncryptionConfig getColumnConfig(String tableName, String columnName) + throws MetadataException { + if (tableName == null || columnName == null) { + return null; + } + + String columnIdentifier = tableName + "." + columnName; + + // Try cache first if caching is enabled + if (config.isCacheEnabled() && isCacheValid()) { + cacheLock.readLock().lock(); + try { + ColumnEncryptionConfig result = metadataCache.get(columnIdentifier); + LOGGER.finest(()->String.format("Cache lookup for column config %s: %s", + columnIdentifier, result != null ? "found" : "not found")); + return result; + } finally { + cacheLock.readLock().unlock(); + } + } + + // Fallback to database query + return getColumnConfigFromDatabase(tableName, columnName); + } + + /** + * Initializes the metadata cache by loading all configurations from database. + * Should be called during plugin initialization. + * + * @throws MetadataException if initialization fails + */ + public void initialize() throws MetadataException { + LOGGER.info("Initializing MetadataManager"); + + if (config.isCacheEnabled()) { + refreshMetadata(); + } + + // Start automatic refresh if configured + startAutomaticRefresh(); + + LOGGER.info("MetadataManager initialized successfully"); + } + + /** + * Updates the configuration and adjusts refresh behavior accordingly. + * + * @param newConfig New encryption configuration + */ + public void updateConfig(EncryptionConfig newConfig) { + EncryptionConfig oldConfig = this.config; + this.config = newConfig; + + // Restart automatic refresh if interval changed + if (!oldConfig.getMetadataRefreshInterval().equals(newConfig.getMetadataRefreshInterval())) { + stopAutomaticRefresh(); + startAutomaticRefresh(); + } + + LOGGER.info("MetadataManager configuration updated"); + } + + /** + * Shuts down the metadata manager and cleans up resources. + */ + public void shutdown() { + LOGGER.info("Shutting down MetadataManager"); + + stopAutomaticRefresh(); + + // Clear cache + cacheLock.writeLock().lock(); + try { + metadataCache.clear(); + } finally { + cacheLock.writeLock().unlock(); + } + + LOGGER.info("MetadataManager shutdown completed"); + } + + /** + * Returns the timestamp of the last cache refresh. + * + * @return Instant of last refresh, or Instant.EPOCH if never refreshed + */ + public Instant getLastRefreshTime() { + return lastRefreshTime; + } + + /** + * Returns the current size of the metadata cache. + * + * @return number of cached configurations + */ + public int getCacheSize() { + cacheLock.readLock().lock(); + try { + return metadataCache.size(); + } finally { + cacheLock.readLock().unlock(); + } + } + + /** + * Checks if the cache is valid based on expiration time. + * + * @return true if cache is valid, false if expired or never initialized + */ + private boolean isCacheValid() { + if (lastRefreshTime.equals(Instant.EPOCH)) { + return false; + } + + Instant expirationTime = lastRefreshTime.plusSeconds(config.getCacheExpirationMinutes() * 60L); + return Instant.now().isBefore(expirationTime); + } + + /** + * Queries database directly to check if column is encrypted. + */ + private boolean isColumnEncryptedFromDatabase(String tableName, String columnName) + throws MetadataException { + LOGGER.finest(()->String.format("Checking encryption status for column %s.%s from database", tableName, columnName)); + + try (Connection connection = pluginService.forceConnect(pluginService.getCurrentHostSpec(), pluginService.getProperties()); + PreparedStatement stmt = connection.prepareStatement(getCheckColumnEncryptedSql())) { + + stmt.setString(1, tableName); + stmt.setString(2, columnName); + + try (ResultSet rs = stmt.executeQuery()) { + boolean result = rs.next(); + LOGGER.finest(()->String.format("Database lookup for column %s.%s: %s", tableName, columnName, result)); + return result; + } + + } catch (SQLException e) { + String errorMsg = String.format("Failed to check encryption status for column %s.%s", + tableName, columnName); + LOGGER.severe(()->errorMsg + e); + throw new MetadataException(errorMsg, e); + } + } + + /** + * Queries database directly to get column configuration. + */ + private ColumnEncryptionConfig getColumnConfigFromDatabase(String tableName, String columnName) + throws MetadataException { + LOGGER.finest(()->String.format("Loading encryption config for column %s.%s from database", tableName, columnName)); + + try (Connection connection = pluginService.forceConnect(pluginService.getCurrentHostSpec(), pluginService.getProperties()); + PreparedStatement stmt = connection.prepareStatement(getColumnConfigSql())) { + + stmt.setString(1, tableName); + stmt.setString(2, columnName); + + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + ColumnEncryptionConfig result = buildColumnConfigFromResultSet(rs); + LOGGER.finest(()->String.format("Database lookup for column config %s.%s: found", tableName, columnName)); + return result; + } else { + LOGGER.finest(()->String.format("Database lookup for column config %s.%s: not found", tableName, columnName)); + return null; + } + } + + } catch (SQLException e) { + String errorMsg = String.format("Failed to load encryption config for column %s.%s", + tableName, columnName); + LOGGER.severe(()->errorMsg + " " + e.getMessage()); + throw new MetadataException(errorMsg, e); + } + } + + /** + * Builds a ColumnEncryptionConfig from a ResultSet row. + */ + private ColumnEncryptionConfig buildColumnConfigFromResultSet(ResultSet rs) throws SQLException { + // Build KeyMetadata + KeyMetadata keyMetadata = KeyMetadata.builder() + .keyId(rs.getString("key_id")) + .keyName(rs.getString("name")) + .masterKeyArn(rs.getString("master_key_arn")) + .encryptedDataKey(rs.getString("encrypted_data_key")) + .hmacKey(rs.getBytes("hmac_key")) + .keySpec(rs.getString("key_spec")) + .createdAt(convertTimestampToInstant(rs.getTimestamp("key_created_at"))) + .lastUsedAt(convertTimestampToInstant(rs.getTimestamp("last_used_at"))) + .build(); + + // Build ColumnEncryptionConfig + return ColumnEncryptionConfig.builder() + .tableName(rs.getString("table_name")) + .columnName(rs.getString("column_name")) + .algorithm(rs.getString("encryption_algorithm")) + .keyId(rs.getString("key_id")) + .keyMetadata(keyMetadata) + .createdAt(convertTimestampToInstant(rs.getTimestamp("created_at"))) + .updatedAt(convertTimestampToInstant(rs.getTimestamp("updated_at"))) + .build(); + } + + /** + * Converts SQL Timestamp to Instant, handling null values. + */ + private Instant convertTimestampToInstant(Timestamp timestamp) { + return timestamp != null ? timestamp.toInstant() : Instant.now(); + } + + /** + * Creates a new refresh executor. + */ + private ScheduledExecutorService createRefreshExecutor() { + return Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "MetadataManager-Refresh"); + t.setDaemon(true); + return t; + }); + } + + /** + * Stops automatic metadata refresh. + */ + private void stopAutomaticRefresh() { + if (refreshExecutor != null && !refreshExecutor.isShutdown()) { + LOGGER.finest(()->String.format("Stopping automatic metadata refresh")); + refreshExecutor.shutdown(); + try { + if (!refreshExecutor.awaitTermination(2, TimeUnit.SECONDS)) { + refreshExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + refreshExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + /** + * Starts automatic metadata refresh based on configuration. + */ + private void startAutomaticRefresh() { + Duration refreshInterval = config.getMetadataRefreshInterval(); + + if (refreshInterval.isZero() || refreshInterval.isNegative()) { + LOGGER.info(()->String.format("Automatic metadata refresh disabled (interval: %s)", refreshInterval)); + return; + } + + // Create new executor if current one is shut down + if (refreshExecutor == null || refreshExecutor.isShutdown()) { + refreshExecutor = createRefreshExecutor(); + } + + long intervalMs = refreshInterval.toMillis(); + refreshExecutor.scheduleAtFixedRate(() -> { + try { + LOGGER.finest(()->"Performing automatic metadata refresh"); + refreshMetadata(); + } catch (Exception e) { + LOGGER.warning(()->String.format("Automatic metadata refresh failed", e.getMessage())); + } + }, intervalMs, intervalMs, TimeUnit.MILLISECONDS); + + LOGGER.info(()->String.format("Started automatic metadata refresh every %sms", intervalMs)); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/ColumnEncryptionConfig.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/ColumnEncryptionConfig.java new file mode 100644 index 000000000..d9a656e26 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/ColumnEncryptionConfig.java @@ -0,0 +1,165 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.model; + +import java.time.Instant; +import java.util.Objects; + +/** + * Configuration class that represents encryption settings for a specific database column. + * Contains table/column mapping information and associated encryption metadata. + */ +public class ColumnEncryptionConfig { + + private final String tableName; + private final String columnName; + private final String algorithm; + private final String keyId; + private final KeyMetadata keyMetadata; + private final Instant createdAt; + private final Instant updatedAt; + + private ColumnEncryptionConfig(Builder builder) { + this.tableName = Objects.requireNonNull(builder.tableName, "tableName cannot be null"); + this.columnName = Objects.requireNonNull(builder.columnName, "columnName cannot be null"); + this.algorithm = Objects.requireNonNull(builder.algorithm, "algorithm cannot be null"); + this.keyId = Objects.requireNonNull(builder.keyId, "keyId cannot be null"); + this.keyMetadata = builder.keyMetadata; + this.createdAt = builder.createdAt != null ? builder.createdAt : Instant.now(); + this.updatedAt = builder.updatedAt != null ? builder.updatedAt : Instant.now(); + } + + public String getTableName() { + return tableName; + } + + public String getColumnName() { + return columnName; + } + + public String getAlgorithm() { + return algorithm; + } + + public String getKeyId() { + return keyId; + } + + public KeyMetadata getKeyMetadata() { + return keyMetadata; + } + + public Instant getCreatedAt() { + return createdAt; + } + + public Instant getUpdatedAt() { + return updatedAt; + } + + /** + * Returns a unique identifier for this column configuration. + * Format: "tableName.columnName" + * + * @return Column identifier string + */ + public String getColumnIdentifier() { + return tableName + "." + columnName; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ColumnEncryptionConfig that = (ColumnEncryptionConfig) o; + return Objects.equals(tableName, that.tableName) && + Objects.equals(columnName, that.columnName) && + Objects.equals(algorithm, that.algorithm) && + Objects.equals(keyId, that.keyId); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, columnName, algorithm, keyId); + } + + @Override + public String toString() { + return "ColumnEncryptionConfig{" + + "tableName='" + tableName + '\'' + + ", columnName='" + columnName + '\'' + + ", algorithm='" + algorithm + '\'' + + ", keyId='" + keyId + '\'' + + ", createdAt=" + createdAt + + ", updatedAt=" + updatedAt + + '}'; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String tableName; + private String columnName; + private String algorithm = "AES-256-GCM"; // Default algorithm + private String keyId; + private KeyMetadata keyMetadata; + private Instant createdAt; + private Instant updatedAt; + + public Builder tableName(String tableName) { + this.tableName = tableName; + return this; + } + + public Builder columnName(String columnName) { + this.columnName = columnName; + return this; + } + + public Builder algorithm(String algorithm) { + this.algorithm = algorithm; + return this; + } + + public Builder keyId(String keyId) { + this.keyId = keyId; + return this; + } + + public Builder keyMetadata(KeyMetadata keyMetadata) { + this.keyMetadata = keyMetadata; + return this; + } + + public Builder createdAt(Instant createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder updatedAt(Instant updatedAt) { + this.updatedAt = updatedAt; + return this; + } + + public ColumnEncryptionConfig build() { + return new ColumnEncryptionConfig(this); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/ConnectionParameters.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/ConnectionParameters.java new file mode 100644 index 000000000..db0d7a4df --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/ConnectionParameters.java @@ -0,0 +1,288 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.model; + +import java.util.Objects; +import java.util.Properties; + +/** + * Immutable data class that holds connection parameters extracted from a database connection. + * These parameters can be used to create independent connections to the same database. + */ +public class ConnectionParameters { + private final String jdbcUrl; + private final String username; + private final String password; + private final Properties connectionProperties; + private final String driverClassName; + private final String catalog; + private final String schema; + + private ConnectionParameters(Builder builder) { + this.jdbcUrl = builder.jdbcUrl; + this.username = builder.username; + this.password = builder.password; + this.connectionProperties = new Properties(); + if (builder.connectionProperties != null) { + this.connectionProperties.putAll(builder.connectionProperties); + } + this.driverClassName = builder.driverClassName; + this.catalog = builder.catalog; + this.schema = builder.schema; + } + + /** + * Gets the JDBC URL for the database connection. + * + * @return the JDBC URL, never null + */ + public String getJdbcUrl() { + return jdbcUrl; + } + + /** + * Gets the username for database authentication. + * + * @return the username, may be null if using other authentication methods + */ + public String getUsername() { + return username; + } + + /** + * Gets the password for database authentication. + * + * @return the password, may be null if using other authentication methods + */ + public String getPassword() { + return password; + } + + /** + * Gets additional connection properties. + * + * @return a copy of the connection properties, never null + */ + public Properties getConnectionProperties() { + return new Properties(connectionProperties); + } + + /** + * Gets the JDBC driver class name. + * + * @return the driver class name, may be null if not specified + */ + public String getDriverClassName() { + return driverClassName; + } + + /** + * Gets the database catalog name. + * + * @return the catalog name, may be null + */ + public String getCatalog() { + return catalog; + } + + /** + * Gets the database schema name. + * + * @return the schema name, may be null + */ + public String getSchema() { + return schema; + } + + /** + * Checks if this connection uses username/password authentication. + * + * @return true if both username and password are present, false otherwise + */ + public boolean hasCredentials() { + return username != null && password != null; + } + + /** + * Creates a new Builder instance for constructing ConnectionParameters. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a new Builder instance initialized with values from this instance. + * + * @return a new Builder instance with copied values + */ + public Builder toBuilder() { + return new Builder() + .jdbcUrl(this.jdbcUrl) + .username(this.username) + .password(this.password) + .connectionProperties(this.connectionProperties) + .driverClassName(this.driverClassName) + .catalog(this.catalog) + .schema(this.schema); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConnectionParameters that = (ConnectionParameters) o; + return Objects.equals(jdbcUrl, that.jdbcUrl) && + Objects.equals(username, that.username) && + Objects.equals(password, that.password) && + Objects.equals(connectionProperties, that.connectionProperties) && + Objects.equals(driverClassName, that.driverClassName) && + Objects.equals(catalog, that.catalog) && + Objects.equals(schema, that.schema); + } + + @Override + public int hashCode() { + return Objects.hash(jdbcUrl, username, password, connectionProperties, + driverClassName, catalog, schema); + } + + @Override + public String toString() { + return "ConnectionParameters{" + + "jdbcUrl='" + jdbcUrl + '\'' + + ", username='" + username + '\'' + + ", password='[REDACTED]'" + + ", connectionProperties=" + connectionProperties + + ", driverClassName='" + driverClassName + '\'' + + ", catalog='" + catalog + '\'' + + ", schema='" + schema + '\'' + + '}'; + } + + /** + * Builder class for constructing ConnectionParameters instances. + */ + public static class Builder { + private String jdbcUrl; + private String username; + private String password; + private Properties connectionProperties; + private String driverClassName; + private String catalog; + private String schema; + + private Builder() { + } + + /** + * Sets the JDBC URL. + * + * @param jdbcUrl the JDBC URL, must not be null or empty + * @return this Builder instance for method chaining + * @throws IllegalArgumentException if jdbcUrl is null or empty + */ + public Builder jdbcUrl(String jdbcUrl) { + if (jdbcUrl == null || jdbcUrl.trim().isEmpty()) { + throw new IllegalArgumentException("JDBC URL cannot be null or empty"); + } + this.jdbcUrl = jdbcUrl.trim(); + return this; + } + + /** + * Sets the username for authentication. + * + * @param username the username, may be null + * @return this Builder instance for method chaining + */ + public Builder username(String username) { + this.username = username; + return this; + } + + /** + * Sets the password for authentication. + * + * @param password the password, may be null + * @return this Builder instance for method chaining + */ + public Builder password(String password) { + this.password = password; + return this; + } + + /** + * Sets the connection properties. + * + * @param connectionProperties the connection properties, may be null + * @return this Builder instance for method chaining + */ + public Builder connectionProperties(Properties connectionProperties) { + this.connectionProperties = connectionProperties; + return this; + } + + /** + * Sets the JDBC driver class name. + * + * @param driverClassName the driver class name, may be null + * @return this Builder instance for method chaining + */ + public Builder driverClassName(String driverClassName) { + this.driverClassName = driverClassName; + return this; + } + + /** + * Sets the database catalog name. + * + * @param catalog the catalog name, may be null + * @return this Builder instance for method chaining + */ + public Builder catalog(String catalog) { + this.catalog = catalog; + return this; + } + + /** + * Sets the database schema name. + * + * @param schema the schema name, may be null + * @return this Builder instance for method chaining + */ + public Builder schema(String schema) { + this.schema = schema; + return this; + } + + /** + * Builds a new ConnectionParameters instance. + * + * @return a new ConnectionParameters instance + * @throws IllegalStateException if required fields are not set + */ + public ConnectionParameters build() { + if (jdbcUrl == null || jdbcUrl.trim().isEmpty()) { + throw new IllegalStateException("JDBC URL is required"); + } + return new ConnectionParameters(this); + } + } +} \ No newline at end of file diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/EncryptionConfig.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/EncryptionConfig.java new file mode 100644 index 000000000..6500c7444 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/EncryptionConfig.java @@ -0,0 +1,388 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.model; + +import java.time.Duration; +import java.util.Objects; +import java.util.Properties; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.PropertyDefinition; + +/** + * Configuration class for the encryption plugin containing KMS settings, + * caching options, retry policies, and other operational parameters. + */ +public class EncryptionConfig { + + // Property definitions using AwsWrapperProperty + public static final AwsWrapperProperty KMS_REGION = new AwsWrapperProperty( + "kms.region", null, "AWS KMS region for encryption operations"); + + public static final AwsWrapperProperty KMS_MASTER_KEY_ARN = new AwsWrapperProperty( + "kms.MasterKeyArn", null, "Master key ARN for encryption"); + + public static final AwsWrapperProperty KEY_ROTATION_DAYS = new AwsWrapperProperty( + "key.rotationDays", "30", "Number of days for key rotation"); + + public static final AwsWrapperProperty METADATA_CACHE_ENABLED = new AwsWrapperProperty( + "metadataCache.enabled", "true", "Enable/disable metadata caching"); + + public static final AwsWrapperProperty METADATA_CACHE_EXPIRATION_MINUTES = new AwsWrapperProperty( + "metadataCache.expirationMinutes", "60", "Metadata cache expiration time in minutes"); + + public static final AwsWrapperProperty KEY_MANAGEMENT_MAX_RETRIES = new AwsWrapperProperty( + "keyManagement.maxRetries", "3", "Maximum number of retries for key management operations"); + + public static final AwsWrapperProperty KEY_MANAGEMENT_RETRY_BACKOFF_BASE_MS = new AwsWrapperProperty( + "keyManagement.retryBackoffBaseMs", "100", "Base backoff time in milliseconds for key management retries"); + + public static final AwsWrapperProperty AUDIT_LOGGING_ENABLED = new AwsWrapperProperty( + "audit.loggingEnabled", "false", "Enable/disable audit logging"); + + public static final AwsWrapperProperty KMS_CONNECTION_TIMEOUT_MS = new AwsWrapperProperty( + "kms.connectionTimeoutMs", "5000", "KMS connection timeout in milliseconds"); + + public static final AwsWrapperProperty DATA_KEY_CACHE_ENABLED = new AwsWrapperProperty( + "dataKeyCache.enabled", "true", "Enable/disable data key caching"); + + public static final AwsWrapperProperty DATA_KEY_CACHE_MAX_SIZE = new AwsWrapperProperty( + "dataKeyCache.maxSize", "1000", "Maximum size of data key cache"); + + public static final AwsWrapperProperty DATA_KEY_CACHE_EXPIRATION_MS = new AwsWrapperProperty( + "dataKeyCache.expirationMs", "3600000", "Data key cache expiration in milliseconds"); + + public static final AwsWrapperProperty METADATA_CACHE_REFRESH_INTERVAL_MS = new AwsWrapperProperty( + "metadataCache.refreshIntervalMs", "300000", "Metadata cache refresh interval in milliseconds"); + + public static final AwsWrapperProperty ENCRYPTION_METADATA_SCHEMA = new AwsWrapperProperty( + "encryption.metadataSchema", "aws", "Schema name for encryption metadata tables"); + + static { + PropertyDefinition.registerPluginProperties(EncryptionConfig.class); + } + + private final String kmsRegion; + private final String defaultMasterKeyArn; + private final int keyRotationDays; + private final boolean cacheEnabled; + private final int cacheExpirationMinutes; + private final int maxRetries; + private final Duration retryBackoffBase; + private final boolean auditLoggingEnabled; + private final Duration kmsConnectionTimeout; + private final boolean dataKeyCacheEnabled; + private final int dataKeyCacheMaxSize; + private final Duration dataKeyCacheExpiration; + private final Duration metadataRefreshInterval; + private final String encryptionMetadataSchema; + + private EncryptionConfig(Builder builder) { + this.kmsRegion = Objects.requireNonNull(builder.kmsRegion, "kmsRegion cannot be null"); + this.defaultMasterKeyArn = builder.defaultMasterKeyArn; + this.keyRotationDays = builder.keyRotationDays; + this.cacheEnabled = builder.cacheEnabled; + this.cacheExpirationMinutes = builder.cacheExpirationMinutes; + this.maxRetries = builder.maxRetries; + this.retryBackoffBase = builder.retryBackoffBase; + this.auditLoggingEnabled = builder.auditLoggingEnabled; + this.kmsConnectionTimeout = builder.kmsConnectionTimeout; + this.dataKeyCacheEnabled = builder.dataKeyCacheEnabled; + this.dataKeyCacheMaxSize = builder.dataKeyCacheMaxSize; + this.dataKeyCacheExpiration = builder.dataKeyCacheExpiration; + this.metadataRefreshInterval = builder.metadataRefreshInterval; + this.encryptionMetadataSchema = Objects.requireNonNull(builder.encryptionMetadataSchema, "encryptionMetadataSchema cannot be null"); + } + + public String getKmsRegion() { + return kmsRegion; + } + + public String getDefaultMasterKeyArn() { + return defaultMasterKeyArn; + } + + public int getKeyRotationDays() { + return keyRotationDays; + } + + public boolean isCacheEnabled() { + return cacheEnabled; + } + + public int getCacheExpirationMinutes() { + return cacheExpirationMinutes; + } + + public int getMaxRetries() { + return maxRetries; + } + + public Duration getRetryBackoffBase() { + return retryBackoffBase; + } + + public boolean isAuditLoggingEnabled() { + return auditLoggingEnabled; + } + + public Duration getKmsConnectionTimeout() { + return kmsConnectionTimeout; + } + + public boolean isDataKeyCacheEnabled() { + return dataKeyCacheEnabled; + } + + public int getDataKeyCacheMaxSize() { + return dataKeyCacheMaxSize; + } + + public Duration getDataKeyCacheExpiration() { + return dataKeyCacheExpiration; + } + + public Duration getMetadataRefreshInterval() { + return metadataRefreshInterval; + } + + public String getEncryptionMetadataSchema() { + return encryptionMetadataSchema; + } + + /** + * Validates the configuration settings. + * + * @throws IllegalArgumentException if configuration is invalid + */ + public void validate() { + if (kmsRegion == null || kmsRegion.trim().isEmpty()) { + throw new IllegalArgumentException("KMS region cannot be null or empty"); + } + + if (keyRotationDays < 0) { + throw new IllegalArgumentException("Key rotation days cannot be negative"); + } + + if (cacheExpirationMinutes < 0) { + throw new IllegalArgumentException("Cache expiration minutes cannot be negative"); + } + + if (maxRetries < 0) { + throw new IllegalArgumentException("Max retries cannot be negative"); + } + + if (retryBackoffBase.isNegative()) { + throw new IllegalArgumentException("Retry backoff base cannot be negative"); + } + + if (kmsConnectionTimeout.isNegative()) { + throw new IllegalArgumentException("KMS connection timeout cannot be negative"); + } + + if (dataKeyCacheMaxSize <= 0) { + throw new IllegalArgumentException("Data key cache max size must be positive"); + } + + if (dataKeyCacheExpiration.isNegative()) { + throw new IllegalArgumentException("Data key cache expiration cannot be negative"); + } + + if (metadataRefreshInterval.isNegative()) { + throw new IllegalArgumentException("Metrics reporting interval cannot be negative"); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + EncryptionConfig that = (EncryptionConfig) o; + return keyRotationDays == that.keyRotationDays && + cacheEnabled == that.cacheEnabled && + cacheExpirationMinutes == that.cacheExpirationMinutes && + maxRetries == that.maxRetries && + auditLoggingEnabled == that.auditLoggingEnabled && + dataKeyCacheEnabled == that.dataKeyCacheEnabled && + dataKeyCacheMaxSize == that.dataKeyCacheMaxSize && + Objects.equals(kmsRegion, that.kmsRegion) && + Objects.equals(defaultMasterKeyArn, that.defaultMasterKeyArn) && + Objects.equals(retryBackoffBase, that.retryBackoffBase) && + Objects.equals(kmsConnectionTimeout, that.kmsConnectionTimeout) && + Objects.equals(dataKeyCacheExpiration, that.dataKeyCacheExpiration) && + Objects.equals(metadataRefreshInterval, that.metadataRefreshInterval); + } + + @Override + public int hashCode() { + return Objects.hash(kmsRegion, defaultMasterKeyArn, keyRotationDays, cacheEnabled, + cacheExpirationMinutes, maxRetries, retryBackoffBase, auditLoggingEnabled, + kmsConnectionTimeout, dataKeyCacheEnabled, dataKeyCacheMaxSize, + dataKeyCacheExpiration, metadataRefreshInterval); + } + + @Override + public String toString() { + return "EncryptionConfig{" + + "kmsRegion='" + kmsRegion + '\'' + + ", defaultMasterKeyArn='" + defaultMasterKeyArn + '\'' + + ", keyRotationDays=" + keyRotationDays + + ", cacheEnabled=" + cacheEnabled + + ", cacheExpirationMinutes=" + cacheExpirationMinutes + + ", maxRetries=" + maxRetries + + ", retryBackoffBase=" + retryBackoffBase + + ", auditLoggingEnabled=" + auditLoggingEnabled + + ", kmsConnectionTimeout=" + kmsConnectionTimeout + + ", dataKeyCacheEnabled=" + dataKeyCacheEnabled + + ", dataKeyCacheMaxSize=" + dataKeyCacheMaxSize + + ", dataKeyCacheExpiration=" + dataKeyCacheExpiration + + ", metadataRefreshInterval=" + metadataRefreshInterval + + '}'; + } + + /** + * Creates an EncryptionConfig from Properties. + * + * @param properties Properties containing configuration values + * @return EncryptionConfig instance + */ + public static EncryptionConfig fromProperties(Properties properties) { + Builder builder = builder(); + + String region = KMS_REGION.getString(properties); + if (region != null) { + builder.kmsRegion(region); + } + + String masterKeyArn = KMS_MASTER_KEY_ARN.getString(properties); + if (masterKeyArn != null) { + builder.defaultMasterKeyArn(masterKeyArn); + } + + builder.keyRotationDays(KEY_ROTATION_DAYS.getInteger(properties)); + builder.cacheEnabled(METADATA_CACHE_ENABLED.getBoolean(properties)); + builder.cacheExpirationMinutes(METADATA_CACHE_EXPIRATION_MINUTES.getInteger(properties)); + builder.maxRetries(KEY_MANAGEMENT_MAX_RETRIES.getInteger(properties)); + builder.retryBackoffBase(Duration.ofMillis(KEY_MANAGEMENT_RETRY_BACKOFF_BASE_MS.getLong(properties))); + builder.auditLoggingEnabled(AUDIT_LOGGING_ENABLED.getBoolean(properties)); + builder.kmsConnectionTimeout(Duration.ofMillis(KMS_CONNECTION_TIMEOUT_MS.getLong(properties))); + builder.dataKeyCacheEnabled(DATA_KEY_CACHE_ENABLED.getBoolean(properties)); + builder.dataKeyCacheMaxSize(DATA_KEY_CACHE_MAX_SIZE.getInteger(properties)); + builder.dataKeyCacheExpiration(Duration.ofMillis(DATA_KEY_CACHE_EXPIRATION_MS.getLong(properties))); + builder.metadataRefreshInterval(Duration.ofMillis(METADATA_CACHE_REFRESH_INTERVAL_MS.getLong(properties))); + builder.encryptionMetadataSchema(ENCRYPTION_METADATA_SCHEMA.getString(properties)); + + return builder.build(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String kmsRegion; + private String defaultMasterKeyArn; + private int keyRotationDays = 90; // Default 90 days + private boolean cacheEnabled = true; + private int cacheExpirationMinutes = 60; // Default 1 hour + private int maxRetries = 5; + private Duration retryBackoffBase = Duration.ofMillis(100); + private boolean auditLoggingEnabled = false; + private Duration kmsConnectionTimeout = Duration.ofSeconds(30); + private boolean dataKeyCacheEnabled = true; + private int dataKeyCacheMaxSize = 1000; + private Duration dataKeyCacheExpiration = Duration.ofMinutes(30); + private Duration metadataRefreshInterval = Duration.ofMinutes(5); + private String encryptionMetadataSchema = "encrypt"; // Default schema name + + public Builder kmsRegion(String kmsRegion) { + this.kmsRegion = kmsRegion; + return this; + } + + public Builder defaultMasterKeyArn(String defaultMasterKeyArn) { + this.defaultMasterKeyArn = defaultMasterKeyArn; + return this; + } + + public Builder keyRotationDays(int keyRotationDays) { + this.keyRotationDays = keyRotationDays; + return this; + } + + public Builder cacheEnabled(boolean cacheEnabled) { + this.cacheEnabled = cacheEnabled; + return this; + } + + public Builder cacheExpirationMinutes(int cacheExpirationMinutes) { + this.cacheExpirationMinutes = cacheExpirationMinutes; + return this; + } + + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder retryBackoffBase(Duration retryBackoffBase) { + this.retryBackoffBase = retryBackoffBase; + return this; + } + + public Builder auditLoggingEnabled(boolean auditLoggingEnabled) { + this.auditLoggingEnabled = auditLoggingEnabled; + return this; + } + + public Builder kmsConnectionTimeout(Duration kmsConnectionTimeout) { + this.kmsConnectionTimeout = kmsConnectionTimeout; + return this; + } + + public Builder dataKeyCacheEnabled(boolean dataKeyCacheEnabled) { + this.dataKeyCacheEnabled = dataKeyCacheEnabled; + return this; + } + + public Builder dataKeyCacheMaxSize(int dataKeyCacheMaxSize) { + this.dataKeyCacheMaxSize = dataKeyCacheMaxSize; + return this; + } + + public Builder dataKeyCacheExpiration(Duration dataKeyCacheExpiration) { + this.dataKeyCacheExpiration = dataKeyCacheExpiration; + return this; + } + + public Builder metadataRefreshInterval(Duration metadataRefreshInterval) { + this.metadataRefreshInterval = metadataRefreshInterval; + return this; + } + + public Builder encryptionMetadataSchema(String encryptionMetadataSchema) { + this.encryptionMetadataSchema = encryptionMetadataSchema; + return this; + } + + public EncryptionConfig build() { + return new EncryptionConfig(this); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/KeyMetadata.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/KeyMetadata.java new file mode 100644 index 000000000..3f56c358b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/model/KeyMetadata.java @@ -0,0 +1,195 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.model; + +import java.time.Instant; +import java.util.Objects; + +/** + * Metadata class for storing KMS key information including master key ARN, + * encrypted data key, and usage tracking information. + */ +public class KeyMetadata { + + private final String keyId; + private final String keyName; + private final String masterKeyArn; + private final String encryptedDataKey; + private final byte[] hmacKey; + private final String keySpec; + private final Instant createdAt; + private final Instant lastUsedAt; + + private KeyMetadata(Builder builder) { + this.keyId = Objects.requireNonNull(builder.keyId, "keyId cannot be null"); + this.keyName = Objects.requireNonNull(builder.keyName, "keyName cannot be null"); + this.masterKeyArn = Objects.requireNonNull(builder.masterKeyArn, "masterKeyArn cannot be null"); + this.encryptedDataKey = Objects.requireNonNull(builder.encryptedDataKey, "encryptedDataKey cannot be null"); + this.hmacKey = builder.hmacKey; + this.keySpec = Objects.requireNonNull(builder.keySpec, "keySpec cannot be null"); + this.createdAt = builder.createdAt != null ? builder.createdAt : Instant.now(); + this.lastUsedAt = builder.lastUsedAt != null ? builder.lastUsedAt : Instant.now(); + } + + public String getKeyId() { + return keyId; + } + + public String getKeyName() { + return keyName; + } + + public String getMasterKeyArn() { + return masterKeyArn; + } + + public String getEncryptedDataKey() { + return encryptedDataKey; + } + + public byte[] getHmacKey() { + return hmacKey; + } + + public String getKeySpec() { + return keySpec; + } + + public Instant getCreatedAt() { + return createdAt; + } + + public Instant getLastUsedAt() { + return lastUsedAt; + } + + /** + * Creates a new KeyMetadata instance with updated lastUsedAt timestamp. + * + * @return New KeyMetadata with current timestamp + */ + public KeyMetadata withUpdatedLastUsed() { + return builder() + .keyId(this.keyId) + .masterKeyArn(this.masterKeyArn) + .encryptedDataKey(this.encryptedDataKey) + .keySpec(this.keySpec) + .createdAt(this.createdAt) + .lastUsedAt(Instant.now()) + .build(); + } + + /** + * Checks if the key metadata is valid for encryption operations. + * + * @return True if metadata is valid, false otherwise + */ + public boolean isValid() { + return keyId != null && !keyId.trim().isEmpty() && + masterKeyArn != null && !masterKeyArn.trim().isEmpty() && + encryptedDataKey != null && !encryptedDataKey.trim().isEmpty() && + keySpec != null && !keySpec.trim().isEmpty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KeyMetadata that = (KeyMetadata) o; + return Objects.equals(keyId, that.keyId) && + Objects.equals(masterKeyArn, that.masterKeyArn) && + Objects.equals(encryptedDataKey, that.encryptedDataKey) && + Objects.equals(keySpec, that.keySpec); + } + + @Override + public int hashCode() { + return Objects.hash(keyId, masterKeyArn, encryptedDataKey, keySpec); + } + + @Override + public String toString() { + return "KeyMetadata{" + + "keyId='" + keyId + '\'' + + ", masterKeyArn='" + masterKeyArn + '\'' + + ", keySpec='" + keySpec + '\'' + + ", createdAt=" + createdAt + + ", lastUsedAt=" + lastUsedAt + + ", encryptedDataKey='[REDACTED]'" + // Don't expose encrypted key in logs + '}'; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String keyId; + private String keyName; + private String masterKeyArn; + private String encryptedDataKey; + private byte[] hmacKey; + private String keySpec = "AES_256"; // Default key spec + private Instant createdAt; + private Instant lastUsedAt; + + public Builder keyId(String keyId) { + this.keyId = keyId; + return this; + } + + public Builder keyName(String keyName) { + this.keyName = keyName; + return this; + } + + public Builder masterKeyArn(String masterKeyArn) { + this.masterKeyArn = masterKeyArn; + return this; + } + + public Builder encryptedDataKey(String encryptedDataKey) { + this.encryptedDataKey = encryptedDataKey; + return this; + } + + public Builder hmacKey(byte[] hmacKey) { + this.hmacKey = hmacKey; + return this; + } + + public Builder keySpec(String keySpec) { + this.keySpec = keySpec; + return this; + } + + public Builder createdAt(Instant createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder lastUsedAt(Instant lastUsedAt) { + this.lastUsedAt = lastUsedAt; + return this; + } + + public KeyMetadata build() { + return new KeyMetadata(this); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/BENCHMARK_RESULTS.md b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/BENCHMARK_RESULTS.md new file mode 100644 index 000000000..6918bde0d --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/BENCHMARK_RESULTS.md @@ -0,0 +1,59 @@ +# PostgreSQL Java SQL Parser - Performance Benchmarks + +## Benchmark Results + +JMH (Java Microbenchmark Harness) performance results for the PostgreSQL Java SQL Parser: + +| Benchmark | Average Time (μs) | Operations/sec | Description | +|-----------|------------------|----------------|-------------| +| parseSimpleSelect | 0.180 ± 0.001 | ~5.6M | `SELECT * FROM users` | +| parseDelete | 0.371 ± 0.024 | ~2.7M | `DELETE FROM users WHERE age < 18` | +| parseSelectWithWhere | 0.555 ± 0.040 | ~1.8M | `SELECT id, name FROM users WHERE age > 25` | +| parseSelectWithOrderBy | 0.576 ± 0.058 | ~1.7M | `SELECT * FROM products ORDER BY price DESC` | +| parseScientificNotation | 0.585 ± 0.052 | ~1.7M | `INSERT INTO measurements VALUES (42, 3.14159, 2.5e10)` | +| parseInsertWithPlaceholders | 0.625 ± 0.016 | ~1.6M | `INSERT INTO users (name, age, email) VALUES (?, ?, ?)` | +| parseUpdateWithPlaceholders | 0.696 ± 0.020 | ~1.4M | `UPDATE users SET name = ?, age = ? WHERE id = ?` | +| parseUpdate | 0.746 ± 0.536 | ~1.3M | `UPDATE users SET name = 'Jane', age = 25 WHERE id = 1` | +| parseInsert | 0.922 ± 0.037 | ~1.1M | `INSERT INTO users (name, age, email) VALUES ('John', 30, 'john@example.com')` | +| parseCreateTable | 1.231 ± 0.145 | ~810K | `CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR NOT NULL, price DECIMAL)` | +| parseComplexExpression | 1.366 ± 0.180 | ~730K | Complex WHERE with AND/OR conditions | +| parseComplexSelect | 1.808 ± 0.275 | ~550K | Multi-table SELECT with JOIN conditions | + +## Performance Analysis + +### Key Findings: + +1. **Excellent Performance**: The parser achieves sub-microsecond parsing for simple statements +2. **Scalability**: Performance scales reasonably with query complexity +3. **JDBC Placeholders**: Placeholder parsing is actually faster than literal parsing (fewer tokens to process) +4. **Consistent Results**: Low error margins indicate stable performance + +### Performance Characteristics: + +- **Simple SELECT**: ~180 nanoseconds (5.6M ops/sec) +- **Complex queries**: 1-2 microseconds (500K-1M ops/sec) +- **Memory efficient**: No significant GC pressure during benchmarks + +### Use Case Performance: + +- **High-frequency JDBC operations**: Excellent (sub-microsecond) +- **Query analysis tools**: Very good (1-2 microseconds for complex queries) +- **Real-time SQL processing**: Suitable for high-throughput applications + +## Test Environment + +- **JVM**: OpenJDK 21.0.7 64-Bit Server VM +- **JMH Version**: 1.37 +- **Benchmark Mode**: Average time per operation +- **Warmup**: 3 iterations, 1 second each +- **Measurement**: 5 iterations, 1 second each +- **Threads**: Single-threaded + +## Comparison Context + +For reference, typical database operations: +- Network round-trip to database: ~1-10ms (1,000-10,000μs) +- Simple database query execution: ~100μs-1ms +- **This parser**: 0.18-1.8μs + +The parser overhead is negligible compared to actual database operations, making it suitable for production use in JDBC drivers, query analyzers, and SQL processing tools. diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParser.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParser.java new file mode 100644 index 000000000..ece069e72 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParser.java @@ -0,0 +1,272 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +import software.amazon.jdbc.plugin.encryption.parser.ast.*; +import java.util.List; + +/** + * Main PostgreSQL SQL Parser + * Combines lexer and parser to parse SQL statements + */ +public class PostgreSqlParser { + + /** + * Parse a SQL string and return the AST + */ + public Statement parse(String sql) { + // Tokenize the input + SqlLexer lexer = new SqlLexer(sql); + List tokens = lexer.tokenize(); + + // Parse the tokens + SqlParser parser = new SqlParser(tokens); + return parser.parse(); + } + + /** + * Parse and pretty print the AST + */ + public String parseAndFormat(String sql) { + Statement stmt = parse(sql); + return formatStatement(stmt); + } + + private String formatStatement(Statement stmt) { + if (stmt instanceof SelectStatement) { + return formatSelectStatement((SelectStatement) stmt); + } else if (stmt instanceof InsertStatement) { + return formatInsertStatement((InsertStatement) stmt); + } else if (stmt instanceof UpdateStatement) { + return formatUpdateStatement((UpdateStatement) stmt); + } else if (stmt instanceof DeleteStatement) { + return formatDeleteStatement((DeleteStatement) stmt); + } else if (stmt instanceof CreateTableStatement) { + return formatCreateTableStatement((CreateTableStatement) stmt); + } + return stmt.toString(); + } + + private String formatSelectStatement(SelectStatement stmt) { + StringBuilder sb = new StringBuilder(); + sb.append("SELECT "); + + // Format select list + for (int i = 0; i < stmt.getSelectList().size(); i++) { + if (i > 0) sb.append(", "); + SelectItem item = stmt.getSelectList().get(i); + sb.append(formatExpression(item.getExpression())); + if (item.getAlias() != null) { + sb.append(" AS ").append(item.getAlias()); + } + } + + // Format FROM clause + if (stmt.getFromList() != null && !stmt.getFromList().isEmpty()) { + sb.append("\nFROM "); + for (int i = 0; i < stmt.getFromList().size(); i++) { + if (i > 0) sb.append(", "); + TableReference table = stmt.getFromList().get(i); + sb.append(table.getTableName().getName()); + if (table.getAlias() != null) { + sb.append(" AS ").append(table.getAlias()); + } + } + } + + // Format WHERE clause + if (stmt.getWhereClause() != null) { + sb.append("\nWHERE ").append(formatExpression(stmt.getWhereClause())); + } + + // Format GROUP BY clause + if (stmt.getGroupByList() != null && !stmt.getGroupByList().isEmpty()) { + sb.append("\nGROUP BY "); + for (int i = 0; i < stmt.getGroupByList().size(); i++) { + if (i > 0) sb.append(", "); + sb.append(formatExpression(stmt.getGroupByList().get(i))); + } + } + + // Format HAVING clause + if (stmt.getHavingClause() != null) { + sb.append("\nHAVING ").append(formatExpression(stmt.getHavingClause())); + } + + // Format ORDER BY clause + if (stmt.getOrderByList() != null && !stmt.getOrderByList().isEmpty()) { + sb.append("\nORDER BY "); + for (int i = 0; i < stmt.getOrderByList().size(); i++) { + if (i > 0) sb.append(", "); + OrderByItem item = stmt.getOrderByList().get(i); + sb.append(formatExpression(item.getExpression())); + sb.append(" ").append(item.getDirection()); + } + } + + // Format LIMIT clause + if (stmt.getLimit() != null) { + sb.append("\nLIMIT ").append(stmt.getLimit()); + } + + return sb.toString(); + } + + private String formatInsertStatement(InsertStatement stmt) { + StringBuilder sb = new StringBuilder(); + sb.append("INSERT INTO ").append(stmt.getTable().getTableName().getName()); + + if (stmt.getColumns() != null && !stmt.getColumns().isEmpty()) { + sb.append(" ("); + for (int i = 0; i < stmt.getColumns().size(); i++) { + if (i > 0) sb.append(", "); + sb.append(stmt.getColumns().get(i).getName()); + } + sb.append(")"); + } + + sb.append("\nVALUES "); + for (int i = 0; i < stmt.getValues().size(); i++) { + if (i > 0) sb.append(", "); + sb.append("("); + List values = stmt.getValues().get(i); + for (int j = 0; j < values.size(); j++) { + if (j > 0) sb.append(", "); + sb.append(formatExpression(values.get(j))); + } + sb.append(")"); + } + + return sb.toString(); + } + + private String formatUpdateStatement(UpdateStatement stmt) { + StringBuilder sb = new StringBuilder(); + sb.append("UPDATE ").append(stmt.getTable().getTableName().getName()); + sb.append("\nSET "); + + for (int i = 0; i < stmt.getAssignments().size(); i++) { + if (i > 0) sb.append(", "); + Assignment assignment = stmt.getAssignments().get(i); + sb.append(assignment.getColumn().getName()); + sb.append(" = "); + sb.append(formatExpression(assignment.getValue())); + } + + if (stmt.getWhereClause() != null) { + sb.append("\nWHERE ").append(formatExpression(stmt.getWhereClause())); + } + + return sb.toString(); + } + + private String formatDeleteStatement(DeleteStatement stmt) { + StringBuilder sb = new StringBuilder(); + sb.append("DELETE FROM ").append(stmt.getTable().getTableName().getName()); + + if (stmt.getWhereClause() != null) { + sb.append("\nWHERE ").append(formatExpression(stmt.getWhereClause())); + } + + return sb.toString(); + } + + private String formatCreateTableStatement(CreateTableStatement stmt) { + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE ").append(stmt.getTableName().getName()).append(" (\n"); + + for (int i = 0; i < stmt.getColumns().size(); i++) { + if (i > 0) sb.append(",\n"); + ColumnDefinition col = stmt.getColumns().get(i); + sb.append(" ").append(col.getColumnName().getName()); + sb.append(" ").append(col.getDataType()); + + if (col.isNotNull()) { + sb.append(" NOT NULL"); + } + if (col.isPrimaryKey()) { + sb.append(" PRIMARY KEY"); + } + } + + sb.append("\n)"); + return sb.toString(); + } + + private String formatExpression(Expression expr) { + if (expr instanceof Identifier) { + return ((Identifier) expr).getName(); + } else if (expr instanceof StringLiteral) { + return "'" + ((StringLiteral) expr).getValue() + "'"; + } else if (expr instanceof NumericLiteral) { + return ((NumericLiteral) expr).getValue(); + } else if (expr instanceof BinaryExpression) { + BinaryExpression binExpr = (BinaryExpression) expr; + return formatExpression(binExpr.getLeft()) + + " " + formatOperator(binExpr.getOperator()) + + " " + formatExpression(binExpr.getRight()); + } + return expr.toString(); + } + + private String formatOperator(BinaryExpression.Operator op) { + switch (op) { + case EQUALS: return "="; + case NOT_EQUALS: return "<>"; + case LESS_THAN: return "<"; + case GREATER_THAN: return ">"; + case LESS_EQUALS: return "<="; + case GREATER_EQUALS: return ">="; + case PLUS: return "+"; + case MINUS: return "-"; + case MULTIPLY: return "*"; + case DIVIDE: return "/"; + case MODULO: return "%"; + case AND: return "AND"; + case OR: return "OR"; + case LIKE: return "LIKE"; + case IN: return "IN"; + case BETWEEN: return "BETWEEN"; + default: return op.toString(); + } + } + + /** + * Main method for testing + */ + public static void main(String[] args) { + PostgreSqlParser parser = new PostgreSqlParser(); + + // Test SELECT statement + String selectSql = "SELECT id, name, age FROM users WHERE age > 18 ORDER BY name"; + System.out.println("Original SQL: " + selectSql); + System.out.println("Parsed AST:"); + System.out.println(parser.parseAndFormat(selectSql)); + System.out.println(); + + // Test INSERT statement + String insertSql = "INSERT INTO users (name, age) VALUES ('John', 25), ('Jane', 30)"; + System.out.println("Original SQL: " + insertSql); + System.out.println("Parsed AST:"); + System.out.println(parser.parseAndFormat(insertSql)); + System.out.println(); + + // Test UPDATE statement + String updateSql = "UPDATE users SET age = 26 WHERE name = 'John'"; + System.out.println("Original SQL: " + updateSql); + System.out.println("Parsed AST:"); + System.out.println(parser.parseAndFormat(updateSql)); + System.out.println(); + + // Test DELETE statement + String deleteSql = "DELETE FROM users WHERE age < 18"; + System.out.println("Original SQL: " + deleteSql); + System.out.println("Parsed AST:"); + System.out.println(parser.parseAndFormat(deleteSql)); + System.out.println(); + + // Test CREATE TABLE statement + String createSql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name VARCHAR NOT NULL, age INTEGER)"; + System.out.println("Original SQL: " + createSql); + System.out.println("Parsed AST:"); + System.out.println(parser.parseAndFormat(createSql)); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SQLAnalyzer.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SQLAnalyzer.java new file mode 100644 index 000000000..d0ec91d11 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SQLAnalyzer.java @@ -0,0 +1,285 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.encryption.parser; + +import software.amazon.jdbc.plugin.encryption.parser.ast.*; + +import java.util.*; + +public class SQLAnalyzer { + + private final PostgreSqlParser parser = new PostgreSqlParser(); + + public static class ColumnInfo { + public String tableName; + public String columnName; + + public ColumnInfo(String tableName, String columnName) { + this.tableName = tableName; + this.columnName = columnName; + } + + @Override + public String toString() { + return tableName + "." + columnName; + } + } + + public static class QueryAnalysis { + public String queryType; + public List columns = new ArrayList<>(); + public List whereColumns = new ArrayList<>(); // Separate WHERE clause columns + public Set tables = new HashSet<>(); + public boolean hasParameters = false; + + @Override + public String toString() { + return String.format("QueryAnalysis{queryType='%s', tables=%s, columns=%s, whereColumns=%s, hasParameters=%s}", + queryType, tables, columns, whereColumns, hasParameters); + } + } + + private boolean containsParameters(Expression expression) { + if (expression == null) return false; + + if (expression instanceof Placeholder) { + return true; + } else if (expression instanceof BinaryExpression) { + BinaryExpression binaryExpr = (BinaryExpression) expression; + return containsParameters(binaryExpr.getLeft()) || containsParameters(binaryExpr.getRight()); + } + return false; + } + + private boolean statementHasParameters(Statement statement) { + if (statement instanceof SelectStatement) { + SelectStatement select = (SelectStatement) statement; + return select.getWhereClause() != null && containsParameters(select.getWhereClause()); + } else if (statement instanceof InsertStatement) { + return true; // INSERT with VALUES typically has parameters + } else if (statement instanceof UpdateStatement) { + UpdateStatement update = (UpdateStatement) statement; + return update.getWhereClause() != null && containsParameters(update.getWhereClause()); + } else if (statement instanceof DeleteStatement) { + DeleteStatement delete = (DeleteStatement) statement; + return delete.getWhereClause() != null && containsParameters(delete.getWhereClause()); + } + return false; + } + + public QueryAnalysis analyze(String sql) { + QueryAnalysis analysis = new QueryAnalysis(); + + try { + Statement statement = parser.parse(sql); + analysis.hasParameters = statementHasParameters(statement); + + if (statement instanceof SelectStatement) { + analysis.queryType = "SELECT"; + extractFromSelect((SelectStatement) statement, analysis); + } else if (statement instanceof InsertStatement) { + analysis.queryType = "INSERT"; + extractFromInsert((InsertStatement) statement, analysis); + } else if (statement instanceof UpdateStatement) { + analysis.queryType = "UPDATE"; + extractFromUpdate((UpdateStatement) statement, analysis); + } else if (statement instanceof DeleteStatement) { + analysis.queryType = "DELETE"; + extractFromDelete((DeleteStatement) statement, analysis); + } else if (statement instanceof CreateTableStatement) { + analysis.queryType = "CREATE"; + extractFromCreateTable((CreateTableStatement) statement, analysis); + } else { + analysis.queryType = "UNKNOWN"; + } + + } catch (SqlParser.ParseException e) { + // Fallback to string parsing if parser fails + String trimmedSql = sql.trim().toUpperCase(); + if (trimmedSql.startsWith("SELECT")) { + analysis.queryType = "SELECT"; + } else if (trimmedSql.startsWith("INSERT")) { + analysis.queryType = "INSERT"; + } else if (trimmedSql.startsWith("UPDATE")) { + analysis.queryType = "UPDATE"; + } else if (trimmedSql.startsWith("DELETE")) { + analysis.queryType = "DELETE"; + } else if (trimmedSql.startsWith("CREATE")) { + analysis.queryType = "CREATE"; + } else if (trimmedSql.startsWith("DROP")) { + analysis.queryType = "DROP"; + } else { + analysis.queryType = "UNKNOWN"; + } + } + + return analysis; + } + + private String extractTableName(String fullName) { + if (fullName.contains(".")) { + return fullName.substring(fullName.lastIndexOf(".") + 1); + } + return fullName; + } + + private void extractFromSelect(SelectStatement select, QueryAnalysis analysis) { + // Extract tables and build alias map + Map aliasToTable = new HashMap<>(); + if (select.getFromList() != null) { + for (TableReference table : select.getFromList()) { + String tableName = extractTableName(table.getTableName().getName()); + analysis.tables.add(tableName); + + // Map alias to table name + if (table.getAlias() != null) { + aliasToTable.put(table.getAlias(), tableName); + } + } + } + + // Extract columns from SELECT clause (skip * and literals) + for (SelectItem selectItem : select.getSelectList()) { + if (selectItem.getExpression() instanceof Identifier) { + Identifier column = (Identifier) selectItem.getExpression(); + // Skip * wildcard + if (!"*".equals(column.getName())) { + String fullName = column.getName(); + String tableName; + String columnName; + + // Parse qualified column name (e.g., "u.name" or "name") + if (fullName.contains(".")) { + String[] parts = fullName.split("\\.", 2); + String tableOrAlias = parts[0]; + columnName = parts[1]; + // Resolve alias to actual table name + tableName = aliasToTable.getOrDefault(tableOrAlias, tableOrAlias); + } else { + tableName = analysis.tables.isEmpty() ? "unknown" : analysis.tables.iterator().next(); + columnName = fullName; + } + + analysis.columns.add(new ColumnInfo(tableName, columnName)); + } + } + } + + // Extract columns from WHERE clause only if WHERE contains parameters + if (select.getWhereClause() != null && containsParameters(select.getWhereClause())) { + extractWhereColumnsFromExpression(select.getWhereClause(), analysis, aliasToTable); + } + } + + private void extractColumnsFromExpression(Expression expression, QueryAnalysis analysis) { + if (expression instanceof Identifier) { + Identifier column = (Identifier) expression; + String tableName = analysis.tables.isEmpty() ? "unknown" : analysis.tables.iterator().next(); + analysis.columns.add(new ColumnInfo(tableName, column.getName())); + } else if (expression instanceof BinaryExpression) { + BinaryExpression binaryExpr = (BinaryExpression) expression; + extractColumnsFromExpression(binaryExpr.getLeft(), analysis); + extractColumnsFromExpression(binaryExpr.getRight(), analysis); + } else if (expression instanceof SubqueryExpression) { + SubqueryExpression subquery = (SubqueryExpression) expression; + // Extract tables from the subquery + extractFromSelect(subquery.getSelectStatement(), analysis); + } + } + + private void extractWhereColumnsFromExpression(Expression expression, QueryAnalysis analysis, Map aliasToTable) { + if (expression instanceof Identifier) { + Identifier column = (Identifier) expression; + String fullName = column.getName(); + String tableName; + String columnName; + + // Parse qualified column name (e.g., "u.id" or "id") + if (fullName.contains(".")) { + String[] parts = fullName.split("\\.", 2); + String tableOrAlias = parts[0]; + columnName = parts[1]; + // Resolve alias to actual table name + tableName = aliasToTable.getOrDefault(tableOrAlias, tableOrAlias); + } else { + tableName = analysis.tables.isEmpty() ? "unknown" : analysis.tables.iterator().next(); + columnName = fullName; + } + + analysis.whereColumns.add(new ColumnInfo(tableName, columnName)); + } else if (expression instanceof BinaryExpression) { + BinaryExpression binaryExpr = (BinaryExpression) expression; + extractWhereColumnsFromExpression(binaryExpr.getLeft(), analysis, aliasToTable); + extractWhereColumnsFromExpression(binaryExpr.getRight(), analysis, aliasToTable); + } else if (expression instanceof SubqueryExpression) { + SubqueryExpression subquery = (SubqueryExpression) expression; + // Extract tables from the subquery + extractFromSelect(subquery.getSelectStatement(), analysis); + } + } + + private void extractFromInsert(InsertStatement insert, QueryAnalysis analysis) { + // Extract table (handle schema.table format) + String tableName = extractTableName(insert.getTable().getTableName().getName()); + analysis.tables.add(tableName); + + // Extract columns (only if they exist) + if (insert.getColumns() != null) { + for (Identifier column : insert.getColumns()) { + analysis.columns.add(new ColumnInfo(tableName, column.getName())); + } + } + } + + private void extractFromUpdate(UpdateStatement update, QueryAnalysis analysis) { + // Extract table + String tableName = extractTableName(update.getTable().getTableName().getName()); + analysis.tables.add(tableName); + + // Extract columns from assignments + for (Assignment assignment : update.getAssignments()) { + analysis.columns.add(new ColumnInfo(tableName, assignment.getColumn().getName())); + } + + // Extract columns from WHERE clause only if WHERE contains parameters + if (update.getWhereClause() != null && containsParameters(update.getWhereClause())) { + extractWhereColumnsFromExpression(update.getWhereClause(), analysis, new HashMap<>()); + } + } + + private void extractFromDelete(DeleteStatement delete, QueryAnalysis analysis) { + // Extract table + String tableName = extractTableName(delete.getTable().getTableName().getName()); + analysis.tables.add(tableName); + + // Extract columns from WHERE clause only if WHERE contains parameters + if (delete.getWhereClause() != null && containsParameters(delete.getWhereClause())) { + extractWhereColumnsFromExpression(delete.getWhereClause(), analysis, new HashMap<>()); + } + } + + private void extractFromCreateTable(CreateTableStatement create, QueryAnalysis analysis) { + // Extract table + String tableName = extractTableName(create.getTableName().getName()); + analysis.tables.add(tableName); + + // Extract columns + for (ColumnDefinition column : create.getColumns()) { + analysis.columns.add(new ColumnInfo(tableName, column.getColumnName().getName())); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SqlLexer.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SqlLexer.java new file mode 100644 index 000000000..f4c4c149c --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SqlLexer.java @@ -0,0 +1,343 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +import java.util.*; + +/** + * SQL Lexer based on PostgreSQL's scan.l + */ +public class SqlLexer { + private final String input; + private int position; + private int line; + private int column; + + // Keywords map + private static final Map KEYWORDS = new HashMap<>(); + static { + KEYWORDS.put("SELECT", Token.Type.SELECT); + KEYWORDS.put("FROM", Token.Type.FROM); + KEYWORDS.put("WHERE", Token.Type.WHERE); + KEYWORDS.put("INSERT", Token.Type.INSERT); + KEYWORDS.put("INTO", Token.Type.INTO); + KEYWORDS.put("UPDATE", Token.Type.UPDATE); + KEYWORDS.put("DELETE", Token.Type.DELETE); + KEYWORDS.put("CREATE", Token.Type.CREATE); + KEYWORDS.put("DROP", Token.Type.DROP); + KEYWORDS.put("ALTER", Token.Type.ALTER); + KEYWORDS.put("TABLE", Token.Type.TABLE); + KEYWORDS.put("INDEX", Token.Type.INDEX); + KEYWORDS.put("DATABASE", Token.Type.DATABASE); + KEYWORDS.put("SCHEMA", Token.Type.SCHEMA); + KEYWORDS.put("VIEW", Token.Type.VIEW); + KEYWORDS.put("FUNCTION", Token.Type.FUNCTION); + KEYWORDS.put("PROCEDURE", Token.Type.PROCEDURE); + KEYWORDS.put("AND", Token.Type.AND); + KEYWORDS.put("OR", Token.Type.OR); + KEYWORDS.put("NOT", Token.Type.NOT); + KEYWORDS.put("NULL", Token.Type.NULL); + KEYWORDS.put("TRUE", Token.Type.TRUE); + KEYWORDS.put("FALSE", Token.Type.FALSE); + KEYWORDS.put("AS", Token.Type.AS); + KEYWORDS.put("ON", Token.Type.ON); + KEYWORDS.put("IN", Token.Type.IN); + KEYWORDS.put("EXISTS", Token.Type.EXISTS); + KEYWORDS.put("BETWEEN", Token.Type.BETWEEN); + KEYWORDS.put("LIKE", Token.Type.LIKE); + KEYWORDS.put("IS", Token.Type.IS); + KEYWORDS.put("ISNULL", Token.Type.ISNULL); + KEYWORDS.put("NOTNULL", Token.Type.NOTNULL); + KEYWORDS.put("ORDER", Token.Type.ORDER); + KEYWORDS.put("BY", Token.Type.BY); + KEYWORDS.put("GROUP", Token.Type.GROUP); + KEYWORDS.put("HAVING", Token.Type.HAVING); + KEYWORDS.put("LIMIT", Token.Type.LIMIT); + KEYWORDS.put("OFFSET", Token.Type.OFFSET); + KEYWORDS.put("INNER", Token.Type.INNER); + KEYWORDS.put("LEFT", Token.Type.LEFT); + KEYWORDS.put("RIGHT", Token.Type.RIGHT); + KEYWORDS.put("FULL", Token.Type.FULL); + KEYWORDS.put("OUTER", Token.Type.OUTER); + KEYWORDS.put("JOIN", Token.Type.JOIN); + KEYWORDS.put("CROSS", Token.Type.CROSS); + KEYWORDS.put("UNION", Token.Type.UNION); + KEYWORDS.put("INTERSECT", Token.Type.INTERSECT); + KEYWORDS.put("EXCEPT", Token.Type.EXCEPT); + KEYWORDS.put("ALL", Token.Type.ALL); + KEYWORDS.put("DISTINCT", Token.Type.DISTINCT); + KEYWORDS.put("VALUES", Token.Type.VALUES); + KEYWORDS.put("SET", Token.Type.SET); + KEYWORDS.put("PRIMARY", Token.Type.PRIMARY); + KEYWORDS.put("KEY", Token.Type.KEY); + KEYWORDS.put("FOREIGN", Token.Type.FOREIGN); + KEYWORDS.put("REFERENCES", Token.Type.REFERENCES); + KEYWORDS.put("CASE", Token.Type.CASE); + KEYWORDS.put("WHEN", Token.Type.WHEN); + KEYWORDS.put("THEN", Token.Type.THEN); + KEYWORDS.put("ELSE", Token.Type.ELSE); + KEYWORDS.put("END", Token.Type.END); + KEYWORDS.put("CAST", Token.Type.CAST); + KEYWORDS.put("RETURNING", Token.Type.RETURNING); + KEYWORDS.put("WITH", Token.Type.WITH); + KEYWORDS.put("RECURSIVE", Token.Type.RECURSIVE); + KEYWORDS.put("WINDOW", Token.Type.WINDOW); + KEYWORDS.put("OVER", Token.Type.OVER); + KEYWORDS.put("PARTITION", Token.Type.PARTITION); + KEYWORDS.put("ROWS", Token.Type.ROWS); + KEYWORDS.put("RANGE", Token.Type.RANGE); + KEYWORDS.put("NULLS", Token.Type.NULLS); + KEYWORDS.put("FIRST", Token.Type.FIRST); + KEYWORDS.put("LAST", Token.Type.LAST); + KEYWORDS.put("ASC", Token.Type.ASC); + KEYWORDS.put("DESC", Token.Type.DESC); + } + + public SqlLexer(String input) { + this.input = input; + this.position = 0; + this.line = 1; + this.column = 1; + } + + public List tokenize() { + List tokens = new ArrayList<>(); + Token token; + + while ((token = nextToken()).getType() != Token.Type.EOF) { + if (token.getType() != Token.Type.WHITESPACE && token.getType() != Token.Type.COMMENT) { + tokens.add(token); + } + } + tokens.add(token); // Add EOF token + + return tokens; + } + + public Token nextToken() { + skipWhitespace(); + + if (position >= input.length()) { + return new Token(Token.Type.EOF, "", line, column); + } + + char ch = input.charAt(position); + int startLine = line; + int startColumn = column; + + // Single character tokens + switch (ch) { + case ';': advance(); return new Token(Token.Type.SEMICOLON, ";", startLine, startColumn); + case ',': advance(); return new Token(Token.Type.COMMA, ",", startLine, startColumn); + case '.': + // Check if this is a decimal number (. followed by digit) + if (position + 1 < input.length() && Character.isDigit(input.charAt(position + 1))) { + return readNumericLiteral(); + } + advance(); + return new Token(Token.Type.DOT, ".", startLine, startColumn); + case '(': advance(); return new Token(Token.Type.LPAREN, "(", startLine, startColumn); + case ')': advance(); return new Token(Token.Type.RPAREN, ")", startLine, startColumn); + case '+': advance(); return new Token(Token.Type.PLUS, "+", startLine, startColumn); + case '-': + if (peek() == '-') { + return readLineComment(); + } + advance(); + return new Token(Token.Type.MINUS, "-", startLine, startColumn); + case '*': advance(); return new Token(Token.Type.MULTIPLY, "*", startLine, startColumn); + case '/': + if (peek() == '*') { + return readBlockComment(); + } + advance(); + return new Token(Token.Type.DIVIDE, "/", startLine, startColumn); + case '%': advance(); return new Token(Token.Type.MODULO, "%", startLine, startColumn); + case '=': advance(); return new Token(Token.Type.EQUALS, "=", startLine, startColumn); + case '<': + if (peek() == '=') { + advance(); advance(); + return new Token(Token.Type.LESS_EQUALS, "<=", startLine, startColumn); + } else if (peek() == '>') { + advance(); advance(); + return new Token(Token.Type.NOT_EQUALS, "<>", startLine, startColumn); + } + advance(); + return new Token(Token.Type.LESS_THAN, "<", startLine, startColumn); + case '>': + if (peek() == '=') { + advance(); advance(); + return new Token(Token.Type.GREATER_EQUALS, ">=", startLine, startColumn); + } + advance(); + return new Token(Token.Type.GREATER_THAN, ">", startLine, startColumn); + case '?': advance(); return new Token(Token.Type.PLACEHOLDER, "?", startLine, startColumn); + case '!': + if (peek() == '=') { + advance(); advance(); + return new Token(Token.Type.NOT_EQUALS, "!=", startLine, startColumn); + } + break; + } + + // String literals + if (ch == '\'') { + return readStringLiteral(); + } + + // Numeric literals + if (Character.isDigit(ch)) { + return readNumericLiteral(); + } + + // Identifiers and keywords + if (Character.isLetter(ch) || ch == '_') { + return readIdentifier(); + } + + // Unknown character + advance(); + return new Token(Token.Type.IDENT, String.valueOf(ch), startLine, startColumn); + } + + private void skipWhitespace() { + while (position < input.length() && Character.isWhitespace(input.charAt(position))) { + if (input.charAt(position) == '\n') { + line++; + column = 1; + } else { + column++; + } + position++; + } + } + + private char advance() { + if (position >= input.length()) return '\0'; + char ch = input.charAt(position++); + if (ch == '\n') { + line++; + column = 1; + } else { + column++; + } + return ch; + } + + private char peek() { + if (position + 1 >= input.length()) return '\0'; + return input.charAt(position + 1); + } + + private Token readStringLiteral() { + int startLine = line; + int startColumn = column; + StringBuilder sb = new StringBuilder(); + + advance(); // Skip opening quote + + while (position < input.length()) { + char ch = input.charAt(position); + if (ch == '\'') { + if (peek() == '\'') { + // Escaped quote + advance(); advance(); + sb.append('\''); + } else { + // End of string + advance(); + break; + } + } else { + sb.append(advance()); + } + } + + return new Token(Token.Type.SCONST, sb.toString(), startLine, startColumn); + } + + private Token readNumericLiteral() { + int startLine = line; + int startColumn = column; + StringBuilder sb = new StringBuilder(); + boolean hasDecimal = false; + boolean hasExponent = false; + + // Handle starting with dot + if (position < input.length() && input.charAt(position) == '.') { + hasDecimal = true; + sb.append(advance()); + } + + while (position < input.length()) { + char ch = input.charAt(position); + if (Character.isDigit(ch)) { + sb.append(advance()); + } else if (ch == '.' && !hasDecimal && !hasExponent) { + hasDecimal = true; + sb.append(advance()); + } else if ((ch == 'e' || ch == 'E') && !hasExponent) { + hasExponent = true; + sb.append(advance()); + // Handle optional + or - after e/E + if (position < input.length() && (input.charAt(position) == '+' || input.charAt(position) == '-')) { + sb.append(advance()); + } + } else { + break; + } + } + + Token.Type type = (hasDecimal || hasExponent) ? Token.Type.FCONST : Token.Type.ICONST; + return new Token(type, sb.toString(), startLine, startColumn); + } + + private Token readIdentifier() { + int startLine = line; + int startColumn = column; + StringBuilder sb = new StringBuilder(); + + while (position < input.length()) { + char ch = input.charAt(position); + if (Character.isLetterOrDigit(ch) || ch == '_') { + sb.append(advance()); + } else { + break; + } + } + + String value = sb.toString(); + String upperValue = value.toUpperCase(); + Token.Type type = KEYWORDS.getOrDefault(upperValue, Token.Type.IDENT); + + return new Token(type, value, startLine, startColumn); + } + + private Token readLineComment() { + int startLine = line; + int startColumn = column; + StringBuilder sb = new StringBuilder(); + + while (position < input.length() && input.charAt(position) != '\n') { + sb.append(advance()); + } + + return new Token(Token.Type.COMMENT, sb.toString(), startLine, startColumn); + } + + private Token readBlockComment() { + int startLine = line; + int startColumn = column; + StringBuilder sb = new StringBuilder(); + + advance(); advance(); // Skip /* + + while (position < input.length() - 1) { + if (input.charAt(position) == '*' && input.charAt(position + 1) == '/') { + advance(); advance(); // Skip */ + break; + } + sb.append(advance()); + } + + return new Token(Token.Type.COMMENT, sb.toString(), startLine, startColumn); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SqlParser.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SqlParser.java new file mode 100644 index 000000000..83f80b7fa --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/SqlParser.java @@ -0,0 +1,673 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +import software.amazon.jdbc.plugin.encryption.parser.ast.*; +import java.util.*; + +/** + * SQL Parser based on PostgreSQL's gram.y + * Implements a recursive descent parser for basic SQL statements + */ +public class SqlParser { + private final List tokens; + private int position; + + public SqlParser(List tokens) { + this.tokens = tokens; + this.position = 0; + } + + public Statement parse() { + return parseStatement(); + } + + private Statement parseStatement() { + Token token = peek(); + if (token.getType() == Token.Type.EOF) { + return null; + } + + switch (token.getType()) { + case SELECT: + return parseSelectStatement(); + case INSERT: + return parseInsertStatement(); + case UPDATE: + return parseUpdateStatement(); + case DELETE: + return parseDeleteStatement(); + case CREATE: + return parseCreateStatement(); + default: + throw new ParseException("Unexpected token: " + token); + } + } + + private SelectStatement parseSelectStatement() { + consume(Token.Type.SELECT); + + // Parse SELECT list + List selectList = parseSelectList(); + + // Parse FROM clause + List fromClause = null; + if (peek().getType() == Token.Type.FROM) { + consume(Token.Type.FROM); + fromClause = parseFromClause(); + } + + // Parse WHERE clause + Expression whereClause = null; + if (peek().getType() == Token.Type.WHERE) { + consume(Token.Type.WHERE); + whereClause = parseExpression(); + } + + // Parse GROUP BY clause + List groupByClause = null; + if (peek().getType() == Token.Type.GROUP) { + consume(Token.Type.GROUP); + consume(Token.Type.BY); + groupByClause = parseExpressionList(); + } + + // Parse HAVING clause + Expression havingClause = null; + if (peek().getType() == Token.Type.HAVING) { + consume(Token.Type.HAVING); + havingClause = parseExpression(); + } + + // Parse ORDER BY clause + List orderByClause = null; + if (peek().getType() == Token.Type.ORDER) { + consume(Token.Type.ORDER); + consume(Token.Type.BY); + orderByClause = parseOrderByList(); + } + + // Parse LIMIT clause + Expression limitClause = null; + if (peek().getType() == Token.Type.LIMIT) { + consume(Token.Type.LIMIT); + limitClause = parseExpression(); + } + + Integer limitValue = null; + if (limitClause instanceof NumericLiteral) { + limitValue = Integer.parseInt(((NumericLiteral) limitClause).getValue()); + } + + return new SelectStatement(selectList, fromClause, whereClause, + groupByClause, havingClause, orderByClause, limitValue); + } + + private InsertStatement parseInsertStatement() { + consume(Token.Type.INSERT); + consume(Token.Type.INTO); + + TableReference table = parseTableReference(); + + // Parse column list (optional) + List columns = null; + if (peek().getType() == Token.Type.LPAREN) { + consume(Token.Type.LPAREN); + columns = parseIdentifierList(); + consume(Token.Type.RPAREN); + } + + // Parse VALUES clause or SELECT statement + List> values = null; + if (peek().getType() == Token.Type.VALUES) { + consume(Token.Type.VALUES); + values = parseValuesList(); + } else if (peek().getType() == Token.Type.SELECT) { + // For INSERT ... SELECT, we'll just parse it as a simple INSERT + // and let the analyzer handle the SELECT part separately + values = new java.util.ArrayList<>(); + } + + return new InsertStatement(table, columns, values); + } + + private UpdateStatement parseUpdateStatement() { + consume(Token.Type.UPDATE); + + TableReference table = parseTableReference(); + + consume(Token.Type.SET); + List assignments = parseAssignmentList(); + + Expression whereClause = null; + if (peek().getType() == Token.Type.WHERE) { + consume(Token.Type.WHERE); + whereClause = parseExpression(); + } + + return new UpdateStatement(table, assignments, whereClause); + } + + private DeleteStatement parseDeleteStatement() { + consume(Token.Type.DELETE); + consume(Token.Type.FROM); + + TableReference table = parseTableReference(); + + Expression whereClause = null; + if (peek().getType() == Token.Type.WHERE) { + consume(Token.Type.WHERE); + whereClause = parseExpression(); + } + + return new DeleteStatement(table, whereClause); + } + + private Statement parseCreateStatement() { + consume(Token.Type.CREATE); + + if (peek().getType() == Token.Type.TABLE) { + return parseCreateTableStatement(); + } + + throw new ParseException("Unsupported CREATE statement"); + } + + private CreateTableStatement parseCreateTableStatement() { + consume(Token.Type.TABLE); + + Identifier tableName = parseIdentifier(); + + consume(Token.Type.LPAREN); + List columns = parseColumnDefinitionList(); + consume(Token.Type.RPAREN); + + return new CreateTableStatement(tableName, columns); + } + + private List parseSelectList() { + List items = new ArrayList<>(); + + do { + Expression expr = parseExpression(); + String alias = null; + + if (peek().getType() == Token.Type.AS) { + consume(Token.Type.AS); + alias = consume(Token.Type.IDENT).getValue(); + } else if (peek().getType() == Token.Type.IDENT) { + alias = consume(Token.Type.IDENT).getValue(); + } + + items.add(new SelectItem(expr, alias)); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return items; + } + + private List parseFromClause() { + List tables = new ArrayList<>(); + + // Parse first table + tables.add(parseTableReference()); + + // Parse JOINs or comma-separated tables + while (true) { + Token.Type nextType = peek().getType(); + if (nextType == Token.Type.COMMA) { + consume(Token.Type.COMMA); + tables.add(parseTableReference()); + } else if (nextType == Token.Type.JOIN || nextType == Token.Type.INNER || + nextType == Token.Type.LEFT || nextType == Token.Type.RIGHT || + nextType == Token.Type.CROSS) { + // Handle JOIN - consume JOIN keywords and add the joined table + if (nextType == Token.Type.INNER || nextType == Token.Type.LEFT || + nextType == Token.Type.RIGHT || nextType == Token.Type.CROSS) { + consume(nextType); // consume INNER/LEFT/RIGHT/CROSS + // Optional OUTER keyword after LEFT/RIGHT/FULL + if (peek().getType() == Token.Type.OUTER) { + consume(Token.Type.OUTER); + } + } + if (peek().getType() == Token.Type.JOIN) { + consume(Token.Type.JOIN); + } + tables.add(parseTableReference()); + + // Skip ON clause for now (not needed for CROSS JOIN) + if (peek().getType() == Token.Type.ON) { + consume(Token.Type.ON); + parseExpression(); // consume but ignore the join condition + } + } else { + break; + } + } + + return tables; + } + + private TableReference parseTableReference() { + Identifier tableName = parseIdentifier(); + String alias = null; + + if (peek().getType() == Token.Type.AS) { + consume(Token.Type.AS); + alias = consume(Token.Type.IDENT).getValue(); + } else if (peek().getType() == Token.Type.IDENT) { + alias = consume(Token.Type.IDENT).getValue(); + } + + return new TableReference(tableName, alias); + } + + private Expression parseExpression() { + return parseOrExpression(); + } + + private Expression parseOrExpression() { + Expression left = parseAndExpression(); + + while (peek().getType() == Token.Type.OR) { + consume(Token.Type.OR); + Expression right = parseAndExpression(); + left = new BinaryExpression(left, BinaryExpression.Operator.OR, right); + } + + return left; + } + + private Expression parseAndExpression() { + Expression left = parseEqualityExpression(); + + while (peek().getType() == Token.Type.AND) { + consume(Token.Type.AND); + Expression right = parseEqualityExpression(); + left = new BinaryExpression(left, BinaryExpression.Operator.AND, right); + } + + return left; + } + + private Expression parseEqualityExpression() { + Expression left = parseRelationalExpression(); + + while (true) { + Token.Type type = peek().getType(); + BinaryExpression.Operator op = null; + + switch (type) { + case EQUALS: op = BinaryExpression.Operator.EQUALS; break; + case NOT_EQUALS: op = BinaryExpression.Operator.NOT_EQUALS; break; + default: return left; + } + + consume(type); + Expression right = parseRelationalExpression(); + left = new BinaryExpression(left, op, right); + } + } + + private Expression parseRelationalExpression() { + Expression left = parseAdditiveExpression(); + + while (true) { + Token.Type type = peek().getType(); + BinaryExpression.Operator op = null; + + switch (type) { + case LESS_THAN: op = BinaryExpression.Operator.LESS_THAN; break; + case GREATER_THAN: op = BinaryExpression.Operator.GREATER_THAN; break; + case LESS_EQUALS: op = BinaryExpression.Operator.LESS_EQUALS; break; + case GREATER_EQUALS: op = BinaryExpression.Operator.GREATER_EQUALS; break; + case LIKE: op = BinaryExpression.Operator.LIKE; break; + case IN: op = BinaryExpression.Operator.IN; break; + default: return left; + } + + consume(type); + Expression right = parseAdditiveExpression(); + left = new BinaryExpression(left, op, right); + } + } + + private Expression parseAdditiveExpression() { + Expression left = parseMultiplicativeExpression(); + + while (true) { + Token.Type type = peek().getType(); + BinaryExpression.Operator op = null; + + switch (type) { + case PLUS: op = BinaryExpression.Operator.PLUS; break; + case MINUS: op = BinaryExpression.Operator.MINUS; break; + default: return left; + } + + consume(type); + Expression right = parseMultiplicativeExpression(); + left = new BinaryExpression(left, op, right); + } + } + + private Expression parseMultiplicativeExpression() { + Expression left = parsePrimaryExpression(); + + while (true) { + Token.Type type = peek().getType(); + BinaryExpression.Operator op = null; + + switch (type) { + case MULTIPLY: op = BinaryExpression.Operator.MULTIPLY; break; + case DIVIDE: op = BinaryExpression.Operator.DIVIDE; break; + case MODULO: op = BinaryExpression.Operator.MODULO; break; + default: return left; + } + + consume(type); + Expression right = parsePrimaryExpression(); + left = new BinaryExpression(left, op, right); + } + } + + private Expression parsePrimaryExpression() { + Token token = peek(); + + switch (token.getType()) { + case MULTIPLY: + consume(Token.Type.MULTIPLY); + return new Identifier("*"); + case IDENT: + Token identToken = peek(); + // Check if this is a function call + if (tokens.size() > position + 1 && tokens.get(position + 1).getType() == Token.Type.LPAREN) { + consume(Token.Type.IDENT); + consume(Token.Type.LPAREN); + // Skip function arguments for now + int parenCount = 1; + while (parenCount > 0 && peek().getType() != Token.Type.EOF) { + Token t = consume(); + if (t.getType() == Token.Type.LPAREN) parenCount++; + else if (t.getType() == Token.Type.RPAREN) parenCount--; + } + return new Identifier(identToken.getValue() + "()"); + } else { + return parseIdentifier(); + } + case SCONST: + consume(Token.Type.SCONST); + return new StringLiteral(token.getValue()); + case ICONST: + consume(Token.Type.ICONST); + return new NumericLiteral(token.getValue(), true); + case FCONST: + consume(Token.Type.FCONST); + return new NumericLiteral(token.getValue(), false); + case PLACEHOLDER: + consume(Token.Type.PLACEHOLDER); + return new Placeholder(); + case TRUE: + consume(Token.Type.TRUE); + return new BooleanLiteral(true); + case FALSE: + consume(Token.Type.FALSE); + return new BooleanLiteral(false); + case CASE: + return parseCaseExpression(); + case CAST: + return parseCastExpression(); + case LPAREN: + consume(Token.Type.LPAREN); + // Check if this is a subquery + if (peek().getType() == Token.Type.SELECT) { + SelectStatement subquery = parseSelectStatement(); + consume(Token.Type.RPAREN); + return new SubqueryExpression(subquery); + } else { + Expression expr = parseExpression(); + consume(Token.Type.RPAREN); + return expr; + } + default: + throw new ParseException("Unexpected token in expression: " + token); + } + } + + private Identifier parseIdentifier() { + Token token = consume(Token.Type.IDENT); + String name = token.getValue(); + + // Check for qualified name (table.column) + if (peek().getType() == Token.Type.DOT) { + consume(Token.Type.DOT); + Token columnToken = consume(Token.Type.IDENT); + name = name + "." + columnToken.getValue(); + } + + return new Identifier(name); + } + + private List parseExpressionList() { + List expressions = new ArrayList<>(); + + do { + expressions.add(parseExpression()); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return expressions; + } + + private List parseIdentifierList() { + List identifiers = new ArrayList<>(); + + do { + identifiers.add(parseIdentifier()); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return identifiers; + } + + private List parseOrderByList() { + List items = new ArrayList<>(); + + do { + Expression expr = parseExpression(); + OrderByItem.Direction direction = OrderByItem.Direction.ASC; + + // Handle ASC/DESC + Token token = peek(); + if (token.getType() == Token.Type.ASC) { + consume(Token.Type.ASC); + direction = OrderByItem.Direction.ASC; + } else if (token.getType() == Token.Type.DESC) { + consume(Token.Type.DESC); + direction = OrderByItem.Direction.DESC; + } else if (token.getType() == Token.Type.IDENT) { + String dir = token.getValue().toUpperCase(); + if ("ASC".equals(dir)) { + consume(Token.Type.IDENT); + direction = OrderByItem.Direction.ASC; + } else if ("DESC".equals(dir)) { + consume(Token.Type.IDENT); + direction = OrderByItem.Direction.DESC; + } + } + + // Handle NULLS FIRST/LAST + if (peek().getType() == Token.Type.NULLS) { + consume(Token.Type.NULLS); + if (peek().getType() == Token.Type.FIRST) { + consume(Token.Type.FIRST); + } else if (peek().getType() == Token.Type.LAST) { + consume(Token.Type.LAST); + } + } + + items.add(new OrderByItem(expr, direction)); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return items; + } + + private List> parseValuesList() { + List> valuesList = new ArrayList<>(); + + do { + consume(Token.Type.LPAREN); + List values = parseExpressionList(); + consume(Token.Type.RPAREN); + valuesList.add(values); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return valuesList; + } + + private List parseAssignmentList() { + List assignments = new ArrayList<>(); + + do { + Identifier column = parseIdentifier(); + consume(Token.Type.EQUALS); + Expression value = parseExpression(); + assignments.add(new Assignment(column, value)); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return assignments; + } + + private List parseColumnDefinitionList() { + List columns = new ArrayList<>(); + + do { + Identifier name = parseIdentifier(); + String dataType = consume(Token.Type.IDENT).getValue(); + boolean notNull = false; + boolean primaryKey = false; + + // Parse constraints (simplified) + while (peek().getType() == Token.Type.NOT || peek().getType() == Token.Type.PRIMARY) { + if (peek().getType() == Token.Type.NOT) { + consume(Token.Type.NOT); + consume(Token.Type.NULL); + notNull = true; + } else if (peek().getType() == Token.Type.PRIMARY) { + consume(Token.Type.PRIMARY); + consume(Token.Type.KEY); + primaryKey = true; + } + } + + columns.add(new ColumnDefinition(name, dataType, notNull, primaryKey)); + + if (peek().getType() == Token.Type.COMMA) { + consume(Token.Type.COMMA); + } else { + break; + } + } while (true); + + return columns; + } + + private Token peek() { + if (position >= tokens.size()) { + return new Token(Token.Type.EOF, "", 0, 0); + } + return tokens.get(position); + } + + private Token consume(Token.Type expectedType) { + Token token = peek(); + if (token.getType() != expectedType) { + throw new ParseException("Expected " + expectedType + " but got " + token.getType()); + } + position++; + return token; + } + + private Expression parseCaseExpression() { + consume(Token.Type.CASE); + + // Skip WHEN/THEN/ELSE/END for now - just consume tokens until END + int depth = 1; + while (depth > 0 && peek().getType() != Token.Type.EOF) { + if (peek().getType() == Token.Type.CASE) { + depth++; + } else if (peek().getType() == Token.Type.END) { + depth--; + if (depth == 0) { + consume(Token.Type.END); + break; + } + } + consume(); + } + + return new Identifier("CASE"); + } + + private Expression parseCastExpression() { + consume(Token.Type.CAST); + consume(Token.Type.LPAREN); + + // Parse the expression being cast + parseExpression(); + + // Skip AS and type + if (peek().getType() == Token.Type.AS) { + consume(Token.Type.AS); + consume(Token.Type.IDENT); // type name + } + + consume(Token.Type.RPAREN); + + return new Identifier("CAST"); + } + + private Token consume() { + if (position >= tokens.size()) { + throw new ParseException("Unexpected end of input"); + } + return tokens.get(position++); + } + + public static class ParseException extends RuntimeException { + public ParseException(String message) { + super(message); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/Token.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/Token.java new file mode 100644 index 000000000..9d3b67fab --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/Token.java @@ -0,0 +1,58 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +/** + * Represents a SQL token with type and value + */ +public class Token { + public enum Type { + // Literals + IDENT, SCONST, ICONST, FCONST, PLACEHOLDER, + + // Keywords + SELECT, FROM, WHERE, INSERT, INTO, UPDATE, DELETE, CREATE, DROP, ALTER, + TABLE, INDEX, DATABASE, SCHEMA, VIEW, FUNCTION, PROCEDURE, + AND, OR, NOT, NULL, TRUE, FALSE, + AS, ON, IN, EXISTS, BETWEEN, LIKE, IS, ISNULL, NOTNULL, + ORDER, BY, GROUP, HAVING, LIMIT, OFFSET, + INNER, LEFT, RIGHT, FULL, OUTER, JOIN, CROSS, + UNION, INTERSECT, EXCEPT, ALL, DISTINCT, + VALUES, SET, PRIMARY, KEY, FOREIGN, REFERENCES, + CASE, WHEN, THEN, ELSE, END, + CAST, RETURNING, WITH, RECURSIVE, + WINDOW, OVER, PARTITION, ROWS, RANGE, + NULLS, FIRST, LAST, ASC, DESC, + + // Operators + EQUALS, NOT_EQUALS, LESS_THAN, GREATER_THAN, LESS_EQUALS, GREATER_EQUALS, + PLUS, MINUS, MULTIPLY, DIVIDE, MODULO, + CONCAT, // || + + // Punctuation + SEMICOLON, COMMA, DOT, LPAREN, RPAREN, + + // Special + EOF, WHITESPACE, COMMENT + } + + private final Type type; + private final String value; + private final int line; + private final int column; + + public Token(Type type, String value, int line, int column) { + this.type = type; + this.value = value; + this.line = line; + this.column = column; + } + + public Type getType() { return type; } + public String getValue() { return value; } + public int getLine() { return line; } + public int getColumn() { return column; } + + @Override + public String toString() { + return String.format("Token{%s, '%s', %d:%d}", type, value, line, column); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Assignment.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Assignment.java new file mode 100644 index 000000000..98445d52b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Assignment.java @@ -0,0 +1,23 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Assignment in UPDATE statement + */ +public class Assignment extends AstNode { + private final Identifier column; + private final Expression value; + + public Assignment(Identifier column, Expression value) { + this.column = column; + this.value = value; + } + + public Identifier getColumn() { return column; } + public Expression getValue() { return value; } + + @Override + public T accept(AstVisitor visitor) { + // Assignment doesn't have a visitor method, so we delegate to the value + return value.accept(visitor); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/AstNode.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/AstNode.java new file mode 100644 index 000000000..64a520408 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/AstNode.java @@ -0,0 +1,11 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Base class for all AST nodes + */ +public abstract class AstNode { + /** + * Accept method for visitor pattern + */ + public abstract T accept(AstVisitor visitor); +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/AstVisitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/AstVisitor.java new file mode 100644 index 000000000..778f1e69f --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/AstVisitor.java @@ -0,0 +1,19 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Visitor interface for AST traversal + */ +public interface AstVisitor { + T visit(SelectStatement node); + T visit(InsertStatement node); + T visit(UpdateStatement node); + T visit(DeleteStatement node); + T visit(CreateTableStatement node); + T visit(BinaryExpression node); + T visit(Identifier node); + T visit(StringLiteral node); + T visit(NumericLiteral node); + T visit(Placeholder node); + T visit(SubqueryExpression node); + T visit(BooleanLiteral node); +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/BinaryExpression.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/BinaryExpression.java new file mode 100644 index 000000000..704ec93b5 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/BinaryExpression.java @@ -0,0 +1,31 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Binary expression (e.g., a = b, x + y) + */ +public class BinaryExpression extends Expression { + public enum Operator { + EQUALS, NOT_EQUALS, LESS_THAN, GREATER_THAN, LESS_EQUALS, GREATER_EQUALS, + PLUS, MINUS, MULTIPLY, DIVIDE, MODULO, + AND, OR, LIKE, IN, BETWEEN + } + + private final Expression left; + private final Operator operator; + private final Expression right; + + public BinaryExpression(Expression left, Operator operator, Expression right) { + this.left = left; + this.operator = operator; + this.right = right; + } + + public Expression getLeft() { return left; } + public Operator getOperator() { return operator; } + public Expression getRight() { return right; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/BooleanLiteral.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/BooleanLiteral.java new file mode 100644 index 000000000..d298586e5 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/BooleanLiteral.java @@ -0,0 +1,26 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Represents a boolean literal (TRUE/FALSE) in SQL + */ +public class BooleanLiteral extends Expression { + private final boolean value; + + public BooleanLiteral(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } + + @Override + public String toString() { + return String.valueOf(value).toUpperCase(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/ColumnDefinition.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/ColumnDefinition.java new file mode 100644 index 000000000..3b408ba02 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/ColumnDefinition.java @@ -0,0 +1,29 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Column definition in CREATE TABLE + */ +public class ColumnDefinition extends AstNode { + private final Identifier columnName; + private final String dataType; + private final boolean notNull; + private final boolean primaryKey; + + public ColumnDefinition(Identifier columnName, String dataType, boolean notNull, boolean primaryKey) { + this.columnName = columnName; + this.dataType = dataType; + this.notNull = notNull; + this.primaryKey = primaryKey; + } + + public Identifier getColumnName() { return columnName; } + public String getDataType() { return dataType; } + public boolean isNotNull() { return notNull; } + public boolean isPrimaryKey() { return primaryKey; } + + @Override + public T accept(AstVisitor visitor) { + // ColumnDefinition doesn't have a visitor method, so we delegate to the column name + return columnName.accept(visitor); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/CreateTableStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/CreateTableStatement.java new file mode 100644 index 000000000..cf573c66d --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/CreateTableStatement.java @@ -0,0 +1,24 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +import java.util.List; + +/** + * CREATE TABLE statement + */ +public class CreateTableStatement extends Statement { + private final Identifier tableName; + private final List columns; + + public CreateTableStatement(Identifier tableName, List columns) { + this.tableName = tableName; + this.columns = columns; + } + + public Identifier getTableName() { return tableName; } + public List getColumns() { return columns; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/DeleteStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/DeleteStatement.java new file mode 100644 index 000000000..ab7d5dfe2 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/DeleteStatement.java @@ -0,0 +1,22 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * DELETE statement + */ +public class DeleteStatement extends Statement { + private final TableReference table; + private final Expression whereClause; + + public DeleteStatement(TableReference table, Expression whereClause) { + this.table = table; + this.whereClause = whereClause; + } + + public TableReference getTable() { return table; } + public Expression getWhereClause() { return whereClause; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Expression.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Expression.java new file mode 100644 index 000000000..e7226a52f --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Expression.java @@ -0,0 +1,7 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Base class for expressions + */ +public abstract class Expression extends AstNode { +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Identifier.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Identifier.java new file mode 100644 index 000000000..2ea87d739 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Identifier.java @@ -0,0 +1,29 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Identifier (table name, column name, etc.) + */ +public class Identifier extends Expression { + private final String name; + private final String schema; + + public Identifier(String name) { + this(null, name); + } + + public Identifier(String schema, String name) { + this.schema = schema; + this.name = name; + } + + public String getName() { return name; } + public String getSchema() { return schema; } + public String getFullName() { + return schema != null ? schema + "." + name : name; + } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/InsertStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/InsertStatement.java new file mode 100644 index 000000000..ae014765b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/InsertStatement.java @@ -0,0 +1,27 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +import java.util.List; + +/** + * INSERT statement + */ +public class InsertStatement extends Statement { + private final TableReference table; + private final List columns; + private final List> values; + + public InsertStatement(TableReference table, List columns, List> values) { + this.table = table; + this.columns = columns; + this.values = values; + } + + public TableReference getTable() { return table; } + public List getColumns() { return columns; } + public List> getValues() { return values; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/NumericLiteral.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/NumericLiteral.java new file mode 100644 index 000000000..99c12194d --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/NumericLiteral.java @@ -0,0 +1,22 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Numeric literal + */ +public class NumericLiteral extends Expression { + private final String value; + private final boolean isInteger; + + public NumericLiteral(String value, boolean isInteger) { + this.value = value; + this.isInteger = isInteger; + } + + public String getValue() { return value; } + public boolean isInteger() { return isInteger; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/OrderByItem.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/OrderByItem.java new file mode 100644 index 000000000..fe84bca0d --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/OrderByItem.java @@ -0,0 +1,25 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * ORDER BY item + */ +public class OrderByItem extends AstNode { + public enum Direction { ASC, DESC } + + private final Expression expression; + private final Direction direction; + + public OrderByItem(Expression expression, Direction direction) { + this.expression = expression; + this.direction = direction; + } + + public Expression getExpression() { return expression; } + public Direction getDirection() { return direction; } + + @Override + public T accept(AstVisitor visitor) { + // OrderByItem doesn't have a visitor method, so we delegate to the expression + return expression.accept(visitor); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Placeholder.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Placeholder.java new file mode 100644 index 000000000..86187892e --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Placeholder.java @@ -0,0 +1,20 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * JDBC placeholder (?) + */ +public class Placeholder extends Expression { + + public Placeholder() { + } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } + + @Override + public String toString() { + return "?"; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SelectItem.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SelectItem.java new file mode 100644 index 000000000..7977fb066 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SelectItem.java @@ -0,0 +1,23 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * SELECT item (column or expression in SELECT clause) + */ +public class SelectItem extends AstNode { + private final Expression expression; + private final String alias; + + public SelectItem(Expression expression, String alias) { + this.expression = expression; + this.alias = alias; + } + + public Expression getExpression() { return expression; } + public String getAlias() { return alias; } + + @Override + public T accept(AstVisitor visitor) { + // SelectItem doesn't have a visitor method, so we delegate to the expression + return expression.accept(visitor); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SelectStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SelectStatement.java new file mode 100644 index 000000000..0b514bd00 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SelectStatement.java @@ -0,0 +1,43 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +import java.util.List; + +/** + * SELECT statement + */ +public class SelectStatement extends Statement { + private final List selectList; + private final List fromList; + private final Expression whereClause; + private final List groupByList; + private final Expression havingClause; + private final List orderByList; + private final Integer limit; + + public SelectStatement(List selectList, List fromList, + Expression whereClause, List groupByList, + Expression havingClause, List orderByList, Integer limit) { + this.selectList = selectList; + this.fromList = fromList; + this.whereClause = whereClause; + this.groupByList = groupByList; + this.havingClause = havingClause; + this.orderByList = orderByList; + this.limit = limit; + } + + public List getSelectList() { return selectList; } + public List getFromList() { return fromList; } + public List getFromClause() { return fromList; } // convenience method + public Expression getWhereClause() { return whereClause; } + public List getGroupByList() { return groupByList; } + public Expression getHavingClause() { return havingClause; } + public List getOrderByList() { return orderByList; } + public List getOrderBy() { return orderByList; } // convenience method + public Integer getLimit() { return limit; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Statement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Statement.java new file mode 100644 index 000000000..f65ea7cad --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/Statement.java @@ -0,0 +1,7 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Base class for SQL statements + */ +public abstract class Statement extends AstNode { +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/StringLiteral.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/StringLiteral.java new file mode 100644 index 000000000..abf603cfd --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/StringLiteral.java @@ -0,0 +1,19 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * String literal + */ +public class StringLiteral extends Expression { + private final String value; + + public StringLiteral(String value) { + this.value = value; + } + + public String getValue() { return value; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SubqueryExpression.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SubqueryExpression.java new file mode 100644 index 000000000..e7244a8c3 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/SubqueryExpression.java @@ -0,0 +1,26 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Represents a subquery expression in SQL + */ +public class SubqueryExpression extends Expression { + private final SelectStatement selectStatement; + + public SubqueryExpression(SelectStatement selectStatement) { + this.selectStatement = selectStatement; + } + + public SelectStatement getSelectStatement() { + return selectStatement; + } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } + + @Override + public String toString() { + return "(" + selectStatement.toString() + ")"; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/TableReference.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/TableReference.java new file mode 100644 index 000000000..88d618d7b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/TableReference.java @@ -0,0 +1,23 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +/** + * Table reference + */ +public class TableReference extends AstNode { + private final Identifier tableName; + private final String alias; + + public TableReference(Identifier tableName, String alias) { + this.tableName = tableName; + this.alias = alias; + } + + public Identifier getTableName() { return tableName; } + public String getAlias() { return alias; } + + @Override + public T accept(AstVisitor visitor) { + // TableReference doesn't have a visitor method, so we delegate to the table name + return tableName.accept(visitor); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/UpdateStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/UpdateStatement.java new file mode 100644 index 000000000..c6c451e0e --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/parser/ast/UpdateStatement.java @@ -0,0 +1,27 @@ +package software.amazon.jdbc.plugin.encryption.parser.ast; + +import java.util.List; + +/** + * UPDATE statement + */ +public class UpdateStatement extends Statement { + private final TableReference table; + private final List assignments; + private final Expression whereClause; + + public UpdateStatement(TableReference table, List assignments, Expression whereClause) { + this.table = table; + this.assignments = assignments; + this.whereClause = whereClause; + } + + public TableReference getTable() { return table; } + public List getAssignments() { return assignments; } + public Expression getWhereClause() { return whereClause; } + + @Override + public T accept(AstVisitor visitor) { + return visitor.visit(this); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/schema/EncryptedDataTypeInstaller.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/schema/EncryptedDataTypeInstaller.java new file mode 100644 index 000000000..305d5bca8 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/schema/EncryptedDataTypeInstaller.java @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.encryption.schema; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +public class EncryptedDataTypeInstaller { + + private static final Logger LOGGER = Logger.getLogger(EncryptedDataTypeInstaller.class.getName()); + private static final String SQL_RESOURCE_PATH = "/sql/encrypted_data_type.sql"; + + public static void installEncryptedDataType(Connection connection) throws SQLException { + LOGGER.info("Installing encrypted_data custom type"); + + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto"); + LOGGER.fine("pgcrypto extension enabled"); + + // Use DOMAIN-based implementation + String sql = loadSqlScript(); + stmt.execute(sql); + + LOGGER.info("encrypted_data type installed successfully (DOMAIN approach)"); + } + } + + public static boolean isEncryptedDataTypeInstalled(Connection connection) throws SQLException { + String checkSql = "SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'encrypted_data')"; + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(checkSql)) { + return rs.next() && rs.getBoolean(1); + } + } + + private static String loadSqlScript() { + try (InputStream is = EncryptedDataTypeInstaller.class.getResourceAsStream(SQL_RESOURCE_PATH)) { + if (is == null) { + throw new IllegalStateException("SQL script not found: " + SQL_RESOURCE_PATH); + } + + try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { + return reader.lines().collect(Collectors.joining("\n")); + } + } catch (Exception e) { + throw new IllegalStateException("Failed to load SQL script: " + SQL_RESOURCE_PATH, e); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/schema/SchemaValidator.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/schema/SchemaValidator.java new file mode 100644 index 000000000..f15c4cc35 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/schema/SchemaValidator.java @@ -0,0 +1,309 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.schema; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Objects; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Validates that the required database schema for encryption metadata exists + * and has the correct structure. + */ +public class SchemaValidator { + + private final String metadataSchema; + + public SchemaValidator(String metadataSchema) { + this.metadataSchema = Objects.requireNonNull(metadataSchema, "Metadata schema cannot be null"); + } + + private String getEncryptionMetadataTable() { + return metadataSchema + ".encryption_metadata"; + } + + private String getKeyStorageTable() { + return metadataSchema + ".key_storage"; + } + + private static final Set REQUIRED_ENCRYPTION_METADATA_COLUMNS = new HashSet<>(Arrays.asList( + "id", "table_name", "column_name", "encryption_algorithm", "key_id", "created_at", "updated_at" + )); + + private static final Set REQUIRED_KEY_STORAGE_COLUMNS = new HashSet<>(Arrays.asList( + "id", "name", "master_key_arn", "encrypted_data_key", "key_spec", "created_at", "last_used_at" + )); + + /** + * Validates that all required tables and columns exist in the database. + * + * @param connection Database connection to validate against + * @return ValidationResult containing validation status and any issues found + * @throws SQLException if database access fails + */ + public ValidationResult validateSchema(Connection connection) throws SQLException { + List issues = new ArrayList<>(); + + // Validate encryption_metadata table + String encryptionMetadataTable = getEncryptionMetadataTable(); + if (!tableExists(connection, encryptionMetadataTable)) { + issues.add("Table '" + encryptionMetadataTable + "' does not exist"); + } else { + issues.addAll(validateTableColumns(connection, encryptionMetadataTable, REQUIRED_ENCRYPTION_METADATA_COLUMNS)); + issues.addAll(validateEncryptionMetadataConstraints(connection)); + } + + // Validate key_storage table + String keyStorageTable = getKeyStorageTable(); + if (!tableExists(connection, keyStorageTable)) { + issues.add("Table '" + keyStorageTable + "' does not exist"); + } else { + issues.addAll(validateTableColumns(connection, keyStorageTable, REQUIRED_KEY_STORAGE_COLUMNS)); + issues.addAll(validateKeyStorageConstraints(connection)); + } + + // Validate foreign key relationship + if (issues.isEmpty()) { + issues.addAll(validateForeignKeyConstraints(connection)); + } + + return new ValidationResult(issues.isEmpty(), issues); + } + + /** + * Checks if a table exists in the database. + */ + private boolean tableExists(Connection connection, String tableName) throws SQLException { + DatabaseMetaData metaData = connection.getMetaData(); + + // Get current schema + String currentSchema = getCurrentSchema(connection); + + // Only check in the current schema to avoid cross-contamination + try (ResultSet rs = metaData.getTables(null, currentSchema, tableName, new String[]{"TABLE"})) { + return rs.next(); + } + } + + /** + * Gets the current schema name from the connection. + */ + private String getCurrentSchema(Connection connection) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT current_schema()")) { + if (rs.next()) { + return rs.getString(1); + } + } + return null; + } + + /** + * Validates that all required columns exist in a table. + */ + private List validateTableColumns(Connection connection, String tableName, Set requiredColumns) throws SQLException { + List issues = new ArrayList<>(); + Set existingColumns = new HashSet<>(); + + DatabaseMetaData metaData = connection.getMetaData(); + String currentSchema = getCurrentSchema(connection); + + // Try with current schema first + try (ResultSet rs = metaData.getColumns(null, currentSchema, tableName, null)) { + while (rs.next()) { + existingColumns.add(rs.getString("COLUMN_NAME").toLowerCase()); + } + } + + // If no columns found, try without schema + if (existingColumns.isEmpty()) { + try (ResultSet rs = metaData.getColumns(null, null, tableName, null)) { + while (rs.next()) { + existingColumns.add(rs.getString("COLUMN_NAME").toLowerCase()); + } + } + } + + for (String requiredColumn : requiredColumns) { + if (!existingColumns.contains(requiredColumn.toLowerCase())) { + issues.add(String.format("Table '%s' is missing required column '%s'", tableName, requiredColumn)); + } + } + + return issues; + } + /** + * Validates constraints specific to encryption_metadata table. + */ + private List validateEncryptionMetadataConstraints(Connection connection) throws SQLException { + List issues = new ArrayList<>(); + + // Check for unique constraint on table_name, column_name + String encryptionMetadataTable = getEncryptionMetadataTable(); + if (!hasUniqueConstraint(connection, encryptionMetadataTable, Arrays.asList("table_name", "column_name"))) { + issues.add("Table '" + encryptionMetadataTable + "' is missing unique constraint on (table_name, column_name)"); + } + + return issues; + } + + /** + * Validates constraints specific to key_storage table. + */ + private List validateKeyStorageConstraints(Connection connection) throws SQLException { + List issues = new ArrayList<>(); + + // Check for primary key on id + String keyStorageTable = getKeyStorageTable(); + if (!hasPrimaryKey(connection, keyStorageTable, "id")) { + issues.add("Table '" + keyStorageTable + "' is missing primary key on 'id'"); + } + + return issues; + } + + /** + * Validates foreign key constraints between tables. + */ + private List validateForeignKeyConstraints(Connection connection) throws SQLException { + List issues = new ArrayList<>(); + + // Check for foreign key from encryption_metadata.key_id to key_storage.id + String encryptionMetadataTable = getEncryptionMetadataTable(); + String keyStorageTable = getKeyStorageTable(); + if (!hasForeignKey(connection, encryptionMetadataTable, "key_id", keyStorageTable, "id")) { + issues.add("Missing foreign key constraint from " + encryptionMetadataTable + ".key_id to " + keyStorageTable + ".id"); + } + + return issues; + } + + /** + * Checks if a unique constraint exists on the specified columns. + */ + private boolean hasUniqueConstraint(Connection connection, String tableName, List columnNames) throws SQLException { + DatabaseMetaData metaData = connection.getMetaData(); + String currentSchema = getCurrentSchema(connection); + try (ResultSet rs = metaData.getIndexInfo(null, currentSchema, tableName, true, false)) { + Set indexColumns = new HashSet<>(); + String currentIndexName = null; + + while (rs.next()) { + String indexName = rs.getString("INDEX_NAME"); + String columnName = rs.getString("COLUMN_NAME"); + + if (currentIndexName == null || !currentIndexName.equals(indexName)) { + // Check previous index + if (indexColumns.size() == columnNames.size() && + indexColumns.containsAll(columnNames.stream().map(String::toLowerCase).collect(java.util.stream.Collectors.toList()))) { + return true; + } + // Start new index + currentIndexName = indexName; + indexColumns.clear(); + } + + if (columnName != null) { + indexColumns.add(columnName.toLowerCase()); + } + } + + // Check last index + return indexColumns.size() == columnNames.size() && + indexColumns.containsAll(columnNames.stream().map(String::toLowerCase).collect(java.util.stream.Collectors.toList())); + } + } + + /** + * Checks if a primary key exists on the specified column. + */ + private boolean hasPrimaryKey(Connection connection, String tableName, String columnName) throws SQLException { + DatabaseMetaData metaData = connection.getMetaData(); + String currentSchema = getCurrentSchema(connection); + try (ResultSet rs = metaData.getPrimaryKeys(null, currentSchema, tableName)) { + while (rs.next()) { + if (columnName.equalsIgnoreCase(rs.getString("COLUMN_NAME"))) { + return true; + } + } + } + return false; + } + + /** + * Checks if a foreign key exists between the specified tables and columns. + */ + private boolean hasForeignKey(Connection connection, String fromTable, String fromColumn, + String toTable, String toColumn) throws SQLException { + DatabaseMetaData metaData = connection.getMetaData(); + String currentSchema = getCurrentSchema(connection); + try (ResultSet rs = metaData.getImportedKeys(null, currentSchema, fromTable)) { + while (rs.next()) { + String fkColumnName = rs.getString("FKCOLUMN_NAME"); + String pkTableName = rs.getString("PKTABLE_NAME"); + String pkColumnName = rs.getString("PKCOLUMN_NAME"); + + if (fromColumn.equalsIgnoreCase(fkColumnName) && + toTable.equalsIgnoreCase(pkTableName) && + toColumn.equalsIgnoreCase(pkColumnName)) { + return true; + } + } + } + return false; + } + + /** + * Result of schema validation containing status and any issues found. + */ + public static class ValidationResult { + private final boolean valid; + private final List issues; + + public ValidationResult(boolean valid, List issues) { + this.valid = valid; + this.issues = new ArrayList<>(issues); + } + + public boolean isValid() { + return valid; + } + + public List getIssues() { + return new ArrayList<>(issues); + } + + @Override + public String toString() { + if (valid) { + return "Schema validation passed"; + } else { + return "Schema validation failed: " + String.join(", ", issues); + } + } + } +} \ No newline at end of file diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/service/EncryptionException.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/service/EncryptionException.java new file mode 100644 index 000000000..08ac1fc81 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/service/EncryptionException.java @@ -0,0 +1,236 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.service; + +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +/** + * Exception thrown when encryption or decryption operations fail. + * Extends SQLException to integrate with JDBC error handling. + * Provides enhanced error context information for better troubleshooting. + */ +public class EncryptionException extends SQLException { + + private static final long serialVersionUID = 1L; + + // SQL State codes for different encryption error types + public static final String ENCRYPTION_FAILED_STATE = "ENC01"; + public static final String DECRYPTION_FAILED_STATE = "ENC02"; + public static final String INVALID_ALGORITHM_STATE = "ENC03"; + public static final String INVALID_KEY_STATE = "ENC04"; + public static final String TYPE_CONVERSION_FAILED_STATE = "ENC05"; + + private final Map errorContext = new HashMap<>(); + + /** + * Constructs an EncryptionException with the specified detail message. + * + * @param message the detail message + */ + public EncryptionException(String message) { + super(message, ENCRYPTION_FAILED_STATE); + } + + /** + * Constructs an EncryptionException with the specified detail message and cause. + * + * @param message the detail message + * @param cause the cause of this exception + */ + public EncryptionException(String message, Throwable cause) { + super(message, ENCRYPTION_FAILED_STATE, cause); + } + + /** + * Constructs an EncryptionException with the specified detail message, SQL state and cause. + * + * @param message the detail message + * @param sqlState the SQL state + * @param cause the cause of this exception + */ + public EncryptionException(String message, String sqlState, Throwable cause) { + super(message, sqlState, cause); + } + + /** + * Constructs an EncryptionException with the specified cause. + * + * @param cause the cause of this exception + */ + public EncryptionException(Throwable cause) { + super(cause.getMessage(), ENCRYPTION_FAILED_STATE, cause); + } + + /** + * Adds context information to the exception. + * + * @param key the context key + * @param value the context value + * @return this exception for method chaining + */ + public EncryptionException withContext(String key, Object value) { + errorContext.put(key, value); + return this; + } + + /** + * Adds table name to the error context. + * + * @param tableName the table name + * @return this exception for method chaining + */ + public EncryptionException withTable(String tableName) { + return withContext("table", tableName); + } + + /** + * Adds column name to the error context. + * + * @param columnName the column name + * @return this exception for method chaining + */ + public EncryptionException withColumn(String columnName) { + return withContext("column", columnName); + } + + /** + * Adds algorithm to the error context. + * + * @param algorithm the encryption algorithm + * @return this exception for method chaining + */ + public EncryptionException withAlgorithm(String algorithm) { + return withContext("algorithm", algorithm); + } + + /** + * Adds data type to the error context. + * + * @param dataType the data type being processed + * @return this exception for method chaining + */ + public EncryptionException withDataType(String dataType) { + return withContext("dataType", dataType); + } + + /** + * Adds operation type to the error context. + * + * @param operation the operation being performed + * @return this exception for method chaining + */ + public EncryptionException withOperation(String operation) { + return withContext("operation", operation); + } + + /** + * Gets the error context map. + * + * @return a copy of the error context + */ + public Map getErrorContext() { + return new HashMap<>(errorContext); + } + + /** + * Gets a formatted error message including context information. + * + * @return formatted error message with context + */ + public String getDetailedMessage() { + if (errorContext.isEmpty()) { + return getMessage(); + } + + StringBuilder sb = new StringBuilder(getMessage()); + sb.append(" [Context: "); + + boolean first = true; + for (Map.Entry entry : errorContext.entrySet()) { + if (!first) { + sb.append(", "); + } + sb.append(entry.getKey()).append("=").append(entry.getValue()); + first = false; + } + + sb.append("]"); + return sb.toString(); + } + + /** + * Creates an EncryptionException for encryption failures. + * + * @param message Error message + * @param cause Root cause + * @return New EncryptionException instance + */ + public static EncryptionException encryptionFailed(String message, Throwable cause) { + return new EncryptionException(message, ENCRYPTION_FAILED_STATE, cause); + } + + /** + * Creates an EncryptionException for decryption failures. + * + * @param message Error message + * @param cause Root cause + * @return New EncryptionException instance + */ + public static EncryptionException decryptionFailed(String message, Throwable cause) { + return new EncryptionException(message, DECRYPTION_FAILED_STATE, cause); + } + + /** + * Creates an EncryptionException for invalid algorithm errors. + * + * @param algorithm Invalid algorithm name + * @return New EncryptionException instance + */ + public static EncryptionException invalidAlgorithm(String algorithm) { + return new EncryptionException("Unsupported encryption algorithm: " + algorithm, INVALID_ALGORITHM_STATE, null) + .withAlgorithm(algorithm); + } + + /** + * Creates an EncryptionException for invalid key errors. + * + * @param message Error message + * @return New EncryptionException instance + */ + public static EncryptionException invalidKey(String message) { + return new EncryptionException(message, INVALID_KEY_STATE, null); + } + + /** + * Creates an EncryptionException for type conversion errors. + * + * @param fromType Source type + * @param toType Target type + * @param cause Root cause + * @return New EncryptionException instance + */ + public static EncryptionException typeConversionFailed(String fromType, String toType, Throwable cause) { + return new EncryptionException( + String.format("Cannot convert %s to %s", fromType, toType), + TYPE_CONVERSION_FAILED_STATE, + cause + ).withContext("fromType", fromType).withContext("toType", toType); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/service/EncryptionService.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/service/EncryptionService.java new file mode 100644 index 000000000..6be815cce --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/service/EncryptionService.java @@ -0,0 +1,703 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.service; + +import java.util.logging.Logger; + +import javax.crypto.Cipher; +import javax.crypto.Mac; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.SecureRandom; +import java.sql.Date; +import java.sql.Time; +import java.util.Base64; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.Arrays; + +/** + * Service for encrypting and decrypting data using AES-256-GCM algorithm. + * Supports multiple data types and provides secure memory handling. + */ +public class EncryptionService { + + private static final Logger LOGGER = Logger.getLogger(EncryptionService.class.getName()); + + // Algorithm constants + private static final String DEFAULT_ALGORITHM = "AES-256-GCM"; + private static final String AES_GCM_TRANSFORMATION = "AES/GCM/NoPadding"; + private static final int GCM_IV_LENGTH = 12; // 96 bits + private static final int GCM_TAG_LENGTH = 16; // 128 bits + private static final String HMAC_ALGORITHM = "HmacSHA256"; + private static final int HMAC_TAG_LENGTH = 32; // 256 bits + + // Supported algorithms + private static final String[] SUPPORTED_ALGORITHMS = { + "AES-256-GCM", + "AES-128-GCM" + }; + + private final SecureRandom secureRandom; + + /** + * Creates a new EncryptionService instance. + */ + public EncryptionService() { + this.secureRandom = new SecureRandom(); + } + + /** + * Encrypts a value using the specified data key and algorithm. + * + * @param value the value to encrypt + * @param dataKey the encryption key + * @param hmacKey the HMAC verification key + * @param algorithm the encryption algorithm to use + * @return the encrypted data as byte array with HMAC prepended + * @throws EncryptionException if encryption fails + */ + public byte[] encrypt(Object value, byte[] dataKey, byte[] hmacKey, String algorithm) throws EncryptionException { + if (value == null) { + return null; + } + + validateAlgorithm(algorithm); + validateDataKey(dataKey, algorithm); + + try { + // Convert value to bytes based on type + byte[] plaintext = serializeValue(value); + + // Generate random IV + byte[] iv = new byte[GCM_IV_LENGTH]; + secureRandom.nextBytes(iv); + + // Create cipher + Cipher cipher = Cipher.getInstance(AES_GCM_TRANSFORMATION); + SecretKeySpec keySpec = new SecretKeySpec(dataKey, "AES"); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv); + cipher.init(Cipher.ENCRYPT_MODE, keySpec, gcmSpec); + + // Encrypt the data + byte[] ciphertext = cipher.doFinal(plaintext); + + // Combine type marker + IV + ciphertext + ByteBuffer buffer = ByteBuffer.allocate(1 + iv.length + ciphertext.length); + buffer.put(getTypeMarker(value)); + buffer.put(iv); + buffer.put(ciphertext); + byte[] encryptedData = buffer.array(); + + // Generate HMAC using the separate HMAC key + Mac hmac = Mac.getInstance(HMAC_ALGORITHM); + hmac.init(new SecretKeySpec(hmacKey, HMAC_ALGORITHM)); + byte[] hmacTag = hmac.doFinal(encryptedData); + + // Prepend HMAC tag to encrypted data: [HMAC:32bytes][type:1byte][IV:12bytes][ciphertext] + ByteBuffer finalBuffer = ByteBuffer.allocate(HMAC_TAG_LENGTH + encryptedData.length); + finalBuffer.put(hmacTag); + finalBuffer.put(encryptedData); + + // Clear sensitive data + Arrays.fill(plaintext, (byte) 0); + Arrays.fill(iv, (byte) 0); + + return finalBuffer.array(); + + } catch (Exception e) { + LOGGER.severe(()->String.format("Encryption failed for value type: %s %s", value.getClass().getSimpleName(), e.getMessage())); + throw EncryptionException.encryptionFailed("Failed to encrypt value", e) + .withDataType(value.getClass().getSimpleName()) + .withAlgorithm(algorithm) + .withOperation("ENCRYPT"); + } + } + + /** + * Encrypts a value using the same key for both encryption and HMAC. + * This is a convenience method for backward compatibility. + * + * @param value the value to encrypt + * @param dataKey the encryption key (also used for HMAC) + * @param algorithm the encryption algorithm to use + * @return the encrypted data as byte array with HMAC prepended + * @throws EncryptionException if encryption fails + */ + public byte[] encrypt(Object value, byte[] dataKey, String algorithm) throws EncryptionException { + return encrypt(value, dataKey, dataKey, algorithm); + } + + /** + * Decrypts encrypted data using the specified data key and algorithm. + * + * @param encryptedValue the encrypted data with HMAC prepended + * @param dataKey the decryption key + * @param hmacKey the HMAC verification key + * @param algorithm the encryption algorithm used + * @param targetType the expected type of the decrypted value + * @return the decrypted value + * @throws EncryptionException if decryption fails or HMAC verification fails + */ + public Object decrypt(byte[] encryptedValue, byte[] dataKey, byte[] hmacKey, String algorithm, Class targetType) + throws EncryptionException { + if (encryptedValue == null) { + return null; + } + + validateAlgorithm(algorithm); + validateDataKey(dataKey, algorithm); + + // Check if this is old format (with salt) or new format (without salt) + // Old format: [salt:16][HMAC:32][type:1][IV:12][ciphertext] = min 61 bytes + // New format: [HMAC:32][type:1][IV:12][ciphertext] = min 45 bytes + boolean isOldFormat = encryptedValue.length >= 61 && encryptedValue.length >= 16 + 32 + 1 + 12 + 16; + + if (isOldFormat) { + // Try old format first (with salt-based HMAC derivation) + try { + return decryptOldFormat(encryptedValue, dataKey, algorithm, targetType); + } catch (Exception e) { + // If old format fails, try new format + LOGGER.fine(() -> "Old format decryption failed, trying new format: " + e.getMessage()); + } + } + + // New format (two-key system) + if (encryptedValue.length < 32 + 1 + 12 + 16) { + throw EncryptionException.decryptionFailed("Invalid encrypted data length", null) + .withAlgorithm(algorithm) + .withDataType(targetType.getSimpleName()) + .withContext("dataLength", encryptedValue.length) + .withContext("minimumLength", 32 + 1 + 12 + 16); + } + + try { + ByteBuffer buffer = ByteBuffer.wrap(encryptedValue); + + // Extract HMAC tag (first 32 bytes) + byte[] storedHmacTag = new byte[32]; + buffer.get(storedHmacTag); + + // Extract encrypted data (everything after HMAC) + byte[] encryptedData = new byte[buffer.remaining()]; + buffer.get(encryptedData); + + // Verify HMAC using the separate HMAC key + LOGGER.info(() -> String.format("Decrypting: hmacKey length=%d, encryptedData length=%d", + hmacKey != null ? hmacKey.length : 0, encryptedData.length)); + + Mac hmac = Mac.getInstance(HMAC_ALGORITHM); + hmac.init(new SecretKeySpec(hmacKey, HMAC_ALGORITHM)); + byte[] calculatedHmacTag = hmac.doFinal(encryptedData); + + LOGGER.info(() -> String.format("HMAC comparison: stored=%s, calculated=%s", + bytesToHex(storedHmacTag).substring(0, 16), + bytesToHex(calculatedHmacTag).substring(0, 16))); + + if (!MessageDigest.isEqual(storedHmacTag, calculatedHmacTag)) { + throw EncryptionException.decryptionFailed("HMAC verification failed - data may be tampered", null) + .withAlgorithm(algorithm) + .withDataType(targetType.getSimpleName()) + .withOperation("VERIFY_HMAC"); + } + + // Now decrypt the verified data + ByteBuffer dataBuffer = ByteBuffer.wrap(encryptedData); + + // Extract type marker + byte typeMarker = dataBuffer.get(); + + // Extract IV + byte[] iv = new byte[GCM_IV_LENGTH]; + dataBuffer.get(iv); + + // Extract ciphertext + byte[] ciphertext = new byte[dataBuffer.remaining()]; + dataBuffer.get(ciphertext); + + // Create cipher + Cipher cipher = Cipher.getInstance(AES_GCM_TRANSFORMATION); + SecretKeySpec keySpec = new SecretKeySpec(dataKey, "AES"); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv); + cipher.init(Cipher.DECRYPT_MODE, keySpec, gcmSpec); + + // Decrypt the data + byte[] plaintext = cipher.doFinal(ciphertext); + + // Deserialize based on type marker and target type + Object result = deserializeValue(plaintext, typeMarker, targetType); + + // Clear sensitive data + Arrays.fill(plaintext, (byte) 0); + Arrays.fill(iv, (byte) 0); + + return result; + + } catch (Exception e) { + LOGGER.severe(()->String.format("Decryption failed for target type: %s %s", targetType.getSimpleName(), e.getMessage())); + throw EncryptionException.decryptionFailed("Failed to decrypt value", e) + .withDataType(targetType.getSimpleName()) + .withAlgorithm(algorithm) + .withOperation("DECRYPT"); + } + } + + /** + * Decrypts data encrypted with old salt-based format. + */ + private Object decryptOldFormat(byte[] encryptedValue, byte[] dataKey, String algorithm, Class targetType) + throws Exception { + ByteBuffer buffer = ByteBuffer.wrap(encryptedValue); + + // Extract salt (first 16 bytes) + byte[] hmacSalt = new byte[16]; + buffer.get(hmacSalt); + + // Extract HMAC tag (next 32 bytes) + byte[] storedHmacTag = new byte[32]; + buffer.get(storedHmacTag); + + // Extract encrypted data (everything after salt + HMAC) + byte[] encryptedData = new byte[buffer.remaining()]; + buffer.get(encryptedData); + + // Derive verification key from data key and salt + Mac hmacDerive = Mac.getInstance(HMAC_ALGORITHM); + hmacDerive.init(new SecretKeySpec(dataKey, HMAC_ALGORITHM)); + byte[] verificationKey = hmacDerive.doFinal(hmacSalt); + + // Verify HMAC + Mac hmac = Mac.getInstance(HMAC_ALGORITHM); + hmac.init(new SecretKeySpec(verificationKey, HMAC_ALGORITHM)); + byte[] calculatedHmacTag = hmac.doFinal(encryptedData); + + if (!MessageDigest.isEqual(storedHmacTag, calculatedHmacTag)) { + throw EncryptionException.decryptionFailed("HMAC verification failed (old format)", null); + } + + Arrays.fill(verificationKey, (byte) 0); + Arrays.fill(hmacSalt, (byte) 0); + + // Decrypt the verified data + ByteBuffer dataBuffer = ByteBuffer.wrap(encryptedData); + byte typeMarker = dataBuffer.get(); + byte[] iv = new byte[GCM_IV_LENGTH]; + dataBuffer.get(iv); + byte[] ciphertext = new byte[dataBuffer.remaining()]; + dataBuffer.get(ciphertext); + + Cipher cipher = Cipher.getInstance(AES_GCM_TRANSFORMATION); + SecretKeySpec keySpec = new SecretKeySpec(dataKey, "AES"); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv); + cipher.init(Cipher.DECRYPT_MODE, keySpec, gcmSpec); + + byte[] plaintext = cipher.doFinal(ciphertext); + Object result = deserializeValue(plaintext, typeMarker, targetType); + + Arrays.fill(plaintext, (byte) 0); + Arrays.fill(iv, (byte) 0); + + return result; + } + + /** + * Decrypts encrypted data using the same key for both decryption and HMAC verification. + * This is a convenience method for backward compatibility. + * + * @param encryptedValue the encrypted data with HMAC prepended + * @param dataKey the decryption key (also used for HMAC verification) + * @param algorithm the encryption algorithm used + * @param targetType the expected type of the decrypted value + * @return the decrypted value + * @throws EncryptionException if decryption fails or HMAC verification fails + */ + public Object decrypt(byte[] encryptedValue, byte[] dataKey, String algorithm, Class targetType) + throws EncryptionException { + return decrypt(encryptedValue, dataKey, dataKey, algorithm, targetType); + } + + /** + * Returns the default encryption algorithm. + * + * @return the default algorithm name + */ + public String getDefaultAlgorithm() { + return DEFAULT_ALGORITHM; + } + + /** + * Checks if the specified algorithm is supported. + * + * @param algorithm the algorithm to check + * @return true if supported, false otherwise + */ + public boolean isAlgorithmSupported(String algorithm) { + if (algorithm == null) { + return false; + } + return Arrays.asList(SUPPORTED_ALGORITHMS).contains(algorithm); + } + + /** + * Validates that the algorithm is supported. + */ + private void validateAlgorithm(String algorithm) throws EncryptionException { + if (!isAlgorithmSupported(algorithm)) { + throw EncryptionException.invalidAlgorithm(algorithm); + } + } + + /** + * Validates the data key for the specified algorithm. + */ + private void validateDataKey(byte[] dataKey, String algorithm) throws EncryptionException { + if (dataKey == null) { + throw EncryptionException.invalidKey("Data key cannot be null") + .withAlgorithm(algorithm); + } + + int expectedKeyLength = getExpectedKeyLength(algorithm); + if (dataKey.length != expectedKeyLength) { + throw EncryptionException.invalidKey( + String.format("Invalid key length for %s: expected %d bytes, got %d", + algorithm, expectedKeyLength, dataKey.length)) + .withAlgorithm(algorithm) + .withContext("expectedLength", expectedKeyLength) + .withContext("actualLength", dataKey.length); + } + } + + /** + * Gets the expected key length for the algorithm. + */ + private int getExpectedKeyLength(String algorithm) { + switch (algorithm) { + case "AES-256-GCM": + return 32; // 256 bits + case "AES-128-GCM": + return 16; // 128 bits + default: + throw new IllegalArgumentException("Unknown algorithm: " + algorithm); + } + } + + /** + * Serializes a value to bytes based on its type. + */ + private byte[] serializeValue(Object value) throws Exception { + if (value instanceof String) { + return ((String) value).getBytes(StandardCharsets.UTF_8); + } else if (value instanceof Integer) { + return ByteBuffer.allocate(4).putInt((Integer) value).array(); + } else if (value instanceof Long) { + return ByteBuffer.allocate(8).putLong((Long) value).array(); + } else if (value instanceof Double) { + return ByteBuffer.allocate(8).putDouble((Double) value).array(); + } else if (value instanceof Float) { + return ByteBuffer.allocate(4).putFloat((Float) value).array(); + } else if (value instanceof Boolean) { + return new byte[]{(Boolean) value ? (byte) 1 : (byte) 0}; + } else if (value instanceof BigDecimal) { + return ((BigDecimal) value).toString().getBytes(StandardCharsets.UTF_8); + } else if (value instanceof Date) { + return ByteBuffer.allocate(8).putLong(((Date) value).getTime()).array(); + } else if (value instanceof Time) { + return ByteBuffer.allocate(8).putLong(((Time) value).getTime()).array(); + } else if (value instanceof Timestamp) { + return ByteBuffer.allocate(8).putLong(((Timestamp) value).getTime()).array(); + } else if (value instanceof LocalDate) { + return ((LocalDate) value).toString().getBytes(StandardCharsets.UTF_8); + } else if (value instanceof LocalTime) { + return ((LocalTime) value).toString().getBytes(StandardCharsets.UTF_8); + } else if (value instanceof LocalDateTime) { + return ((LocalDateTime) value).toString().getBytes(StandardCharsets.UTF_8); + } else if (value instanceof byte[]) { + return (byte[]) value; + } else { + // Fallback to Java serialization for complex objects + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(value); + return baos.toByteArray(); + } + } + } + + /** + * Gets a type marker byte for the value type. + */ + private byte getTypeMarker(Object value) { + if (value instanceof String) return 1; + if (value instanceof Integer) return 2; + if (value instanceof Long) return 3; + if (value instanceof Double) return 4; + if (value instanceof Float) return 5; + if (value instanceof Boolean) return 6; + if (value instanceof BigDecimal) return 7; + if (value instanceof Date) return 8; + if (value instanceof Time) return 9; + if (value instanceof Timestamp) return 10; + if (value instanceof LocalDate) return 11; + if (value instanceof LocalTime) return 12; + if (value instanceof LocalDateTime) return 13; + if (value instanceof byte[]) return 14; + return 99; // Generic object serialization + } + + /** + * Deserializes bytes to the appropriate type. + */ + private Object deserializeValue(byte[] data, byte typeMarker, Class targetType) throws Exception { + switch (typeMarker) { + case 1: // String + String str = new String(data, StandardCharsets.UTF_8); + return convertToTargetType(str, targetType); + + case 2: // Integer + if (data.length != 4) throw EncryptionException.decryptionFailed("Invalid Integer data length", null) + .withContext("expectedLength", 4).withContext("actualLength", data.length); + int intVal = ByteBuffer.wrap(data).getInt(); + return convertToTargetType(intVal, targetType); + + case 3: // Long + if (data.length != 8) throw EncryptionException.decryptionFailed("Invalid Long data length", null) + .withContext("expectedLength", 8).withContext("actualLength", data.length); + long longVal = ByteBuffer.wrap(data).getLong(); + return convertToTargetType(longVal, targetType); + + case 4: // Double + if (data.length != 8) throw EncryptionException.decryptionFailed("Invalid Double data length", null) + .withContext("expectedLength", 8).withContext("actualLength", data.length); + double doubleVal = ByteBuffer.wrap(data).getDouble(); + return convertToTargetType(doubleVal, targetType); + + case 5: // Float + if (data.length != 4) throw EncryptionException.decryptionFailed("Invalid Float data length", null) + .withContext("expectedLength", 4).withContext("actualLength", data.length); + float floatVal = ByteBuffer.wrap(data).getFloat(); + return convertToTargetType(floatVal, targetType); + + case 6: // Boolean + if (data.length != 1) throw EncryptionException.decryptionFailed("Invalid Boolean data length", null) + .withContext("expectedLength", 1).withContext("actualLength", data.length); + boolean boolVal = data[0] == 1; + return convertToTargetType(boolVal, targetType); + + case 7: // BigDecimal + String decStr = new String(data, StandardCharsets.UTF_8); + BigDecimal decVal = new BigDecimal(decStr); + return convertToTargetType(decVal, targetType); + + case 8: // Date + if (data.length != 8) throw EncryptionException.decryptionFailed("Invalid Date data length", null) + .withContext("expectedLength", 8).withContext("actualLength", data.length); + long dateTime = ByteBuffer.wrap(data).getLong(); + Date dateVal = new Date(dateTime); + return convertToTargetType(dateVal, targetType); + + case 9: // Time + if (data.length != 8) throw EncryptionException.decryptionFailed("Invalid Time data length", null) + .withContext("expectedLength", 8).withContext("actualLength", data.length); + long timeTime = ByteBuffer.wrap(data).getLong(); + Time timeVal = new Time(timeTime); + return convertToTargetType(timeVal, targetType); + + case 10: // Timestamp + if (data.length != 8) throw EncryptionException.decryptionFailed("Invalid Timestamp data length", null) + .withContext("expectedLength", 8).withContext("actualLength", data.length); + long tsTime = ByteBuffer.wrap(data).getLong(); + Timestamp tsVal = new Timestamp(tsTime); + return convertToTargetType(tsVal, targetType); + + case 11: // LocalDate + String ldStr = new String(data, StandardCharsets.UTF_8); + LocalDate ldVal = LocalDate.parse(ldStr); + return convertToTargetType(ldVal, targetType); + + case 12: // LocalTime + String ltStr = new String(data, StandardCharsets.UTF_8); + LocalTime ltVal = LocalTime.parse(ltStr); + return convertToTargetType(ltVal, targetType); + + case 13: // LocalDateTime + String ldtStr = new String(data, StandardCharsets.UTF_8); + LocalDateTime ldtVal = LocalDateTime.parse(ldtStr); + return convertToTargetType(ldtVal, targetType); + + case 14: // byte[] + return convertToTargetType(data, targetType); + + case 99: // Generic object + try (ByteArrayInputStream bais = new ByteArrayInputStream(data); + ObjectInputStream ois = new ObjectInputStream(bais)) { + Object obj = ois.readObject(); + return convertToTargetType(obj, targetType); + } + + default: + throw EncryptionException.decryptionFailed("Unknown type marker: " + typeMarker, null) + .withContext("typeMarker", typeMarker); + } + } + + /** + * Converts a value to the target type if possible. + */ + private Object convertToTargetType(Object value, Class targetType) throws EncryptionException { + if (value == null || targetType == null) { + return value; + } + + // If already the correct type, return as-is + if (targetType.isAssignableFrom(value.getClass())) { + return value; + } + + // Handle Object.class target type (return as-is) + if (targetType == Object.class) { + return value; + } + + // Handle String conversions + if (targetType == String.class) { + return value.toString(); + } + + // Handle byte array conversions + if (targetType == byte[].class) { + if (value instanceof String) { + // Assume base64 encoded string, decode it + try { + return Base64.getDecoder().decode((String) value); + } catch (IllegalArgumentException e) { + throw EncryptionException.typeConversionFailed("String", "byte[]", e) + .withContext("stringValue", value.toString().length() > 50 ? + value.toString().substring(0, 47) + "..." : value.toString()); + } + } else if (value instanceof byte[]) { + return value; + } + } + + // Handle numeric conversions + if (value instanceof Number) { + Number num = (Number) value; + if (targetType == Integer.class || targetType == int.class) { + return num.intValue(); + } else if (targetType == Long.class || targetType == long.class) { + return num.longValue(); + } else if (targetType == Double.class || targetType == double.class) { + return num.doubleValue(); + } else if (targetType == Float.class || targetType == float.class) { + return num.floatValue(); + } else if (targetType == BigDecimal.class) { + return BigDecimal.valueOf(num.doubleValue()); + } + } + + // Handle String to numeric conversions + if (value instanceof String) { + String str = (String) value; + try { + if (targetType == Integer.class || targetType == int.class) { + return Integer.valueOf(str); + } else if (targetType == Long.class || targetType == long.class) { + return Long.valueOf(str); + } else if (targetType == Double.class || targetType == double.class) { + return Double.valueOf(str); + } else if (targetType == Float.class || targetType == float.class) { + return Float.valueOf(str); + } else if (targetType == BigDecimal.class) { + return new BigDecimal(str); + } else if (targetType == Boolean.class || targetType == boolean.class) { + return Boolean.valueOf(str); + } + } catch (NumberFormatException e) { + throw EncryptionException.typeConversionFailed("String", targetType.getSimpleName(), e) + .withContext("stringValue", str.length() > 50 ? str.substring(0, 47) + "..." : str); + } + } + + // If no conversion is possible, throw an exception + throw EncryptionException.typeConversionFailed( + value.getClass().getSimpleName(), + targetType.getSimpleName(), + null); + } + + /** + * Verifies that encrypted data has not been tampered with, without decrypting it. + * This method only requires the HMAC key, not the encryption key or decryption permission. + * + * @param encryptedValue the encrypted data with HMAC prepended + * @param hmacKey the HMAC verification key + * @return true if HMAC verification passes, false otherwise + */ + public boolean verifyEncryptedData(byte[] encryptedValue, byte[] hmacKey) { + if (encryptedValue == null || hmacKey == null) { + return false; + } + + if (encryptedValue.length < HMAC_TAG_LENGTH + 1 + GCM_IV_LENGTH + GCM_TAG_LENGTH) { + return false; + } + + try { + ByteBuffer buffer = ByteBuffer.wrap(encryptedValue); + + // Extract stored HMAC tag + byte[] storedHmacTag = new byte[HMAC_TAG_LENGTH]; + buffer.get(storedHmacTag); + + // Extract encrypted data + byte[] encryptedData = new byte[buffer.remaining()]; + buffer.get(encryptedData); + + // Calculate HMAC using the HMAC key + Mac hmac = Mac.getInstance(HMAC_ALGORITHM); + hmac.init(new SecretKeySpec(hmacKey, HMAC_ALGORITHM)); + byte[] calculatedHmacTag = hmac.doFinal(encryptedData); + + // Verify + return MessageDigest.isEqual(storedHmacTag, calculatedHmacTag); + + } catch (Exception e) { + LOGGER.warning(()->"HMAC verification failed: " + e.getMessage()); + return false; + } + } + + private static String bytesToHex(byte[] bytes) { + StringBuilder sb = new StringBuilder(); + for (byte b : bytes) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/sql/SqlAnalysisService.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/sql/SqlAnalysisService.java new file mode 100644 index 000000000..ab73c54b2 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/sql/SqlAnalysisService.java @@ -0,0 +1,190 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.sql; + +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.model.ColumnEncryptionConfig; +import software.amazon.jdbc.plugin.encryption.parser.SQLAnalyzer; +import java.util.logging.Logger; + +import java.util.*; + +/** + * Service that analyzes SQL statements to identify columns that need encryption/decryption. + * Uses jOOQ parser via SQLAnalyzer class. + */ +public class SqlAnalysisService { + + private static final Logger LOGGER = Logger.getLogger(SqlAnalysisService.class.getName()); + + private final MetadataManager metadataManager; + private final SQLAnalyzer analyzer; + + public SqlAnalysisService(PluginService pluginService, MetadataManager metadataManager) { + this.metadataManager = metadataManager; + this.analyzer = new SQLAnalyzer(); + } + + /** + * Analyzes a SQL statement to determine which columns need encryption/decryption. + * + * @param sql The SQL statement to analyze + * @return Analysis result containing affected columns and their encryption configs + */ + public SqlAnalysisResult analyzeSql(String sql) { + if (sql == null || sql.trim().isEmpty()) { + return new SqlAnalysisResult(Collections.emptySet(), Collections.emptyMap(), "UNKNOWN"); + } + + try { + SQLAnalyzer.QueryAnalysis queryAnalysis = analyzer.analyze(sql); + if (queryAnalysis != null) { + Set tables = extractTablesFromAnalysis(queryAnalysis); + String queryType = extractQueryTypeFromAnalysis(queryAnalysis); + return analyzeFromTables(tables, queryType); + } + } catch (Exception e) { + LOGGER.severe(()->String.format("Error analyzing SQL: %s", e.getMessage())); + throw new RuntimeException("SQL analysis failed", e); + } + + return new SqlAnalysisResult(Collections.emptySet(), Collections.emptyMap(), "UNKNOWN"); + } + + /** + * Extracts table names from SQLAnalyzer QueryAnalysis result. + */ + private Set extractTablesFromAnalysis(SQLAnalyzer.QueryAnalysis queryAnalysis) { + Set tables = new HashSet<>(); + if (queryAnalysis != null) { + tables.addAll(queryAnalysis.tables); + } + return tables; + } + + /** + * Extracts query type from SQLAnalyzer QueryAnalysis result. + */ + private String extractQueryTypeFromAnalysis(SQLAnalyzer.QueryAnalysis queryAnalysis) { + if (queryAnalysis != null) { + return queryAnalysis.queryType != null ? queryAnalysis.queryType : "UNKNOWN"; + } + return "UNKNOWN"; + } + + /** + * Analyzes SQL using the extracted table names from parser. + */ + private SqlAnalysisResult analyzeFromTables(Set tables, String queryType) { + Map encryptedColumns = new HashMap<>(); + + LOGGER.finest(()->String.format("Parser analysis found %s tables", tables.size())); + + return new SqlAnalysisResult(tables, encryptedColumns, queryType); + } + + /** + * Result of SQL analysis containing affected tables and encrypted columns. + */ + public static class SqlAnalysisResult { + private final Set affectedTables; + private final Map encryptedColumns; + private final String queryType; + + public SqlAnalysisResult(Set affectedTables, Map encryptedColumns, String queryType) { + this.affectedTables = Collections.unmodifiableSet(new HashSet<>(affectedTables)); + this.encryptedColumns = Collections.unmodifiableMap(new HashMap<>(encryptedColumns)); + this.queryType = queryType; + } + + public Set getAffectedTables() { + return affectedTables; + } + + public Map getEncryptedColumns() { + return encryptedColumns; + } + + public String getQueryType() { + return queryType; + } + + public boolean hasEncryptedColumns() { + return !encryptedColumns.isEmpty(); + } + + public int getTableCount() { + return affectedTables.size(); + } + + public int getEncryptedColumnCount() { + return encryptedColumns.size(); + } + + @Override + public String toString() { + return String.format("SqlAnalysisResult{tables=%d, encryptedColumns=%d}", + getTableCount(), getEncryptedColumnCount()); + } + } + + /** + * Gets column-to-parameter mapping for prepared statement parameters. + */ + public Map getColumnParameterMapping(String sql) { + Map mapping = new HashMap<>(); + + try { + SQLAnalyzer.QueryAnalysis queryAnalysis = analyzer.analyze(sql); + if (queryAnalysis != null) { + // For SELECT statements, map parameters to WHERE clause columns (where ? placeholders are) + if ("SELECT".equals(queryAnalysis.queryType)) { + // Map parameters to WHERE clause columns + for (int i = 0; i < queryAnalysis.whereColumns.size(); i++) { + SQLAnalyzer.ColumnInfo column = queryAnalysis.whereColumns.get(i); + mapping.put(i + 1, column.columnName); + } + } else if (!queryAnalysis.columns.isEmpty()) { + // For INSERT/UPDATE, map parameters to main columns in order + int parameterIndex = 1; + for (SQLAnalyzer.ColumnInfo column : queryAnalysis.columns) { + mapping.put(parameterIndex++, column.columnName); + } + } + } + } catch (Exception e) { + LOGGER.warning(()->String.format("Failed to get column parameter mapping for SQL: %s", sql)); + } + + return mapping; + } + + /** + * Count the number of parameter placeholders (?) in SQL. + */ + private int countParameters(String sql) { + int count = 0; + for (int i = 0; i < sql.length(); i++) { + if (sql.charAt(i) == '?') { + count++; + } + } + return count; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/DecryptingResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/DecryptingResultSet.java new file mode 100644 index 000000000..a47d42cbb --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/DecryptingResultSet.java @@ -0,0 +1,1603 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.wrapper; + +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.model.ColumnEncryptionConfig; +import software.amazon.jdbc.plugin.encryption.service.EncryptionService; +import software.amazon.jdbc.plugin.encryption.key.KeyManager; +import java.util.logging.Logger; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.*; +import java.util.Calendar; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A ResultSet wrapper that automatically decrypts values from columns + * configured for encryption. Uses delegation pattern for non-encrypted + * operations. + */ +public class DecryptingResultSet implements ResultSet { + + private static final Logger LOGGER = Logger.getLogger(DecryptingResultSet.class.getName()); + + private final ResultSet delegate; + private final MetadataManager metadataManager; + private final EncryptionService encryptionService; + private final KeyManager keyManager; + + // Cache for column index/name to encryption config mapping + private final Map columnConfigCache = new ConcurrentHashMap<>(); + private final Map columnIndexToNameCache = new ConcurrentHashMap<>(); + private String tableName; + private boolean metadataInitialized = false; + + public DecryptingResultSet(ResultSet delegate, + MetadataManager metadataManager, + EncryptionService encryptionService, + KeyManager keyManager) { + this.delegate = delegate; + this.metadataManager = metadataManager; + this.encryptionService = encryptionService; + this.keyManager = keyManager; + + // Initialize metadata mapping + initializeMetadata(); + } + + /** + * Initializes column metadata by examining the ResultSet metadata. + */ + private void initializeMetadata() { + try { + ResultSetMetaData rsmd = delegate.getMetaData(); + + // Get table name from first column (assuming single table queries) + if (rsmd.getColumnCount() > 0) { + this.tableName = rsmd.getTableName(1); + + // Build column index to name mapping + for (int i = 1; i <= rsmd.getColumnCount(); i++) { + String columnName = rsmd.getColumnName(i); + columnIndexToNameCache.put(i, columnName); + + // Check if column is encrypted and cache the config + if (tableName != null && metadataManager.isColumnEncrypted(tableName, columnName)) { + ColumnEncryptionConfig config = metadataManager.getColumnConfig(tableName, columnName); + if (config != null) { + columnConfigCache.put(columnName, config); + LOGGER.finest(()->String.format("Cached encryption config for column %s.%s", tableName, columnName)); + } + } + } + } + + metadataInitialized = true; + LOGGER.finest(()-> { + try { + return String.format("Metadata initialized for table: %s with %s columns", + tableName, rsmd.getColumnCount()); + } catch (SQLException e) { + return String.format("Error getting resultset metadata %s",e.getMessage()); + } + }); + + } catch (Exception e) { + LOGGER.warning(()->String.format("Failed to initialize ResultSet metadata %s", e.getMessage())); + metadataInitialized = false; + } + } + + /** + * Gets the column name for a given column index. + */ + private String getColumnName(int columnIndex) throws SQLException { + String columnName = columnIndexToNameCache.get(columnIndex); + if (columnName == null) { + // Fallback to metadata lookup + ResultSetMetaData rsmd = delegate.getMetaData(); + if (columnIndex >= 1 && columnIndex <= rsmd.getColumnCount()) { + columnName = rsmd.getColumnName(columnIndex); + columnIndexToNameCache.put(columnIndex, columnName); + } + } + return columnName; + } + + /** + * Gets the encryption configuration for a column by name. + */ + private ColumnEncryptionConfig getColumnConfig(String columnName) { + return columnConfigCache.get(columnName); + } + + /** + * Checks if a column should be decrypted and decrypts it if necessary. + * Only attempts decryption for byte array values (encrypted data). + */ + private Object decryptValueIfNeeded(String columnName, Object value, Class targetType) throws SQLException { + if (!metadataInitialized || tableName == null || value == null) { + return value; + } + + // Only decrypt byte arrays - encrypted data should always be stored as bytes + if (!(value instanceof byte[])) { + LOGGER.finest(()->String.format("Skipping decryption for column %s.%s - value is not byte array (type: %s)", + tableName, columnName, value.getClass().getName())); + return value; + } + + try { + // Check if column is configured for encryption + ColumnEncryptionConfig config = getColumnConfig(columnName); + if (config == null) { + LOGGER.finest(()->String.format("No encryption config found for column %s.%s", tableName, columnName)); + return value; + } + + byte[] encryptedBytes = (byte[]) value; + LOGGER.finest(()->String.format("Attempting to decrypt byte array for column %s.%s (length: %s)", + tableName, columnName, encryptedBytes.length)); + + // Get data key for decryption + byte[] dataKey = keyManager.decryptDataKey( + config.getKeyMetadata().getEncryptedDataKey(), + config.getKeyMetadata().getMasterKeyArn()); + + if (dataKey == null) { + LOGGER.severe(()->String.format("Failed to decrypt data key for column %s.%s", tableName, columnName)); + throw new SQLException("Data key decryption failed"); + } + + // Get HMAC key + byte[] hmacKey = config.getKeyMetadata().getHmacKey(); + + // Decrypt the value + Object decryptedValue = encryptionService.decrypt( + encryptedBytes, + dataKey, + hmacKey, + config.getAlgorithm(), + targetType); + + // Clear the data key from memory + java.util.Arrays.fill(dataKey, (byte) 0); + + LOGGER.finest(()->String.format("Successfully decrypted value for column %s.%s", tableName, columnName)); + return decryptedValue; + + } catch (Exception e) { + String errorMsg = String.format("Failed to decrypt value for column %s.%s %s", + tableName, columnName, e.getMessage()); + LOGGER.severe(()->errorMsg); + throw new SQLException(errorMsg, e); + } + } + + /** + * Decrypts value by column index. + */ + private Object decryptValueIfNeeded(int columnIndex, Object value, Class targetType) throws SQLException { + String columnName = getColumnName(columnIndex); + if (columnName == null) { + return value; + } + return decryptValueIfNeeded(columnName, value, targetType); + } + + // Override getXXX methods to add decryption logic + + @Override + public String getString(int columnIndex) throws SQLException { + String columnName = getColumnName(columnIndex); + ColumnEncryptionConfig config = getColumnConfig(columnName); + + // If column is encrypted, get as EncryptedData + Object value; + if (config != null) { + Object obj = delegate.getObject(columnIndex); + if (obj instanceof EncryptedData) { + value = ((EncryptedData) obj).getBytes(); + } else { + value = delegate.getBytes(columnIndex); + } + } else { + value = delegate.getObject(columnIndex); + } + + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, String.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof String) { + return (String) decryptedValue; + } else { + return decryptedValue.toString(); + } + } + + @Override + public String getString(String columnLabel) throws SQLException { + ColumnEncryptionConfig config = getColumnConfig(columnLabel); + + // If column is encrypted, get as EncryptedData + Object value; + if (config != null) { + Object obj = delegate.getObject(columnLabel); + if (obj instanceof EncryptedData) { + value = ((EncryptedData) obj).getBytes(); + } else { + value = delegate.getBytes(columnLabel); + } + } else { + value = delegate.getObject(columnLabel); + } + + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, String.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof String) { + return (String) decryptedValue; + } else { + return decryptedValue.toString(); + } + } + + private static byte[] hexToBytes(String hex) { + int len = hex.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) + + Character.digit(hex.charAt(i+1), 16)); + } + return data; + } + + @Override + public int getInt(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Integer.class); + + if (decryptedValue == null) { + return 0; + } else if (decryptedValue instanceof Integer) { + return (Integer) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).intValue(); + } else { + return Integer.parseInt(decryptedValue.toString()); + } + } + + @Override + public int getInt(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Integer.class); + + if (decryptedValue == null) { + return 0; + } else if (decryptedValue instanceof Integer) { + return (Integer) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).intValue(); + } else { + return Integer.parseInt(decryptedValue.toString()); + } + } + + @Override + public long getLong(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Long.class); + + if (decryptedValue == null) { + return 0L; + } else if (decryptedValue instanceof Long) { + return (Long) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).longValue(); + } else { + return Long.parseLong(decryptedValue.toString()); + } + } + + @Override + public long getLong(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Long.class); + + if (decryptedValue == null) { + return 0L; + } else if (decryptedValue instanceof Long) { + return (Long) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).longValue(); + } else { + return Long.parseLong(decryptedValue.toString()); + } + } + + @Override + public byte[] getBytes(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, byte[].class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof byte[]) { + return (byte[]) decryptedValue; + } else { + return decryptedValue.toString().getBytes(); + } + } + + @Override + public byte[] getBytes(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, byte[].class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof byte[]) { + return (byte[]) decryptedValue; + } else { + return decryptedValue.toString().getBytes(); + } + } + + @Override + public double getDouble(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Double.class); + + if (decryptedValue == null) { + return 0.0; + } else if (decryptedValue instanceof Double) { + return (Double) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).doubleValue(); + } else { + return Double.parseDouble(decryptedValue.toString()); + } + } + + @Override + public double getDouble(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Double.class); + + if (decryptedValue == null) { + return 0.0; + } else if (decryptedValue instanceof Double) { + return (Double) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).doubleValue(); + } else { + return Double.parseDouble(decryptedValue.toString()); + } + } + + @Override + public float getFloat(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Float.class); + + if (decryptedValue == null) { + return 0.0f; + } else if (decryptedValue instanceof Float) { + return (Float) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).floatValue(); + } else { + return Float.parseFloat(decryptedValue.toString()); + } + } + + @Override + public float getFloat(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Float.class); + + if (decryptedValue == null) { + return 0.0f; + } else if (decryptedValue instanceof Float) { + return (Float) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).floatValue(); + } else { + return Float.parseFloat(decryptedValue.toString()); + } + } + + @Override + public boolean getBoolean(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Boolean.class); + + if (decryptedValue == null) { + return false; + } else if (decryptedValue instanceof Boolean) { + return (Boolean) decryptedValue; + } else { + return Boolean.parseBoolean(decryptedValue.toString()); + } + } + + @Override + public boolean getBoolean(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Boolean.class); + + if (decryptedValue == null) { + return false; + } else if (decryptedValue instanceof Boolean) { + return (Boolean) decryptedValue; + } else { + return Boolean.parseBoolean(decryptedValue.toString()); + } + } + + @Override + public short getShort(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Short.class); + + if (decryptedValue == null) { + return 0; + } else if (decryptedValue instanceof Short) { + return (Short) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).shortValue(); + } else { + return Short.parseShort(decryptedValue.toString()); + } + } + + @Override + public short getShort(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Short.class); + + if (decryptedValue == null) { + return 0; + } else if (decryptedValue instanceof Short) { + return (Short) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).shortValue(); + } else { + return Short.parseShort(decryptedValue.toString()); + } + } + + @Override + public byte getByte(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Byte.class); + + if (decryptedValue == null) { + return 0; + } else if (decryptedValue instanceof Byte) { + return (Byte) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).byteValue(); + } else { + return Byte.parseByte(decryptedValue.toString()); + } + } + + @Override + public byte getByte(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Byte.class); + + if (decryptedValue == null) { + return 0; + } else if (decryptedValue instanceof Byte) { + return (Byte) decryptedValue; + } else if (decryptedValue instanceof Number) { + return ((Number) decryptedValue).byteValue(); + } else { + return Byte.parseByte(decryptedValue.toString()); + } + } + + @Override + public BigDecimal getBigDecimal(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, BigDecimal.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof BigDecimal) { + return (BigDecimal) decryptedValue; + } else if (decryptedValue instanceof Number) { + return BigDecimal.valueOf(((Number) decryptedValue).doubleValue()); + } else { + return new BigDecimal(decryptedValue.toString()); + } + } + + @Override + public BigDecimal getBigDecimal(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, BigDecimal.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof BigDecimal) { + return (BigDecimal) decryptedValue; + } else if (decryptedValue instanceof Number) { + return BigDecimal.valueOf(((Number) decryptedValue).doubleValue()); + } else { + return new BigDecimal(decryptedValue.toString()); + } + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException { + BigDecimal value = getBigDecimal(columnIndex); + return value != null ? value.setScale(scale, BigDecimal.ROUND_HALF_UP) : null; + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException { + BigDecimal value = getBigDecimal(columnLabel); + return value != null ? value.setScale(scale, BigDecimal.ROUND_HALF_UP) : null; + } + + @Override + public Date getDate(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Date.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Date) { + return (Date) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Date(((java.util.Date) decryptedValue).getTime()); + } else { + return Date.valueOf(decryptedValue.toString()); + } + } + + @Override + public Date getDate(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Date.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Date) { + return (Date) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Date(((java.util.Date) decryptedValue).getTime()); + } else { + return Date.valueOf(decryptedValue.toString()); + } + } + + @Override + public Time getTime(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Time.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Time) { + return (Time) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Time(((java.util.Date) decryptedValue).getTime()); + } else { + return Time.valueOf(decryptedValue.toString()); + } + } + + @Override + public Time getTime(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Time.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Time) { + return (Time) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Time(((java.util.Date) decryptedValue).getTime()); + } else { + return Time.valueOf(decryptedValue.toString()); + } + } + + @Override + public Timestamp getTimestamp(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Timestamp.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Timestamp) { + return (Timestamp) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Timestamp(((java.util.Date) decryptedValue).getTime()); + } else { + return Timestamp.valueOf(decryptedValue.toString()); + } + } + + @Override + public Timestamp getTimestamp(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Timestamp.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Timestamp) { + return (Timestamp) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Timestamp(((java.util.Date) decryptedValue).getTime()); + } else { + return Timestamp.valueOf(decryptedValue.toString()); + } + } + + @Override + public Object getObject(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + return decryptValueIfNeeded(columnIndex, value, Object.class); + } + + @Override + public Object getObject(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + return decryptValueIfNeeded(columnLabel, value, Object.class); + } + + @Override + public Object getObject(int columnIndex, Map> map) throws SQLException { + Object value = delegate.getObject(columnIndex, map); + return decryptValueIfNeeded(columnIndex, value, Object.class); + } + + @Override + public Object getObject(String columnLabel, Map> map) throws SQLException { + Object value = delegate.getObject(columnLabel, map); + return decryptValueIfNeeded(columnLabel, value, Object.class); + } + + @Override + public T getObject(int columnIndex, Class type) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, type); + + if (decryptedValue == null) { + return null; + } else if (type.isAssignableFrom(decryptedValue.getClass())) { + return type.cast(decryptedValue); + } else { + throw new SQLException("Cannot convert decrypted value to " + type.getSimpleName()); + } + } + + @Override + public T getObject(String columnLabel, Class type) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, type); + + if (decryptedValue == null) { + return null; + } else if (type.isAssignableFrom(decryptedValue.getClass())) { + return type.cast(decryptedValue); + } else { + throw new SQLException("Cannot convert decrypted value to " + type.getSimpleName()); + } + } + + // Calendar-based date/time methods + @Override + public Date getDate(int columnIndex, Calendar cal) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Date.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Date) { + return (Date) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Date(((java.util.Date) decryptedValue).getTime()); + } else { + return Date.valueOf(decryptedValue.toString()); + } + } + + @Override + public Date getDate(String columnLabel, Calendar cal) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Date.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Date) { + return (Date) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Date(((java.util.Date) decryptedValue).getTime()); + } else { + return Date.valueOf(decryptedValue.toString()); + } + } + + @Override + public Time getTime(int columnIndex, Calendar cal) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Time.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Time) { + return (Time) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Time(((java.util.Date) decryptedValue).getTime()); + } else { + return Time.valueOf(decryptedValue.toString()); + } + } + + @Override + public Time getTime(String columnLabel, Calendar cal) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Time.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Time) { + return (Time) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Time(((java.util.Date) decryptedValue).getTime()); + } else { + return Time.valueOf(decryptedValue.toString()); + } + } + + @Override + public Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, Timestamp.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Timestamp) { + return (Timestamp) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Timestamp(((java.util.Date) decryptedValue).getTime()); + } else { + return Timestamp.valueOf(decryptedValue.toString()); + } + } + + @Override + public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, Timestamp.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof Timestamp) { + return (Timestamp) decryptedValue; + } else if (decryptedValue instanceof java.util.Date) { + return new Timestamp(((java.util.Date) decryptedValue).getTime()); + } else { + return Timestamp.valueOf(decryptedValue.toString()); + } + } + + // Stream and reader methods - delegate directly (decryption not supported for + // streams) + @Override + public InputStream getBinaryStream(int columnIndex) throws SQLException { + return delegate.getBinaryStream(columnIndex); + } + + @Override + public InputStream getBinaryStream(String columnLabel) throws SQLException { + return delegate.getBinaryStream(columnLabel); + } + + @Override + public InputStream getAsciiStream(int columnIndex) throws SQLException { + return delegate.getAsciiStream(columnIndex); + } + + @Override + public InputStream getAsciiStream(String columnLabel) throws SQLException { + return delegate.getAsciiStream(columnLabel); + } + + @Override + public Reader getCharacterStream(int columnIndex) throws SQLException { + return delegate.getCharacterStream(columnIndex); + } + + @Override + public Reader getCharacterStream(String columnLabel) throws SQLException { + return delegate.getCharacterStream(columnLabel); + } + + @Override + public Reader getNCharacterStream(int columnIndex) throws SQLException { + return delegate.getNCharacterStream(columnIndex); + } + + @Override + public Reader getNCharacterStream(String columnLabel) throws SQLException { + return delegate.getNCharacterStream(columnLabel); + } + + // Deprecated methods + @Override + @Deprecated + public InputStream getUnicodeStream(int columnIndex) throws SQLException { + return delegate.getUnicodeStream(columnIndex); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(String columnLabel) throws SQLException { + return delegate.getUnicodeStream(columnLabel); + } + + // Other specialized getters - delegate directly + @Override + public URL getURL(int columnIndex) throws SQLException { + return delegate.getURL(columnIndex); + } + + @Override + public URL getURL(String columnLabel) throws SQLException { + return delegate.getURL(columnLabel); + } + + @Override + public Ref getRef(int columnIndex) throws SQLException { + return delegate.getRef(columnIndex); + } + + @Override + public Ref getRef(String columnLabel) throws SQLException { + return delegate.getRef(columnLabel); + } + + @Override + public Blob getBlob(int columnIndex) throws SQLException { + return delegate.getBlob(columnIndex); + } + + @Override + public Blob getBlob(String columnLabel) throws SQLException { + return delegate.getBlob(columnLabel); + } + + @Override + public Clob getClob(int columnIndex) throws SQLException { + return delegate.getClob(columnIndex); + } + + @Override + public Clob getClob(String columnLabel) throws SQLException { + return delegate.getClob(columnLabel); + } + + @Override + public NClob getNClob(int columnIndex) throws SQLException { + return delegate.getNClob(columnIndex); + } + + @Override + public NClob getNClob(String columnLabel) throws SQLException { + return delegate.getNClob(columnLabel); + } + + @Override + public Array getArray(int columnIndex) throws SQLException { + return delegate.getArray(columnIndex); + } + + @Override + public Array getArray(String columnLabel) throws SQLException { + return delegate.getArray(columnLabel); + } + + @Override + public SQLXML getSQLXML(int columnIndex) throws SQLException { + return delegate.getSQLXML(columnIndex); + } + + @Override + public SQLXML getSQLXML(String columnLabel) throws SQLException { + return delegate.getSQLXML(columnLabel); + } + + @Override + public RowId getRowId(int columnIndex) throws SQLException { + return delegate.getRowId(columnIndex); + } + + @Override + public RowId getRowId(String columnLabel) throws SQLException { + return delegate.getRowId(columnLabel); + } + + @Override + public String getNString(int columnIndex) throws SQLException { + Object value = delegate.getObject(columnIndex); + Object decryptedValue = decryptValueIfNeeded(columnIndex, value, String.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof String) { + return (String) decryptedValue; + } else { + return decryptedValue.toString(); + } + } + + @Override + public String getNString(String columnLabel) throws SQLException { + Object value = delegate.getObject(columnLabel); + Object decryptedValue = decryptValueIfNeeded(columnLabel, value, String.class); + + if (decryptedValue == null) { + return null; + } else if (decryptedValue instanceof String) { + return (String) decryptedValue; + } else { + return decryptedValue.toString(); + } + } + + // All other ResultSet methods delegate directly to the wrapped ResultSet + + @Override + public boolean next() throws SQLException { + return delegate.next(); + } + + @Override + public void close() throws SQLException { + delegate.close(); + } + + @Override + public boolean wasNull() throws SQLException { + return delegate.wasNull(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + return delegate.getMetaData(); + } + + @Override + public int findColumn(String columnLabel) throws SQLException { + return delegate.findColumn(columnLabel); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return delegate.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + delegate.clearWarnings(); + } + + @Override + public String getCursorName() throws SQLException { + return delegate.getCursorName(); + } + + @Override + public boolean isBeforeFirst() throws SQLException { + return delegate.isBeforeFirst(); + } + + @Override + public boolean isAfterLast() throws SQLException { + return delegate.isAfterLast(); + } + + @Override + public boolean isFirst() throws SQLException { + return delegate.isFirst(); + } + + @Override + public boolean isLast() throws SQLException { + return delegate.isLast(); + } + + @Override + public void beforeFirst() throws SQLException { + delegate.beforeFirst(); + } + + @Override + public void afterLast() throws SQLException { + delegate.afterLast(); + } + + @Override + public boolean first() throws SQLException { + return delegate.first(); + } + + @Override + public boolean last() throws SQLException { + return delegate.last(); + } + + @Override + public int getRow() throws SQLException { + return delegate.getRow(); + } + + @Override + public boolean absolute(int row) throws SQLException { + return delegate.absolute(row); + } + + @Override + public boolean relative(int rows) throws SQLException { + return delegate.relative(rows); + } + + @Override + public boolean previous() throws SQLException { + return delegate.previous(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + delegate.setFetchDirection(direction); + } + + @Override + public int getFetchDirection() throws SQLException { + return delegate.getFetchDirection(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + delegate.setFetchSize(rows); + } + + @Override + public int getFetchSize() throws SQLException { + return delegate.getFetchSize(); + } + + @Override + public int getType() throws SQLException { + return delegate.getType(); + } + + @Override + public int getConcurrency() throws SQLException { + return delegate.getConcurrency(); + } + + @Override + public boolean rowUpdated() throws SQLException { + return delegate.rowUpdated(); + } + + @Override + public boolean rowInserted() throws SQLException { + return delegate.rowInserted(); + } + + @Override + public boolean rowDeleted() throws SQLException { + return delegate.rowDeleted(); + } + + // Update methods - delegate directly (no encryption on updates through + // ResultSet) + @Override + public void updateNull(int columnIndex) throws SQLException { + delegate.updateNull(columnIndex); + } + + @Override + public void updateNull(String columnLabel) throws SQLException { + delegate.updateNull(columnLabel); + } + + @Override + public void updateBoolean(int columnIndex, boolean x) throws SQLException { + delegate.updateBoolean(columnIndex, x); + } + + @Override + public void updateBoolean(String columnLabel, boolean x) throws SQLException { + delegate.updateBoolean(columnLabel, x); + } + + @Override + public void updateByte(int columnIndex, byte x) throws SQLException { + delegate.updateByte(columnIndex, x); + } + + @Override + public void updateByte(String columnLabel, byte x) throws SQLException { + delegate.updateByte(columnLabel, x); + } + + @Override + public void updateShort(int columnIndex, short x) throws SQLException { + delegate.updateShort(columnIndex, x); + } + + @Override + public void updateShort(String columnLabel, short x) throws SQLException { + delegate.updateShort(columnLabel, x); + } + + @Override + public void updateInt(int columnIndex, int x) throws SQLException { + delegate.updateInt(columnIndex, x); + } + + @Override + public void updateInt(String columnLabel, int x) throws SQLException { + delegate.updateInt(columnLabel, x); + } + + @Override + public void updateLong(int columnIndex, long x) throws SQLException { + delegate.updateLong(columnIndex, x); + } + + @Override + public void updateLong(String columnLabel, long x) throws SQLException { + delegate.updateLong(columnLabel, x); + } + + @Override + public void updateFloat(int columnIndex, float x) throws SQLException { + delegate.updateFloat(columnIndex, x); + } + + @Override + public void updateFloat(String columnLabel, float x) throws SQLException { + delegate.updateFloat(columnLabel, x); + } + + @Override + public void updateDouble(int columnIndex, double x) throws SQLException { + delegate.updateDouble(columnIndex, x); + } + + @Override + public void updateDouble(String columnLabel, double x) throws SQLException { + delegate.updateDouble(columnLabel, x); + } + + @Override + public void updateBigDecimal(int columnIndex, BigDecimal x) throws SQLException { + delegate.updateBigDecimal(columnIndex, x); + } + + @Override + public void updateBigDecimal(String columnLabel, BigDecimal x) throws SQLException { + delegate.updateBigDecimal(columnLabel, x); + } + + @Override + public void updateString(int columnIndex, String x) throws SQLException { + delegate.updateString(columnIndex, x); + } + + @Override + public void updateString(String columnLabel, String x) throws SQLException { + delegate.updateString(columnLabel, x); + } + + @Override + public void updateBytes(int columnIndex, byte[] x) throws SQLException { + delegate.updateBytes(columnIndex, x); + } + + @Override + public void updateBytes(String columnLabel, byte[] x) throws SQLException { + delegate.updateBytes(columnLabel, x); + } + + @Override + public void updateDate(int columnIndex, Date x) throws SQLException { + delegate.updateDate(columnIndex, x); + } + + @Override + public void updateDate(String columnLabel, Date x) throws SQLException { + delegate.updateDate(columnLabel, x); + } + + @Override + public void updateTime(int columnIndex, Time x) throws SQLException { + delegate.updateTime(columnIndex, x); + } + + @Override + public void updateTime(String columnLabel, Time x) throws SQLException { + delegate.updateTime(columnLabel, x); + } + + @Override + public void updateTimestamp(int columnIndex, Timestamp x) throws SQLException { + delegate.updateTimestamp(columnIndex, x); + } + + @Override + public void updateTimestamp(String columnLabel, Timestamp x) throws SQLException { + delegate.updateTimestamp(columnLabel, x); + } + + @Override + public void updateAsciiStream(int columnIndex, InputStream x, int length) throws SQLException { + delegate.updateAsciiStream(columnIndex, x, length); + } + + @Override + public void updateAsciiStream(String columnLabel, InputStream x, int length) throws SQLException { + delegate.updateAsciiStream(columnLabel, x, length); + } + + @Override + public void updateAsciiStream(int columnIndex, InputStream x, long length) throws SQLException { + delegate.updateAsciiStream(columnIndex, x, length); + } + + @Override + public void updateAsciiStream(String columnLabel, InputStream x, long length) throws SQLException { + delegate.updateAsciiStream(columnLabel, x, length); + } + + @Override + public void updateAsciiStream(int columnIndex, InputStream x) throws SQLException { + delegate.updateAsciiStream(columnIndex, x); + } + + @Override + public void updateAsciiStream(String columnLabel, InputStream x) throws SQLException { + delegate.updateAsciiStream(columnLabel, x); + } + + @Override + public void updateBinaryStream(int columnIndex, InputStream x, int length) throws SQLException { + delegate.updateBinaryStream(columnIndex, x, length); + } + + @Override + public void updateBinaryStream(String columnLabel, InputStream x, int length) throws SQLException { + delegate.updateBinaryStream(columnLabel, x, length); + } + + @Override + public void updateBinaryStream(int columnIndex, InputStream x, long length) throws SQLException { + delegate.updateBinaryStream(columnIndex, x, length); + } + + @Override + public void updateBinaryStream(String columnLabel, InputStream x, long length) throws SQLException { + delegate.updateBinaryStream(columnLabel, x, length); + } + + @Override + public void updateBinaryStream(int columnIndex, InputStream x) throws SQLException { + delegate.updateBinaryStream(columnIndex, x); + } + + @Override + public void updateBinaryStream(String columnLabel, InputStream x) throws SQLException { + delegate.updateBinaryStream(columnLabel, x); + } + + @Override + public void updateCharacterStream(int columnIndex, Reader x, int length) throws SQLException { + delegate.updateCharacterStream(columnIndex, x, length); + } + + @Override + public void updateCharacterStream(String columnLabel, Reader reader, int length) throws SQLException { + delegate.updateCharacterStream(columnLabel, reader, length); + } + + @Override + public void updateCharacterStream(int columnIndex, Reader x, long length) throws SQLException { + delegate.updateCharacterStream(columnIndex, x, length); + } + + @Override + public void updateCharacterStream(String columnLabel, Reader reader, long length) throws SQLException { + delegate.updateCharacterStream(columnLabel, reader, length); + } + + @Override + public void updateCharacterStream(int columnIndex, Reader x) throws SQLException { + delegate.updateCharacterStream(columnIndex, x); + } + + @Override + public void updateCharacterStream(String columnLabel, Reader reader) throws SQLException { + delegate.updateCharacterStream(columnLabel, reader); + } + + @Override + public void updateObject(int columnIndex, Object x, int scaleOrLength) throws SQLException { + delegate.updateObject(columnIndex, x, scaleOrLength); + } + + @Override + public void updateObject(String columnLabel, Object x, int scaleOrLength) throws SQLException { + delegate.updateObject(columnLabel, x, scaleOrLength); + } + + @Override + public void updateObject(int columnIndex, Object x) throws SQLException { + delegate.updateObject(columnIndex, x); + } + + @Override + public void updateObject(String columnLabel, Object x) throws SQLException { + delegate.updateObject(columnLabel, x); + } + + @Override + public void insertRow() throws SQLException { + delegate.insertRow(); + } + + @Override + public void updateRow() throws SQLException { + delegate.updateRow(); + } + + @Override + public void deleteRow() throws SQLException { + delegate.deleteRow(); + } + + @Override + public void refreshRow() throws SQLException { + delegate.refreshRow(); + } + + @Override + public void cancelRowUpdates() throws SQLException { + delegate.cancelRowUpdates(); + } + + @Override + public void moveToInsertRow() throws SQLException { + delegate.moveToInsertRow(); + } + + @Override + public void moveToCurrentRow() throws SQLException { + delegate.moveToCurrentRow(); + } + + @Override + public Statement getStatement() throws SQLException { + return delegate.getStatement(); + } + + @Override + public void updateRef(int columnIndex, Ref x) throws SQLException { + delegate.updateRef(columnIndex, x); + } + + @Override + public void updateRef(String columnLabel, Ref x) throws SQLException { + delegate.updateRef(columnLabel, x); + } + + @Override + public void updateBlob(int columnIndex, Blob x) throws SQLException { + delegate.updateBlob(columnIndex, x); + } + + @Override + public void updateBlob(String columnLabel, Blob x) throws SQLException { + delegate.updateBlob(columnLabel, x); + } + + @Override + public void updateBlob(int columnIndex, InputStream inputStream, long length) throws SQLException { + delegate.updateBlob(columnIndex, inputStream, length); + } + + @Override + public void updateBlob(String columnLabel, InputStream inputStream, long length) throws SQLException { + delegate.updateBlob(columnLabel, inputStream, length); + } + + @Override + public void updateBlob(int columnIndex, InputStream inputStream) throws SQLException { + delegate.updateBlob(columnIndex, inputStream); + } + + @Override + public void updateBlob(String columnLabel, InputStream inputStream) throws SQLException { + delegate.updateBlob(columnLabel, inputStream); + } + + @Override + public void updateClob(int columnIndex, Clob x) throws SQLException { + delegate.updateClob(columnIndex, x); + } + + @Override + public void updateClob(String columnLabel, Clob x) throws SQLException { + delegate.updateClob(columnLabel, x); + } + + @Override + public void updateClob(int columnIndex, Reader reader, long length) throws SQLException { + delegate.updateClob(columnIndex, reader, length); + } + + @Override + public void updateClob(String columnLabel, Reader reader, long length) throws SQLException { + delegate.updateClob(columnLabel, reader, length); + } + + @Override + public void updateClob(int columnIndex, Reader reader) throws SQLException { + delegate.updateClob(columnIndex, reader); + } + + @Override + public void updateClob(String columnLabel, Reader reader) throws SQLException { + delegate.updateClob(columnLabel, reader); + } + + @Override + public void updateArray(int columnIndex, Array x) throws SQLException { + delegate.updateArray(columnIndex, x); + } + + @Override + public void updateArray(String columnLabel, Array x) throws SQLException { + delegate.updateArray(columnLabel, x); + } + + @Override + public void updateRowId(int columnIndex, RowId x) throws SQLException { + delegate.updateRowId(columnIndex, x); + } + + @Override + public void updateRowId(String columnLabel, RowId x) throws SQLException { + delegate.updateRowId(columnLabel, x); + } + + @Override + public int getHoldability() throws SQLException { + return delegate.getHoldability(); + } + + @Override + public boolean isClosed() throws SQLException { + return delegate.isClosed(); + } + + @Override + public void updateNString(int columnIndex, String nString) throws SQLException { + delegate.updateNString(columnIndex, nString); + } + + @Override + public void updateNString(String columnLabel, String nString) throws SQLException { + delegate.updateNString(columnLabel, nString); + } + + @Override + public void updateNClob(int columnIndex, NClob nClob) throws SQLException { + delegate.updateNClob(columnIndex, nClob); + } + + @Override + public void updateNClob(String columnLabel, NClob nClob) throws SQLException { + delegate.updateNClob(columnLabel, nClob); + } + + @Override + public void updateNClob(int columnIndex, Reader reader, long length) throws SQLException { + delegate.updateNClob(columnIndex, reader, length); + } + + @Override + public void updateNClob(String columnLabel, Reader reader, long length) throws SQLException { + delegate.updateNClob(columnLabel, reader, length); + } + + @Override + public void updateNClob(int columnIndex, Reader reader) throws SQLException { + delegate.updateNClob(columnIndex, reader); + } + + @Override + public void updateNClob(String columnLabel, Reader reader) throws SQLException { + delegate.updateNClob(columnLabel, reader); + } + + @Override + public void updateSQLXML(int columnIndex, SQLXML xmlObject) throws SQLException { + delegate.updateSQLXML(columnIndex, xmlObject); + } + + @Override + public void updateSQLXML(String columnLabel, SQLXML xmlObject) throws SQLException { + delegate.updateSQLXML(columnLabel, xmlObject); + } + + @Override + public void updateNCharacterStream(int columnIndex, Reader x, long length) throws SQLException { + delegate.updateNCharacterStream(columnIndex, x, length); + } + + @Override + public void updateNCharacterStream(String columnLabel, Reader reader, long length) throws SQLException { + delegate.updateNCharacterStream(columnLabel, reader, length); + } + + @Override + public void updateNCharacterStream(int columnIndex, Reader x) throws SQLException { + delegate.updateNCharacterStream(columnIndex, x); + } + + @Override + public void updateNCharacterStream(String columnLabel, Reader reader) throws SQLException { + delegate.updateNCharacterStream(columnLabel, reader); + } + + // Wrapper interface methods + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isAssignableFrom(getClass())) { + return iface.cast(this); + } + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return iface.isAssignableFrom(getClass()) || delegate.isWrapperFor(iface); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptedData.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptedData.java new file mode 100644 index 000000000..f508ad3c3 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptedData.java @@ -0,0 +1,86 @@ +package software.amazon.jdbc.plugin.encryption.wrapper; + +import org.postgresql.util.PGBinaryObject; +import org.postgresql.util.PGobject; + +import java.sql.SQLException; + +/** + * PostgreSQL custom type wrapper for encrypted_data. + * Handles binary data transfer for the encrypted_data type. + */ +public class EncryptedData extends PGobject implements PGBinaryObject { + + private byte[] bytes; + + public EncryptedData() { + setType("encrypted_data"); + } + + public EncryptedData(byte[] bytes) { + setType("encrypted_data"); + this.bytes = bytes; + } + + @Override + public void setByteValue(byte[] value, int offset) throws SQLException { + // Binary mode: raw bytes, no hex encoding + this.bytes = new byte[value.length - offset]; + System.arraycopy(value, offset, this.bytes, 0, this.bytes.length); + } + + @Override + public int lengthInBytes() { + // Binary mode: actual byte length + return bytes != null ? bytes.length : 0; + } + + @Override + public void toBytes(byte[] target, int offset) { + // Binary mode: raw bytes, no hex encoding + if (this.bytes != null) { + System.arraycopy(this.bytes, 0, target, offset, this.bytes.length); + } + } + + public byte[] getBytes() { + return bytes; + } + + @Override + public void setValue(String value) throws SQLException { + // Text mode: hex-encoded string + if (value != null && value.startsWith("\\x")) { + this.bytes = hexToBytes(value.substring(2)); + } else { + this.bytes = null; + } + } + + @Override + public String getValue() { + // Text mode: hex-encoded string + if (bytes == null) { + return null; + } + return "\\x" + bytesToHex(bytes); + } + + private static byte[] hexToBytes(String hex) { + int len = hex.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) + + Character.digit(hex.charAt(i+1), 16)); + } + return data; + } + + private static String bytesToHex(byte[] bytes) { + StringBuilder sb = new StringBuilder(); + for (byte b : bytes) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingConnection.java new file mode 100644 index 000000000..7802e8cfd --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingConnection.java @@ -0,0 +1,353 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.wrapper; + +import software.amazon.jdbc.plugin.encryption.KmsEncryptionPlugin; +import java.util.logging.Logger; + +import java.sql.*; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Executor; + +/** + * A Connection wrapper that provides transparent encryption/decryption functionality + * by wrapping PreparedStatements and ResultSets with encryption-aware implementations. + */ +public class EncryptingConnection implements Connection { + + private static final Logger LOGGER = Logger.getLogger(EncryptingConnection.class.getName()); + + private final Connection delegate; + private final KmsEncryptionPlugin encryptionPlugin; + + /** + * Creates an encrypting connection wrapper. + * + * @param delegate The underlying Connection to wrap + * @param encryptionPlugin The encryption plugin to use + */ + public EncryptingConnection(Connection delegate, KmsEncryptionPlugin encryptionPlugin) { + this.delegate = delegate; + this.encryptionPlugin = encryptionPlugin; + + LOGGER.finest(()->"Created EncryptingConnection wrapper"); + } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + PreparedStatement statement = delegate.prepareStatement(sql); + return encryptionPlugin.wrapPreparedStatement(statement, sql); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + PreparedStatement statement = delegate.prepareStatement(sql, resultSetType, resultSetConcurrency); + return encryptionPlugin.wrapPreparedStatement(statement, sql); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { + PreparedStatement statement = delegate.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + return encryptionPlugin.wrapPreparedStatement(statement, sql); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + PreparedStatement statement = delegate.prepareStatement(sql, autoGeneratedKeys); + return encryptionPlugin.wrapPreparedStatement(statement, sql); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + PreparedStatement statement = delegate.prepareStatement(sql, columnIndexes); + return encryptionPlugin.wrapPreparedStatement(statement, sql); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + PreparedStatement statement = delegate.prepareStatement(sql, columnNames); + return encryptionPlugin.wrapPreparedStatement(statement, sql); + } + + @Override + public Statement createStatement() throws SQLException { + Statement statement = delegate.createStatement(); + return new EncryptingStatement(statement, encryptionPlugin); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { + Statement statement = delegate.createStatement(resultSetType, resultSetConcurrency); + return new EncryptingStatement(statement, encryptionPlugin); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { + Statement statement = delegate.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); + return new EncryptingStatement(statement, encryptionPlugin); + } + + // All other Connection methods delegate directly to the wrapped connection + + @Override + public CallableStatement prepareCall(String sql) throws SQLException { + return delegate.prepareCall(sql); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + return delegate.prepareCall(sql, resultSetType, resultSetConcurrency); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { + return delegate.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public String nativeSQL(String sql) throws SQLException { + return delegate.nativeSQL(sql); + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + delegate.setAutoCommit(autoCommit); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return delegate.getAutoCommit(); + } + + @Override + public void commit() throws SQLException { + delegate.commit(); + } + + @Override + public void rollback() throws SQLException { + delegate.rollback(); + } + + @Override + public void rollback(Savepoint savepoint) throws SQLException { + delegate.rollback(savepoint); + } + + @Override + public void close() throws SQLException { + delegate.close(); + } + + @Override + public boolean isClosed() throws SQLException { + return delegate.isClosed(); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return delegate.getMetaData(); + } + + @Override + public void setReadOnly(boolean readOnly) throws SQLException { + delegate.setReadOnly(readOnly); + } + + @Override + public boolean isReadOnly() throws SQLException { + return delegate.isReadOnly(); + } + + @Override + public void setCatalog(String catalog) throws SQLException { + delegate.setCatalog(catalog); + } + + @Override + public String getCatalog() throws SQLException { + return delegate.getCatalog(); + } + + @Override + public void setTransactionIsolation(int level) throws SQLException { + delegate.setTransactionIsolation(level); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return delegate.getTransactionIsolation(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return delegate.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + delegate.clearWarnings(); + } + + @Override + public Map> getTypeMap() throws SQLException { + return delegate.getTypeMap(); + } + + @Override + public void setTypeMap(Map> map) throws SQLException { + delegate.setTypeMap(map); + } + + @Override + public void setHoldability(int holdability) throws SQLException { + delegate.setHoldability(holdability); + } + + @Override + public int getHoldability() throws SQLException { + return delegate.getHoldability(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return delegate.setSavepoint(); + } + + @Override + public Savepoint setSavepoint(String name) throws SQLException { + return delegate.setSavepoint(name); + } + + @Override + public void releaseSavepoint(Savepoint savepoint) throws SQLException { + delegate.releaseSavepoint(savepoint); + } + + @Override + public Clob createClob() throws SQLException { + return delegate.createClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return delegate.createBlob(); + } + + @Override + public NClob createNClob() throws SQLException { + return delegate.createNClob(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return delegate.createSQLXML(); + } + + @Override + public boolean isValid(int timeout) throws SQLException { + return delegate.isValid(timeout); + } + + @Override + public void setClientInfo(String name, String value) throws SQLClientInfoException { + delegate.setClientInfo(name, value); + } + + @Override + public void setClientInfo(Properties properties) throws SQLClientInfoException { + delegate.setClientInfo(properties); + } + + @Override + public String getClientInfo(String name) throws SQLException { + return delegate.getClientInfo(name); + } + + @Override + public Properties getClientInfo() throws SQLException { + return delegate.getClientInfo(); + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + return delegate.createArrayOf(typeName, elements); + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + return delegate.createStruct(typeName, attributes); + } + + @Override + public void setSchema(String schema) throws SQLException { + delegate.setSchema(schema); + } + + @Override + public String getSchema() throws SQLException { + return delegate.getSchema(); + } + + @Override + public void abort(Executor executor) throws SQLException { + delegate.abort(executor); + } + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + delegate.setNetworkTimeout(executor, milliseconds); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return delegate.getNetworkTimeout(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isAssignableFrom(getClass())) { + return iface.cast(this); + } + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return iface.isAssignableFrom(getClass()) || delegate.isWrapperFor(iface); + } + + /** + * Gets the underlying Connection. + * + * @return The wrapped Connection + */ + public Connection getDelegate() { + return delegate; + } + + /** + * Gets the encryption plugin instance. + * + * @return The KmsEncryptionPlugin instance + */ + public KmsEncryptionPlugin getEncryptionPlugin() { + return encryptionPlugin; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingDataSource.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingDataSource.java new file mode 100644 index 000000000..9cd426c17 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingDataSource.java @@ -0,0 +1,275 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.wrapper; + +import software.amazon.jdbc.plugin.encryption.KmsEncryptionPlugin; +import java.util.logging.Logger; + +import javax.sql.DataSource; +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Properties; + +/** + * A DataSource wrapper that integrates encryption capabilities with the AWS Advanced JDBC Wrapper. + * This DataSource wraps connections to provide transparent encryption/decryption functionality. + */ +public class EncryptingDataSource implements DataSource { + + private static final Logger LOGGER = Logger.getLogger(EncryptingDataSource.class.getName()); + + private final DataSource delegate; + private final KmsEncryptionPlugin encryptionPlugin; + private final Properties encryptionProperties; + private volatile boolean closed = false; + + /** + * Creates an encrypting DataSource that wraps the provided DataSource. + * + * @param delegate The underlying DataSource to wrap + * @param encryptionProperties Properties for configuring encryption + * @throws SQLException if encryption plugin initialization fails + */ + public EncryptingDataSource(DataSource delegate, Properties encryptionProperties) throws SQLException { + this.delegate = delegate; + this.encryptionProperties = new Properties(); + this.encryptionProperties.putAll(encryptionProperties); + + // Initialize the encryption plugin + this.encryptionPlugin = new KmsEncryptionPlugin(); + this.encryptionPlugin.initialize(encryptionProperties); + + LOGGER.info("EncryptingDataSource initialized with encryption plugin"); + } + + @Override + public Connection getConnection() throws SQLException { + checkNotClosed(); + + Connection connection = null; + try { + connection = delegate.getConnection(); + validateConnection(connection); + return new EncryptingConnection(connection, encryptionPlugin); + } catch (SQLException e) { + // Close the connection if we got one but failed to wrap it + if (connection != null) { + try { + connection.close(); + } catch (SQLException closeEx) { + LOGGER.warning(()->String.format("Failed to close connection after wrapping failure %s", closeEx.getMessage())); + } + } + + LOGGER.severe(()->String.format("Failed to get connection from delegate DataSource %s", e.getMessage())); + throw new SQLException("Failed to obtain encrypted connection: " + e.getMessage(), e); + } + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + checkNotClosed(); + + Connection connection = null; + try { + connection = delegate.getConnection(username, password); + validateConnection(connection); + return new EncryptingConnection(connection, encryptionPlugin); + } catch (SQLException e) { + // Close the connection if we got one but failed to wrap it + if (connection != null) { + try { + connection.close(); + } catch (SQLException closeEx) { + LOGGER.warning(()->String.format("Failed to close connection after wrapping failure %s", closeEx.getMessage())); + } + } + + LOGGER.severe(()->String.format("Failed to get connection from delegate DataSource with credentials %s", e.getMessage())); + throw new SQLException("Failed to obtain encrypted connection: " + e.getMessage(), e); + } + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return delegate.getLogWriter(); + } + + @Override + public void setLogWriter(PrintWriter out) throws SQLException { + delegate.setLogWriter(out); + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + delegate.setLoginTimeout(seconds); + } + + @Override + public int getLoginTimeout() throws SQLException { + return delegate.getLoginTimeout(); + } + + @Override + public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException { + return delegate.getParentLogger(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isAssignableFrom(getClass())) { + return iface.cast(this); + } + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return iface.isAssignableFrom(getClass()) || delegate.isWrapperFor(iface); + } + + /** + * Gets the underlying DataSource. + * + * @return The wrapped DataSource + */ + public DataSource getDelegate() { + return delegate; + } + + /** + * Gets the encryption plugin instance. + * + * @return The KmsEncryptionPlugin instance + */ + public KmsEncryptionPlugin getEncryptionPlugin() { + return encryptionPlugin; + } + + /** + * Tests if the DataSource can provide a valid connection. + * This method attempts to get a connection and immediately closes it. + * + * @return true if a valid connection can be obtained, false otherwise + */ + public boolean isConnectionAvailable() { + if (closed) { + return false; + } + + Connection testConnection = null; + try { + testConnection = delegate.getConnection(); + return testConnection != null && !testConnection.isClosed() && testConnection.isValid(5); + } catch (SQLException e) { + LOGGER.finest(()->String.format("Connection availability test failed %s", e.getMessage())); + return false; + } finally { + if (testConnection != null) { + try { + testConnection.close(); + } catch (SQLException e) { + LOGGER.finest(()->String.format("Failed to close test connection %s", e.getMessage())); + } + } + } + } + + /** + * Closes the encryption plugin and releases resources. + */ + public void close() { + if (closed) { + return; + } + + LOGGER.info(()->"Closing EncryptingDataSource"); + closed = true; + + if (encryptionPlugin != null) { + try { + encryptionPlugin.cleanup(); + } catch (Exception e) { + LOGGER.warning(()->String.format("Error during encryption plugin cleanup %s", e.getMessage())); + } + } + + // If the delegate DataSource has a close method, call it + if (delegate != null) { + try { + // Try to close the delegate if it's closeable (e.g., HikariDataSource, etc.) + if (delegate instanceof AutoCloseable) { + ((AutoCloseable) delegate).close(); + LOGGER.finest(()->"Closed delegate DataSource"); + } + } catch (Exception e) { + LOGGER.warning(()->String.format("Error closing delegate DataSource %s", e.getMessage())); + } + } + + LOGGER.info("EncryptingDataSource closed"); + } + + /** + * Checks if this DataSource has been closed. + * + * @return true if closed, false otherwise + */ + public boolean isClosed() { + return closed; + } + + /** + * Validates that the DataSource is not closed. + * + * @throws SQLException if the DataSource is closed + */ + private void checkNotClosed() throws SQLException { + if (closed) { + throw new SQLException("EncryptingDataSource has been closed"); + } + } + + /** + * Validates that a connection is valid and not closed. + * + * @param connection the connection to validate + * @throws SQLException if the connection is invalid + */ + private void validateConnection(Connection connection) throws SQLException { + if (connection == null) { + throw new SQLException("Delegate DataSource returned null connection"); + } + + if (connection.isClosed()) { + throw new SQLException("Delegate DataSource returned a closed connection"); + } + + // Test the connection with a short timeout + try { + if (!connection.isValid(5)) { // 5 second timeout + throw new SQLException("Delegate DataSource returned an invalid connection"); + } + } catch (SQLException e) { + LOGGER.warning(()->String.format("Connection validation failed %s", e.getMessage())); + throw new SQLException("Connection validation failed: " + e.getMessage(), e); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingPreparedStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingPreparedStatement.java new file mode 100644 index 000000000..e74b87a27 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingPreparedStatement.java @@ -0,0 +1,870 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.wrapper; + +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.model.ColumnEncryptionConfig; +import software.amazon.jdbc.plugin.encryption.key.KeyManager; +import software.amazon.jdbc.plugin.encryption.service.EncryptionService; +import software.amazon.jdbc.plugin.encryption.sql.SqlAnalysisService; +import software.amazon.jdbc.plugin.encryption.parser.SQLAnalyzer; +import java.util.logging.Logger; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.*; +import java.util.Calendar; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A PreparedStatement wrapper that automatically encrypts parameter values + * for columns configured for encryption. Uses delegation pattern for non-encrypted operations. + */ +public class EncryptingPreparedStatement implements PreparedStatement { + + private static final Logger LOGGER = Logger.getLogger(EncryptingPreparedStatement.class.getName()); + + private final PreparedStatement delegate; + private final MetadataManager metadataManager; + private final EncryptionService encryptionService; + private final KeyManager keyManager; + private final SqlAnalysisService sqlAnalysisService; + private final String sql; + + // Cache for parameter index to column name mapping + private final Map parameterColumnMapping = new ConcurrentHashMap<>(); + private String tableName; + private boolean mappingInitialized = false; + + public EncryptingPreparedStatement(PreparedStatement delegate, + MetadataManager metadataManager, + EncryptionService encryptionService, + KeyManager keyManager, + SqlAnalysisService sqlAnalysisService, + String sql) { + LOGGER.finest(()->String.format("EncryptingPreparedStatement created for SQL: %s", sql)); + this.delegate = delegate; + this.metadataManager = metadataManager; + this.encryptionService = encryptionService; + this.keyManager = keyManager; + this.sqlAnalysisService = sqlAnalysisService; + this.sql = sql; + + // Initialize parameter mapping + initializeParameterMapping(); + LOGGER.finest(()->String.format("Parameter mapping initialized: %s", parameterColumnMapping)); + } + + /** + * Initializes the parameter index to column name mapping by parsing the SQL. + * This is a simplified implementation that extracts table name from INSERT/UPDATE statements. + */ + /** + * Initializes parameter mapping using SQL analysis service. + */ + private void initializeParameterMapping() { + LOGGER.finest(()->String.format("initializeParameterMapping called for SQL: %s", sql)); + try { + // Use SqlAnalysisService to analyze SQL and extract table information + SqlAnalysisService.SqlAnalysisResult analysisResult = sqlAnalysisService.analyzeSql(sql); + LOGGER.finest(()->String.format("Analysis result tables: %s", analysisResult.getAffectedTables())); + + // Get the first table from analysis results + if (!analysisResult.getAffectedTables().isEmpty()) { + this.tableName = analysisResult.getAffectedTables().iterator().next(); + LOGGER.finest(()->String.format("Table name set to: %s", tableName)); + + // Use SqlAnalysisService to get parameter mapping + Map mapping = sqlAnalysisService.getColumnParameterMapping(sql); + LOGGER.finest(()->String.format("Column parameter mapping from service: %s", mapping)); + parameterColumnMapping.putAll(mapping); + + LOGGER.finest(()->String.format("Final parameter mapping: %s", parameterColumnMapping)); + } + + mappingInitialized = true; + LOGGER.finest(()->String.format("Parameter mapping initialization complete for table: %s", tableName)); + + } catch (Exception e) { + LOGGER.finest(()->String.format("Failed to initialize parameter mapping: %s", e.getMessage())); + LOGGER.finest(()->String.format("Exception details %s", e)); + mappingInitialized = false; + } + } + + /** + * Maps parameters for INSERT statements by parsing column names. + */ + private void mapInsertParameters() { + // This is a simplified implementation + // In a production system, you might want to use a proper SQL parser + + int columnsStart = sql.indexOf("("); + int columnsEnd = sql.indexOf(")", columnsStart); + + if (columnsStart != -1 && columnsEnd != -1) { + String columnsPart = sql.substring(columnsStart + 1, columnsEnd); + String[] columns = columnsPart.split(","); + + for (int i = 0; i < columns.length; i++) { + String columnName = columns[i].trim(); + parameterColumnMapping.put(i + 1, columnName); + } + } + } + + /** + * Maps parameters for UPDATE statements by parsing SET clause. + */ + private void mapUpdateParameters() { + // This is a simplified implementation + // In a production system, you might want to use a proper SQL parser + + String upperSql = sql.toUpperCase(); + int setIndex = upperSql.indexOf("SET"); + int whereIndex = upperSql.indexOf("WHERE"); + + if (setIndex != -1) { + int endIndex = whereIndex != -1 ? whereIndex : sql.length(); + String setPart = sql.substring(setIndex + 3, endIndex); + + String[] assignments = setPart.split(","); + int parameterIndex = 1; + + for (String assignment : assignments) { + int equalsIndex = assignment.indexOf("="); + if (equalsIndex != -1) { + String columnName = assignment.substring(0, equalsIndex).trim(); + parameterColumnMapping.put(parameterIndex++, columnName); + } + } + } + } + + /** + * Gets the column name for a parameter index. + */ + private String getColumnNameForParameter(int parameterIndex) { + return parameterColumnMapping.get(parameterIndex); + } + + /** + * Checks if a parameter should be encrypted and encrypts it if necessary. + */ + private Object encryptParameterIfNeeded(int parameterIndex, Object value) throws SQLException { + LOGGER.finest(()->String.format("encryptParameterIfNeeded called: param=%s, value=%s", parameterIndex, value)); + LOGGER.finest(()->String.format("mappingInitialized=%s, tableName=%s", mappingInitialized, tableName)); + + if (!mappingInitialized || tableName == null || value == null) { + LOGGER.finest(()->"Skipping encryption - early exit"); + return value; + } + + try { + String columnName = getColumnNameForParameter(parameterIndex); + LOGGER.finest(()->String.format("Parameter %s maps to column: %s", parameterIndex, columnName)); + LOGGER.finest(()->String.format("Parameter mapping: %s", parameterColumnMapping)); + + if (columnName == null) { + return value; + } + + // Check if column is configured for encryption + boolean isEncrypted = metadataManager.isColumnEncrypted(tableName, columnName); + LOGGER.finest(()->String.format("Column %s.%s encrypted: %s", tableName, columnName, isEncrypted)); + + // Debug metadata manager state + try { + LOGGER.finest(()->String.format("Checking metadata manager for table: %s", tableName)); + LOGGER.finest(()->String.format("MetadataManager class: %s", metadataManager.getClass().getName())); + + // Force refresh metadata to pick up any new configurations + LOGGER.finest(()->String.format("Forcing metadata refresh...")); + metadataManager.refreshMetadata(); + LOGGER.finest(()->String.format("Metadata refresh completed")); + + // Try to get config directly after refresh + ColumnEncryptionConfig config = metadataManager.getColumnConfig(tableName, columnName); + LOGGER.finest(()->String.format("Column config for %s.%s after refresh: %s", tableName, columnName, config)); + + // Check encryption status after refresh + boolean isEncryptedAfterRefresh = metadataManager.isColumnEncrypted(tableName, columnName); + LOGGER.finest(()->String.format("Column %s.%s encrypted after refresh: %s", tableName, columnName, isEncryptedAfterRefresh)); + + } catch (Exception e) { + LOGGER.finest(()->String.format("Error getting column config: %s", e.getMessage())); + LOGGER.finest(()->String.format("Exception details", e)); + } + + if (!isEncrypted) { + return value; + } + + // Get encryption configuration + ColumnEncryptionConfig config = metadataManager.getColumnConfig(tableName, columnName); + if (config == null) { + LOGGER.warning(()->String.format("No encryption config found for column %s.%s", tableName, columnName)); + return value; + } + + // Get data key for encryption + byte[] dataKey = keyManager.decryptDataKey( + config.getKeyMetadata().getEncryptedDataKey(), + config.getKeyMetadata().getMasterKeyArn() + ); + + // Get HMAC key + byte[] hmacKey = config.getKeyMetadata().getHmacKey(); + + // Encrypt the value + byte[] encryptedValue = encryptionService.encrypt(value, dataKey, hmacKey, config.getAlgorithm()); + + // Clear the data key from memory + java.util.Arrays.fill(dataKey, (byte) 0); + + LOGGER.fine(()->String.format("Encrypted parameter %s for column %s.%s", parameterIndex, tableName, columnName)); + return encryptedValue; + + } catch (Exception e) { + //TODO move this into the subscriber + String errorMsg = String.format("Failed to encrypt parameter %d for column %s.%s", + parameterIndex, tableName, getColumnNameForParameter(parameterIndex)); + LOGGER.severe(()->String.format(errorMsg)); + throw new SQLException(errorMsg, e); + } + } + + // Override setXXX methods to add encryption logic + + @Override + public void setString(int parameterIndex, String x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setObject(parameterIndex, new EncryptedData((byte[]) encryptedValue)); + } else { + delegate.setString(parameterIndex, (String) encryptedValue); + } + } + + private static String bytesToHex(byte[] bytes) { + StringBuilder sb = new StringBuilder(); + for (byte b : bytes) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } + + @Override + public void setInt(int parameterIndex, int x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setInt(parameterIndex, (Integer) encryptedValue); + } + } + + @Override + public void setLong(int parameterIndex, long x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setLong(parameterIndex, (Long) encryptedValue); + } + } + + @Override + public void setBytes(int parameterIndex, byte[] x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } + + @Override + public void setDouble(int parameterIndex, double x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setDouble(parameterIndex, (Double) encryptedValue); + } + } + + @Override + public void setFloat(int parameterIndex, float x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setFloat(parameterIndex, (Float) encryptedValue); + } + } + + @Override + public void setBoolean(int parameterIndex, boolean x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setBoolean(parameterIndex, (Boolean) encryptedValue); + } + } + + @Override + public void setShort(int parameterIndex, short x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setShort(parameterIndex, (Short) encryptedValue); + } + } + + @Override + public void setByte(int parameterIndex, byte x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setByte(parameterIndex, (Byte) encryptedValue); + } + } + + @Override + public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setBigDecimal(parameterIndex, (BigDecimal) encryptedValue); + } + } + + @Override + public void setDate(int parameterIndex, Date x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setDate(parameterIndex, (Date) encryptedValue); + } + } + + @Override + public void setTime(int parameterIndex, Time x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setTime(parameterIndex, (Time) encryptedValue); + } + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setTimestamp(parameterIndex, (Timestamp) encryptedValue); + } + } + + @Override + public void setObject(int parameterIndex, Object x) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[] && !(x instanceof byte[])) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setObject(parameterIndex, encryptedValue); + } + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[] && !(x instanceof byte[])) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setObject(parameterIndex, encryptedValue, targetSqlType); + } + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType, int scaleOrLength) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[] && !(x instanceof byte[])) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setObject(parameterIndex, encryptedValue, targetSqlType, scaleOrLength); + } + } + + // Null setters - no encryption needed + @Override + public void setNull(int parameterIndex, int sqlType) throws SQLException { + delegate.setNull(parameterIndex, sqlType); + } + + @Override + public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + delegate.setNull(parameterIndex, sqlType, typeName); + } + + // Stream and reader setters - delegate directly (encryption not supported for streams) + @Override + public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + delegate.setBinaryStream(parameterIndex, x, length); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + delegate.setBinaryStream(parameterIndex, x, length); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + delegate.setBinaryStream(parameterIndex, x); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { + delegate.setAsciiStream(parameterIndex, x, length); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + delegate.setAsciiStream(parameterIndex, x, length); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { + delegate.setAsciiStream(parameterIndex, x); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, int length) throws SQLException { + delegate.setCharacterStream(parameterIndex, reader, length); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, long length) throws SQLException { + delegate.setCharacterStream(parameterIndex, reader, length); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { + delegate.setCharacterStream(parameterIndex, reader); + } + + // Other specialized setters - delegate directly + @Override + public void setURL(int parameterIndex, URL x) throws SQLException { + delegate.setURL(parameterIndex, x); + } + + @Override + public void setRef(int parameterIndex, Ref x) throws SQLException { + delegate.setRef(parameterIndex, x); + } + + @Override + public void setBlob(int parameterIndex, Blob x) throws SQLException { + delegate.setBlob(parameterIndex, x); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException { + delegate.setBlob(parameterIndex, inputStream, length); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { + delegate.setBlob(parameterIndex, inputStream); + } + + @Override + public void setClob(int parameterIndex, Clob x) throws SQLException { + delegate.setClob(parameterIndex, x); + } + + @Override + public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + delegate.setClob(parameterIndex, reader, length); + } + + @Override + public void setClob(int parameterIndex, Reader reader) throws SQLException { + delegate.setClob(parameterIndex, reader); + } + + @Override + public void setArray(int parameterIndex, Array x) throws SQLException { + delegate.setArray(parameterIndex, x); + } + + @Override + public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setDate(parameterIndex, (Date) encryptedValue, cal); + } + } + + @Override + public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setTime(parameterIndex, (Time) encryptedValue, cal); + } + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, x); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setTimestamp(parameterIndex, (Timestamp) encryptedValue, cal); + } + } + + // Deprecated methods - delegate directly + @Override + @Deprecated + public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException { + delegate.setUnicodeStream(parameterIndex, x, length); + } + + // JDBC 4.0+ methods + @Override + public void setRowId(int parameterIndex, RowId x) throws SQLException { + delegate.setRowId(parameterIndex, x); + } + + @Override + public void setNString(int parameterIndex, String value) throws SQLException { + Object encryptedValue = encryptParameterIfNeeded(parameterIndex, value); + if (encryptedValue instanceof byte[]) { + delegate.setBytes(parameterIndex, (byte[]) encryptedValue); + } else { + delegate.setNString(parameterIndex, (String) encryptedValue); + } + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value, long length) throws SQLException { + delegate.setNCharacterStream(parameterIndex, value, length); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + delegate.setNCharacterStream(parameterIndex, value); + } + + @Override + public void setNClob(int parameterIndex, NClob value) throws SQLException { + delegate.setNClob(parameterIndex, value); + } + + @Override + public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + delegate.setNClob(parameterIndex, reader, length); + } + + @Override + public void setNClob(int parameterIndex, Reader reader) throws SQLException { + delegate.setNClob(parameterIndex, reader); + } + + @Override + public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + delegate.setSQLXML(parameterIndex, xmlObject); + } + + // All other PreparedStatement methods delegate directly to the wrapped statement + + @Override + public ResultSet executeQuery() throws SQLException { + return delegate.executeQuery(); + } + + @Override + public int executeUpdate() throws SQLException { + return delegate.executeUpdate(); + } + + @Override + public boolean execute() throws SQLException { + return delegate.execute(); + } + + @Override + public void addBatch() throws SQLException { + delegate.addBatch(); + } + + @Override + public void clearParameters() throws SQLException { + delegate.clearParameters(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + return delegate.getMetaData(); + } + + @Override + public ParameterMetaData getParameterMetaData() throws SQLException { + return delegate.getParameterMetaData(); + } + + // Statement methods - delegate to wrapped statement + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + return delegate.executeQuery(sql); + } + + @Override + public int executeUpdate(String sql) throws SQLException { + return delegate.executeUpdate(sql); + } + + @Override + public void close() throws SQLException { + delegate.close(); + } + + @Override + public int getMaxFieldSize() throws SQLException { + return delegate.getMaxFieldSize(); + } + + @Override + public void setMaxFieldSize(int max) throws SQLException { + delegate.setMaxFieldSize(max); + } + + @Override + public int getMaxRows() throws SQLException { + return delegate.getMaxRows(); + } + + @Override + public void setMaxRows(int max) throws SQLException { + delegate.setMaxRows(max); + } + + @Override + public void setEscapeProcessing(boolean enable) throws SQLException { + delegate.setEscapeProcessing(enable); + } + + @Override + public int getQueryTimeout() throws SQLException { + return delegate.getQueryTimeout(); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + delegate.setQueryTimeout(seconds); + } + + @Override + public void cancel() throws SQLException { + delegate.cancel(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return delegate.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + delegate.clearWarnings(); + } + + @Override + public void setCursorName(String name) throws SQLException { + delegate.setCursorName(name); + } + + @Override + public boolean execute(String sql) throws SQLException { + return delegate.execute(sql); + } + + @Override + public ResultSet getResultSet() throws SQLException { + return delegate.getResultSet(); + } + + @Override + public int getUpdateCount() throws SQLException { + return delegate.getUpdateCount(); + } + + @Override + public boolean getMoreResults() throws SQLException { + return delegate.getMoreResults(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + delegate.setFetchDirection(direction); + } + + @Override + public int getFetchDirection() throws SQLException { + return delegate.getFetchDirection(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + delegate.setFetchSize(rows); + } + + @Override + public int getFetchSize() throws SQLException { + return delegate.getFetchSize(); + } + + @Override + public int getResultSetConcurrency() throws SQLException { + return delegate.getResultSetConcurrency(); + } + + @Override + public int getResultSetType() throws SQLException { + return delegate.getResultSetType(); + } + + @Override + public void addBatch(String sql) throws SQLException { + delegate.addBatch(sql); + } + + @Override + public void clearBatch() throws SQLException { + delegate.clearBatch(); + } + + @Override + public int[] executeBatch() throws SQLException { + return delegate.executeBatch(); + } + + @Override + public Connection getConnection() throws SQLException { + return delegate.getConnection(); + } + + @Override + public boolean getMoreResults(int current) throws SQLException { + return delegate.getMoreResults(current); + } + + @Override + public ResultSet getGeneratedKeys() throws SQLException { + return delegate.getGeneratedKeys(); + } + + @Override + public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + return delegate.executeUpdate(sql, autoGeneratedKeys); + } + + @Override + public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + return delegate.executeUpdate(sql, columnIndexes); + } + + @Override + public int executeUpdate(String sql, String[] columnNames) throws SQLException { + return delegate.executeUpdate(sql, columnNames); + } + + @Override + public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + return delegate.execute(sql, autoGeneratedKeys); + } + + @Override + public boolean execute(String sql, int[] columnIndexes) throws SQLException { + return delegate.execute(sql, columnIndexes); + } + + @Override + public boolean execute(String sql, String[] columnNames) throws SQLException { + return delegate.execute(sql, columnNames); + } + + @Override + public int getResultSetHoldability() throws SQLException { + return delegate.getResultSetHoldability(); + } + + @Override + public boolean isClosed() throws SQLException { + return delegate.isClosed(); + } + + @Override + public void setPoolable(boolean poolable) throws SQLException { + delegate.setPoolable(poolable); + } + + @Override + public boolean isPoolable() throws SQLException { + return delegate.isPoolable(); + } + + @Override + public void closeOnCompletion() throws SQLException { + delegate.closeOnCompletion(); + } + + @Override + public boolean isCloseOnCompletion() throws SQLException { + return delegate.isCloseOnCompletion(); + } + + // Wrapper interface methods + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isAssignableFrom(getClass())) { + return iface.cast(this); + } + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return iface.isAssignableFrom(getClass()) || delegate.isWrapperFor(iface); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingStatement.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingStatement.java new file mode 100644 index 000000000..30352e3fb --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/encryption/wrapper/EncryptingStatement.java @@ -0,0 +1,309 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.encryption.wrapper; + +import software.amazon.jdbc.plugin.encryption.KmsEncryptionPlugin; +import java.util.logging.Logger; + +import java.sql.*; + +/** + * A Statement wrapper that provides transparent encryption/decryption functionality. + * This wrapper intercepts SQL execution methods and wraps result sets with decryption support. + * Note: Statement-based encryption is limited compared to PreparedStatement encryption. + */ +public class EncryptingStatement implements Statement { + + private static final Logger LOGGER = Logger.getLogger(EncryptingStatement.class.getName()); + + private final Statement delegate; + private final KmsEncryptionPlugin encryptionPlugin; + + /** + * Creates an encrypting statement wrapper. + * + * @param delegate The underlying Statement to wrap + * @param encryptionPlugin The encryption plugin to use + */ + public EncryptingStatement(Statement delegate, KmsEncryptionPlugin encryptionPlugin) { + this.delegate = delegate; + this.encryptionPlugin = encryptionPlugin; + + LOGGER.finest(()->"Created EncryptingStatement wrapper"); + } + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + LOGGER.finest(()->String.format("Executing query with encryption support: %s", sql)); + + ResultSet resultSet = delegate.executeQuery(sql); + return encryptionPlugin.wrapResultSet(resultSet); + } + + @Override + public int executeUpdate(String sql) throws SQLException { + LOGGER.finest(()->String.format("Executing update with encryption support: %s", sql)); + + // For Statement-based updates, we can't easily encrypt embedded values + // This is a limitation - PreparedStatement should be used for full encryption support + return delegate.executeUpdate(sql); + } + + @Override + public boolean execute(String sql) throws SQLException { + LOGGER.finest(()->String.format("Executing statement with encryption support: %s", sql)); + + return delegate.execute(sql); + } + + @Override + public ResultSet getResultSet() throws SQLException { + ResultSet resultSet = delegate.getResultSet(); + if (resultSet != null) { + return encryptionPlugin.wrapResultSet(resultSet); + } + return null; + } + + @Override + public ResultSet getGeneratedKeys() throws SQLException { + ResultSet resultSet = delegate.getGeneratedKeys(); + if (resultSet != null) { + return encryptionPlugin.wrapResultSet(resultSet); + } + return null; + } + + // All other Statement methods delegate directly to the wrapped statement + + @Override + public void close() throws SQLException { + delegate.close(); + } + + @Override + public int getMaxFieldSize() throws SQLException { + return delegate.getMaxFieldSize(); + } + + @Override + public void setMaxFieldSize(int max) throws SQLException { + delegate.setMaxFieldSize(max); + } + + @Override + public int getMaxRows() throws SQLException { + return delegate.getMaxRows(); + } + + @Override + public void setMaxRows(int max) throws SQLException { + delegate.setMaxRows(max); + } + + @Override + public void setEscapeProcessing(boolean enable) throws SQLException { + delegate.setEscapeProcessing(enable); + } + + @Override + public int getQueryTimeout() throws SQLException { + return delegate.getQueryTimeout(); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + delegate.setQueryTimeout(seconds); + } + + @Override + public void cancel() throws SQLException { + delegate.cancel(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return delegate.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + delegate.clearWarnings(); + } + + @Override + public void setCursorName(String name) throws SQLException { + delegate.setCursorName(name); + } + + @Override + public int getUpdateCount() throws SQLException { + return delegate.getUpdateCount(); + } + + @Override + public boolean getMoreResults() throws SQLException { + return delegate.getMoreResults(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + delegate.setFetchDirection(direction); + } + + @Override + public int getFetchDirection() throws SQLException { + return delegate.getFetchDirection(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + delegate.setFetchSize(rows); + } + + @Override + public int getFetchSize() throws SQLException { + return delegate.getFetchSize(); + } + + @Override + public int getResultSetConcurrency() throws SQLException { + return delegate.getResultSetConcurrency(); + } + + @Override + public int getResultSetType() throws SQLException { + return delegate.getResultSetType(); + } + + @Override + public void addBatch(String sql) throws SQLException { + delegate.addBatch(sql); + } + + @Override + public void clearBatch() throws SQLException { + delegate.clearBatch(); + } + + @Override + public int[] executeBatch() throws SQLException { + return delegate.executeBatch(); + } + + @Override + public Connection getConnection() throws SQLException { + return delegate.getConnection(); + } + + @Override + public boolean getMoreResults(int current) throws SQLException { + return delegate.getMoreResults(current); + } + + @Override + public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + return delegate.executeUpdate(sql, autoGeneratedKeys); + } + + @Override + public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + return delegate.executeUpdate(sql, columnIndexes); + } + + @Override + public int executeUpdate(String sql, String[] columnNames) throws SQLException { + return delegate.executeUpdate(sql, columnNames); + } + + @Override + public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + return delegate.execute(sql, autoGeneratedKeys); + } + + @Override + public boolean execute(String sql, int[] columnIndexes) throws SQLException { + return delegate.execute(sql, columnIndexes); + } + + @Override + public boolean execute(String sql, String[] columnNames) throws SQLException { + return delegate.execute(sql, columnNames); + } + + @Override + public int getResultSetHoldability() throws SQLException { + return delegate.getResultSetHoldability(); + } + + @Override + public boolean isClosed() throws SQLException { + return delegate.isClosed(); + } + + @Override + public void setPoolable(boolean poolable) throws SQLException { + delegate.setPoolable(poolable); + } + + @Override + public boolean isPoolable() throws SQLException { + return delegate.isPoolable(); + } + + @Override + public void closeOnCompletion() throws SQLException { + delegate.closeOnCompletion(); + } + + @Override + public boolean isCloseOnCompletion() throws SQLException { + return delegate.isCloseOnCompletion(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isAssignableFrom(getClass())) { + return iface.cast(this); + } + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return iface.isAssignableFrom(getClass()) || delegate.isWrapperFor(iface); + } + + /** + * Gets the underlying Statement. + * + * @return The wrapped Statement + */ + public Statement getDelegate() { + return delegate; + } + + /** + * Gets the encryption plugin instance. + * + * @return The KmsEncryptionPlugin instance + */ + public KmsEncryptionPlugin getEncryptionPlugin() { + return encryptionPlugin; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/SqlMethodAnalyzer.java b/wrapper/src/main/java/software/amazon/jdbc/util/SqlMethodAnalyzer.java index 4fec72d69..8effce11a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/SqlMethodAnalyzer.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/SqlMethodAnalyzer.java @@ -25,7 +25,12 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.Statement; import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; public class SqlMethodAnalyzer { @@ -39,6 +44,43 @@ public class SqlMethodAnalyzer { JdbcMethod.RESULTSET_CLOSE.methodName ))); + private String getQueryTypeFromParseTree(String sql, PluginService pluginService) { + try { + Statement statement = CCJSqlParserUtil.parse(sql); + String className = statement.getClass().getSimpleName(); + + if (className.contains("Select")) { + return "SELECT"; + } else if (className.contains("Insert")) { + return "INSERT"; + } else if (className.contains("Update")) { + return "UPDATE"; + } else if (className.contains("Delete")) { + return "DELETE"; + } else if (className.contains("Create")) { + return "CREATE"; + } else if (className.contains("Drop")) { + return "DROP"; + } else if (className.contains("Set")) { + return "SET"; + } + } catch (JSQLParserException e) { + // Fallback to string parsing + } + + // Fallback string parsing + String trimmed = sql.trim().toUpperCase(); + if (trimmed.startsWith("SELECT")) return "SELECT"; + if (trimmed.startsWith("INSERT")) return "INSERT"; + if (trimmed.startsWith("UPDATE")) return "UPDATE"; + if (trimmed.startsWith("DELETE")) return "DELETE"; + if (trimmed.startsWith("CREATE")) return "CREATE"; + if (trimmed.startsWith("DROP")) return "DROP"; + if (trimmed.startsWith("SET")) return "SET"; + + return "UNKNOWN"; + } + private static final Set EXECUTE_SQL_METHOD_NAMES = Collections.unmodifiableSet( new HashSet<>(Arrays.asList( JdbcMethod.STATEMENT_EXECUTE.methodName, @@ -61,13 +103,13 @@ public class SqlMethodAnalyzer { ))); public boolean doesOpenTransaction(final Connection conn, final String methodName, - final Object[] args) { + final Object[] args, PluginService pluginService) { if (!(EXECUTE_SQL_METHOD_NAMES.contains(methodName) && args != null && args.length >= 1)) { return false; } final String statement = getFirstSqlStatement(String.valueOf(args[0])); - if (isStatementStartingTransaction(statement)) { + if (isStatementStartingTransaction(statement, pluginService)) { return true; } @@ -78,7 +120,12 @@ public boolean doesOpenTransaction(final Connection conn, final String methodNam return false; } - return !autocommit && isStatementDml(statement); + return !autocommit && isStatementDml(statement, pluginService); + } + + public boolean doesOpenTransaction(final Connection conn, final String methodName, + final Object[] args) { + return doesOpenTransaction(conn, methodName, args, null); } private String getFirstSqlStatement(final String sql) { @@ -108,12 +155,12 @@ private List parseMultiStatementQueries(String query) { } public boolean doesCloseTransaction(final Connection conn, final String methodName, - final Object[] args) { + final Object[] args, PluginService pluginService) { if (CLOSE_TRANSACTION_METHOD_NAMES.contains(methodName)) { return true; } - if (doesSwitchAutoCommitFalseTrue(conn, methodName, args)) { + if (doesSwitchAutoCommitFalseTrue(conn, methodName, args, pluginService)) { return true; } @@ -122,29 +169,51 @@ public boolean doesCloseTransaction(final Connection conn, final String methodNa } final String statement = getFirstSqlStatement(String.valueOf(args[0])); - return isStatementClosingTransaction(statement); + return isStatementClosingTransaction(statement, pluginService); } - public boolean isStatementDml(final String statement) { - return !isStatementStartingTransaction(statement) - && !isStatementClosingTransaction(statement) - && !statement.startsWith("SET ") - && !statement.startsWith("USE ") - && !statement.startsWith("SHOW "); + public boolean doesCloseTransaction(final Connection conn, final String methodName, + final Object[] args) { + return doesCloseTransaction(conn, methodName, args, null); } - public boolean isStatementStartingTransaction(final String statement) { - return statement.startsWith("BEGIN") || statement.startsWith("START TRANSACTION"); + private String getQueryTypeFromString(final String sql) { + final String trimmed = sql.trim().toUpperCase(); + if (trimmed.startsWith("SELECT")) return "SELECT"; + if (trimmed.startsWith("INSERT")) return "INSERT"; + if (trimmed.startsWith("UPDATE")) return "UPDATE"; + if (trimmed.startsWith("DELETE")) return "DELETE"; + if (trimmed.startsWith("CREATE")) return "CREATE"; + if (trimmed.startsWith("DROP")) return "DROP"; + if (trimmed.startsWith("ALTER")) return "ALTER"; + if (trimmed.startsWith("BEGIN") || trimmed.startsWith("START TRANSACTION")) return "BEGIN"; + if (trimmed.startsWith("COMMIT")) return "COMMIT"; + if (trimmed.startsWith("ROLLBACK")) return "ROLLBACK"; + if (trimmed.startsWith("END")) return "COMMIT"; // END is equivalent to COMMIT + if (trimmed.startsWith("ABORT")) return "ROLLBACK"; // ABORT is equivalent to ROLLBACK + if (trimmed.startsWith("SET")) return "SET"; + if (trimmed.startsWith("USE")) return "USE"; + if (trimmed.startsWith("SHOW")) return "SHOW"; + return "UNKNOWN"; } - public boolean isStatementClosingTransaction(final String statement) { - return statement.startsWith("COMMIT") - || statement.startsWith("ROLLBACK") - || statement.startsWith("END") - || statement.startsWith("ABORT"); + public boolean isStatementDml(final String statement, PluginService pluginService) { + final String queryType = getQueryTypeFromParseTree(statement, pluginService); + return "SELECT".equals(queryType) || "INSERT".equals(queryType) || + "UPDATE".equals(queryType) || "DELETE".equals(queryType); } - public boolean isStatementSettingAutoCommit(final String methodName, final Object[] args) { + public boolean isStatementStartingTransaction(final String statement, PluginService pluginService) { + final String queryType = getQueryTypeFromParseTree(statement, pluginService); + return "BEGIN".equals(queryType); + } + + public boolean isStatementClosingTransaction(final String statement, PluginService pluginService) { + final String queryType = getQueryTypeFromParseTree(statement, pluginService); + return "COMMIT".equals(queryType) || "ROLLBACK".equals(queryType); + } + + public boolean isStatementSettingAutoCommit(final String methodName, final Object[] args, PluginService pluginService) { if (args == null || args.length < 1) { return false; } @@ -154,13 +223,22 @@ public boolean isStatementSettingAutoCommit(final String methodName, final Objec } final String statement = getFirstSqlStatement(String.valueOf(args[0])); - return statement.startsWith("SET AUTOCOMMIT"); + final String queryType = getQueryTypeFromParseTree(statement, pluginService); + + // Check if it's a SET statement and contains AUTOCOMMIT + if ("SET".equals(queryType)) { + return statement.toUpperCase().contains("AUTOCOMMIT"); + } + + // Fallback: check if the statement starts with SET AUTOCOMMIT directly + final String trimmed = statement.trim().toUpperCase(); + return trimmed.startsWith("SET") && trimmed.contains("AUTOCOMMIT"); } public boolean doesSwitchAutoCommitFalseTrue(final Connection conn, final String methodName, - final Object[] jdbcMethodArgs) { + final Object[] jdbcMethodArgs, PluginService pluginService) { final boolean isStatementSettingAutoCommit = isStatementSettingAutoCommit( - methodName, jdbcMethodArgs); + methodName, jdbcMethodArgs, pluginService); if (!isStatementSettingAutoCommit && !JdbcMethod.CONNECTION_SETAUTOCOMMIT.methodName.equals(methodName)) { return false; } @@ -182,6 +260,15 @@ public boolean doesSwitchAutoCommitFalseTrue(final Connection conn, final String return !oldAutoCommitVal && Boolean.TRUE.equals(newAutoCommitVal); } + public boolean doesSwitchAutoCommitFalseTrue(final Connection conn, final String methodName, + final Object[] jdbcMethodArgs) { + return doesSwitchAutoCommitFalseTrue(conn, methodName, jdbcMethodArgs, null); + } + + public boolean isStatementSettingAutoCommit(final String methodName, final Object[] args) { + return isStatementSettingAutoCommit(methodName, args, null); + } + public Boolean getAutoCommitValueFromSqlStatement(final Object[] args) { if (args == null || args.length < 1) { return null; diff --git a/wrapper/src/main/resources/sql/encrypted_data_type.sql b/wrapper/src/main/resources/sql/encrypted_data_type.sql new file mode 100644 index 000000000..bf39d92f4 --- /dev/null +++ b/wrapper/src/main/resources/sql/encrypted_data_type.sql @@ -0,0 +1,96 @@ +-- PostgreSQL domain for HMAC-verified encrypted data +-- Format: [HMAC:32bytes][type:1byte][IV:12bytes][ciphertext] +DROP DOMAIN IF EXISTS encrypted_data CASCADE; +CREATE DOMAIN encrypted_data AS bytea +CHECK (length(VALUE) >= 45); + +-- Helper function to verify HMAC using HMAC key (two-key format) +CREATE OR REPLACE FUNCTION verify_encrypted_data_hmac( + data encrypted_data, + hmac_key bytea +) +RETURNS boolean AS $$ +DECLARE + data_bytes bytea := data::bytea; + stored_hmac bytea; + encrypted_payload bytea; + calculated_hmac bytea; +BEGIN + -- Format: [HMAC:32][type:1][IV:12][ciphertext] + stored_hmac := substring(data_bytes from 1 for 32); + encrypted_payload := substring(data_bytes from 33); + calculated_hmac := hmac(encrypted_payload, hmac_key, 'sha256'); + RETURN stored_hmac = calculated_hmac; +END; +$$ LANGUAGE plpgsql IMMUTABLE STRICT; + +CREATE OR REPLACE FUNCTION has_valid_hmac_structure(data encrypted_data) +RETURNS boolean AS $$ +BEGIN + RETURN length(data::bytea) >= 45; +END; +$$ LANGUAGE plpgsql IMMUTABLE STRICT; + +-- Trigger function that validates HMAC for a specific column +-- Usage: CREATE TRIGGER trigger_name BEFORE INSERT OR UPDATE ON table_name +-- FOR EACH ROW EXECUTE FUNCTION validate_encrypted_data_hmac('column_name'); +CREATE OR REPLACE FUNCTION validate_encrypted_data_hmac() +RETURNS trigger AS $$ +DECLARE + metadata_schema text := 'aws'; + col_name text := TG_ARGV[0]; + col_value encrypted_data; + hmac_key bytea; + data_bytes bytea; + stored_hmac bytea; + encrypted_payload bytea; + calculated_hmac bytea; + cache_key text; +BEGIN + EXECUTE format('SELECT ($1).%I', col_name) INTO col_value USING NEW; + + IF col_value IS NOT NULL THEN + -- Try to get HMAC key from session cache + cache_key := 'hmac_key.' || TG_TABLE_NAME || '.' || col_name; + BEGIN + hmac_key := decode(current_setting(cache_key), 'hex'); + EXCEPTION WHEN OTHERS THEN + -- Not cached, fetch from metadata + EXECUTE format( + 'SELECT ks.hmac_key FROM %I.encryption_metadata em ' || + 'JOIN %I.key_storage ks ON em.key_id = ks.id ' || + 'WHERE em.table_name = $1 AND em.column_name = $2', + metadata_schema, metadata_schema + ) INTO hmac_key USING TG_TABLE_NAME, col_name; + + IF hmac_key IS NULL THEN + RAISE EXCEPTION 'No HMAC key found for %.%', TG_TABLE_NAME, col_name; + END IF; + + -- Cache in session variable as hex string + PERFORM set_config(cache_key, encode(hmac_key, 'hex'), false); + END; + + -- Verify HMAC (format: [HMAC:32][type:1][IV:12][ciphertext]) + data_bytes := col_value::bytea; + + IF length(data_bytes) < 45 THEN + RAISE EXCEPTION 'Invalid encrypted data length for column %', col_name; + END IF; + + stored_hmac := substring(data_bytes from 1 for 32); + encrypted_payload := substring(data_bytes from 33); + + calculated_hmac := hmac(encrypted_payload, hmac_key, 'sha256'); + + IF stored_hmac != calculated_hmac THEN + RAISE EXCEPTION 'HMAC verification failed for column %. Stored: %, Calculated: %', + col_name, + encode(stored_hmac, 'hex'), + encode(calculated_hmac, 'hex'); + END IF; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/wrapper/src/test/java/integration/container/tests/KeyManagementUtilityIntegrationTest.java b/wrapper/src/test/java/integration/container/tests/KeyManagementUtilityIntegrationTest.java new file mode 100644 index 000000000..b9f0ae718 --- /dev/null +++ b/wrapper/src/test/java/integration/container/tests/KeyManagementUtilityIntegrationTest.java @@ -0,0 +1,269 @@ +package integration.container.tests; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import integration.container.ConnectionStringHelper; +import integration.container.TestEnvironment; +import java.sql.*; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.CreateKeyRequest; +import software.amazon.awssdk.services.kms.model.CreateKeyResponse; +import software.amazon.awssdk.services.kms.model.KeyUsageType; +import software.amazon.awssdk.services.kms.model.KeySpec; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; + +/** + * Integration test for KeyManagementUtility functionality. + */ +public class KeyManagementUtilityIntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(KeyManagementUtilityIntegrationTest.class); + private static final String KMS_KEY_ARN_ENV = "AWS_KMS_KEY_ARN"; + private static final String TEST_TABLE = "users"; + private static final String TEST_COLUMN = "ssn"; + private static final String TEST_ALGORITHM = "AES-256-GCM"; + + private Connection connection; + private KmsClient kmsClient; + private String masterKeyArn; + private boolean createdKey = false; + + @BeforeEach + void setUp() throws Exception { + // Get or create master key + masterKeyArn = System.getenv(KMS_KEY_ARN_ENV); + if (masterKeyArn == null || masterKeyArn.isEmpty()) { + logger.info("No AWS_KMS_KEY_ARN environment variable found, creating new master key"); + kmsClient = KmsClient.builder().build(); + masterKeyArn = createTestMasterKey(); + createdKey = true; + } else { + logger.info("Using existing master key from environment: {}", masterKeyArn); + kmsClient = KmsClient.builder().build(); + } + + assumeTrue(masterKeyArn != null && !masterKeyArn.isEmpty(), + "KMS Key ARN must be provided via " + KMS_KEY_ARN_ENV + " environment variable"); + + Properties props = ConnectionStringHelper.getDefaultProperties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "kmsEncryption"); + props.setProperty(EncryptionConfig.KMS_MASTER_KEY_ARN.name, masterKeyArn); + props.setProperty(EncryptionConfig.KMS_REGION.name, "us-east-1"); + + String url = String.format("jdbc:aws-wrapper:postgresql://%s:%d/%s", + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpointPort(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getDefaultDbName()); + + connection = DriverManager.getConnection(url, props); + + // Setup test database schema + setupTestSchema(); + + logger.info("Test setup completed with master key: {}", masterKeyArn); + } + + @AfterEach + void tearDown() throws Exception { + if (connection != null) { + try (Statement stmt = connection.createStatement()) { + // Clean up test data + stmt.execute("DROP TABLE IF EXISTS " + TEST_TABLE); + stmt.execute("DELETE FROM encrypt.encryption_metadata WHERE table_name = '" + TEST_TABLE + "'"); + stmt.execute("DELETE FROM encrypt.key_storage WHERE key_id LIKE 'test-%'"); + } + connection.close(); + } + + if (kmsClient != null) { + kmsClient.close(); + } + } + + @Test + void testCreateDataKeyAndPopulateMetadata() throws Exception { + logger.info("Testing data key creation and metadata population for {}.{}", TEST_TABLE, TEST_COLUMN); + + // For this test, we'll use the KeyManagementUtility concept by directly calling + // the same methods it would use, demonstrating the key management workflow + + // Step 1: Generate a data key using KMS (what KeyManagementUtility.generateAndStoreDataKey would do) + String keyId = "test-key-" + System.currentTimeMillis(); + + // Step 2: Store the encryption metadata (what KeyManagementUtility.initializeEncryptionForColumn would do) + try (PreparedStatement stmt = connection.prepareStatement( + "INSERT INTO encrypt.encryption_metadata (table_name, column_name, encryption_algorithm, key_id) VALUES (?, ?, ?, ?)")) { + stmt.setString(1, TEST_TABLE); + stmt.setString(2, TEST_COLUMN); + stmt.setString(3, TEST_ALGORITHM); + stmt.setString(4, keyId); + stmt.executeUpdate(); + logger.info("Created encryption metadata with key ID: {}", keyId); + } + + // Step 3: Verify the metadata was created correctly + try (PreparedStatement checkStmt = connection.prepareStatement( + "SELECT table_name, column_name, encryption_algorithm, key_id FROM encrypt.encryption_metadata WHERE table_name = ? AND column_name = ?")) { + checkStmt.setString(1, TEST_TABLE); + checkStmt.setString(2, TEST_COLUMN); + ResultSet rs = checkStmt.executeQuery(); + + assertTrue(rs.next(), "Should find encryption metadata"); + assertEquals(TEST_TABLE, rs.getString("table_name")); + assertEquals(TEST_COLUMN, rs.getString("column_name")); + assertEquals(TEST_ALGORITHM, rs.getString("encryption_algorithm")); + assertEquals(keyId, rs.getString("key_id")); + logger.info("Verified encryption metadata exists for key: {}", keyId); + } + + // Step 4: Test that the encryption system works with the configured metadata + String insertSql = "INSERT INTO " + TEST_TABLE + " (name, " + TEST_COLUMN + ") VALUES (?, ?)"; + try (PreparedStatement pstmt = connection.prepareStatement(insertSql)) { + pstmt.setString(1, "Test User"); + pstmt.setString(2, "123-45-6789"); + int rowsInserted = pstmt.executeUpdate(); + assertEquals(1, rowsInserted, "Should insert one row"); + logger.info("Successfully inserted encrypted data using key: {}", keyId); + } + + // Step 5: Verify data can be retrieved and decrypted + String selectSql = "SELECT name, " + TEST_COLUMN + " FROM " + TEST_TABLE + " WHERE name = ?"; + try (PreparedStatement pstmt = connection.prepareStatement(selectSql)) { + pstmt.setString(1, "Test User"); + ResultSet rs = pstmt.executeQuery(); + + assertTrue(rs.next(), "Should find inserted row"); + assertEquals("Test User", rs.getString("name")); + assertEquals("123-45-6789", rs.getString(TEST_COLUMN)); + logger.info("Successfully retrieved and decrypted data using key: {}", keyId); + } + + // Step 6: Demonstrate key management utility concept - validate master key + assertTrue(masterKeyArn != null && !masterKeyArn.isEmpty(), "Master key should be valid"); + logger.info("Master key validation successful: {}", masterKeyArn); + } + + @Test + void testEncryptionWithDifferentValues() throws Exception { + logger.info("Testing encryption with different SSN values"); + + // Demonstrate KeyManagementUtility workflow for multiple keys + String keyId = "test-key-multi-" + System.currentTimeMillis(); + + // Setup encryption metadata using KeyManagementUtility approach + try (PreparedStatement stmt = connection.prepareStatement( + "INSERT INTO encrypt.encryption_metadata (table_name, column_name, encryption_algorithm, key_id) VALUES (?, ?, ?, ?)")) { + stmt.setString(1, TEST_TABLE); + stmt.setString(2, TEST_COLUMN); + stmt.setString(3, TEST_ALGORITHM); + stmt.setString(4, keyId); + stmt.executeUpdate(); + logger.info("Setup encryption metadata with key: {}", keyId); + } + + // Test multiple SSN values (demonstrating key management for different data) + String[] testSSNs = {"111-11-1111", "222-22-2222", "333-33-3333"}; + String[] testNames = {"Alice", "Bob", "Charlie"}; + + // Insert test data using the configured encryption + String insertSql = "INSERT INTO " + TEST_TABLE + " (name, " + TEST_COLUMN + ") VALUES (?, ?)"; + for (int i = 0; i < testSSNs.length; i++) { + try (PreparedStatement pstmt = connection.prepareStatement(insertSql)) { + pstmt.setString(1, testNames[i]); + pstmt.setString(2, testSSNs[i]); + pstmt.executeUpdate(); + logger.info("Inserted encrypted data for {} using key: {}", testNames[i], keyId); + } + } + + // Verify all data can be retrieved correctly (demonstrating key management success) + String selectSql = "SELECT name, " + TEST_COLUMN + " FROM " + TEST_TABLE + " ORDER BY name"; + try (PreparedStatement pstmt = connection.prepareStatement(selectSql)) { + ResultSet rs = pstmt.executeQuery(); + + int count = 0; + while (rs.next()) { + String name = rs.getString("name"); + String ssn = rs.getString(TEST_COLUMN); + + // Find matching test data + for (int i = 0; i < testNames.length; i++) { + if (testNames[i].equals(name)) { + assertEquals(testSSNs[i], ssn, "SSN should match for " + name); + count++; + logger.info("Successfully decrypted data for {} using key: {}", name, keyId); + break; + } + } + } + + assertEquals(testSSNs.length, count, "Should retrieve all inserted records"); + logger.info("Successfully verified {} encrypted records using key management", count); + } + } + + private String createTestMasterKey() throws Exception { + logger.info("Creating test master key"); + + CreateKeyRequest request = CreateKeyRequest.builder() + .description("Test master key for KeyManagementUtility integration test") + .keyUsage(KeyUsageType.ENCRYPT_DECRYPT) + .keySpec(KeySpec.SYMMETRIC_DEFAULT) + .build(); + + CreateKeyResponse response = kmsClient.createKey(request); + String keyArn = response.keyMetadata().arn(); + logger.info("Created test master key: {}", keyArn); + return keyArn; + } + + private void setupTestSchema() throws SQLException { + try (Statement stmt = connection.createStatement()) { + // Drop and recreate tables with correct schema + stmt.execute("DROP SCHEMA IF EXISTS encrypt CASCADE"); + stmt.execute("CREATE SCHEMA encrypt"); + stmt.execute("DROP TABLE IF EXISTS " + TEST_TABLE + " CASCADE"); + + // Create key storage table first (due to foreign key) + stmt.execute("CREATE TABLE encrypt.key_storage (" + + "id SERIAL PRIMARY KEY, " + + "name VARCHAR(255) NOT NULL, " + + "master_key_arn VARCHAR(512) NOT NULL, " + + "encrypted_data_key TEXT NOT NULL, " + + "key_spec VARCHAR(50) NOT NULL, " + + "created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, " + + "last_used_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP" + + ")"); + + // Create encryption metadata table + stmt.execute("CREATE TABLE encrypt.encryption_metadata (" + + "table_name VARCHAR(255) NOT NULL, " + + "column_name VARCHAR(255) NOT NULL, " + + "encryption_algorithm VARCHAR(50) NOT NULL, " + + "key_id INTEGER NOT NULL, " + + "created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, " + + "updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, " + + "PRIMARY KEY (table_name, column_name), " + + "FOREIGN KEY (key_id) REFERENCES encrypt.key_storage(id)" + + ")"); + + // Create test users table + stmt.execute("CREATE TABLE " + TEST_TABLE + " (" + + "id SERIAL PRIMARY KEY, " + + "name VARCHAR(100), " + + "ssn TEXT, " + + "email VARCHAR(100)" + + ")"); + + logger.info("Test database schema setup complete"); + } + } +} diff --git a/wrapper/src/test/java/integration/container/tests/KmsEncryptionIntegrationTest.java b/wrapper/src/test/java/integration/container/tests/KmsEncryptionIntegrationTest.java new file mode 100644 index 000000000..81c6fd809 --- /dev/null +++ b/wrapper/src/test/java/integration/container/tests/KmsEncryptionIntegrationTest.java @@ -0,0 +1,346 @@ +package integration.container.tests; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import integration.container.ConnectionStringHelper; +import integration.container.TestEnvironment; +import java.sql.*; +import java.util.Base64; +import java.util.Properties; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.GenerateDataKeyRequest; +import software.amazon.awssdk.services.kms.model.GenerateDataKeyResponse; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.encryption.model.EncryptionConfig; +import software.amazon.jdbc.plugin.encryption.schema.EncryptedDataTypeInstaller; + +/** + * Integration test for KMS encryption functionality with JSqlParser. + */ +public class KmsEncryptionIntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(KmsEncryptionIntegrationTest.class); + private static final String KMS_KEY_ARN_ENV = "AWS_KMS_KEY_ARN"; + private static final String TEST_SSN_1 = "111-11-1111"; + private static final String TEST_NAME_1 = "Alice Test"; + private static final String TEST_EMAIL_1 = "alice@test.com"; + private static final String TEST_SSN_2 = "222-22-2222"; + private static final String TEST_NAME_2 = "Bob Test"; + private static final String TEST_EMAIL_2 = "bob@test.com"; + + private static Connection connection; + private static String kmsKeyArn; + + @BeforeAll + static void setUp() throws Exception { + kmsKeyArn = System.getenv(KMS_KEY_ARN_ENV); + assumeTrue(kmsKeyArn != null && !kmsKeyArn.isEmpty(), + "KMS Key ARN must be provided via " + KMS_KEY_ARN_ENV + " environment variable"); + + Properties props = ConnectionStringHelper.getDefaultProperties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "kmsEncryption"); + props.setProperty(EncryptionConfig.KMS_MASTER_KEY_ARN.name, kmsKeyArn); + props.setProperty(EncryptionConfig.KMS_REGION.name, "us-east-1"); + + // Get the metadata schema from config (defaults to "encrypt") + String metadataSchema = EncryptionConfig.ENCRYPTION_METADATA_SCHEMA.defaultValue; + + String url = String.format("jdbc:aws-wrapper:postgresql://%s:%d/%s", + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpointPort(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getDefaultDbName()); + + // use a direct connection so that we setup all of the metadata before instantiating the encrypted connection + String directUrl = String.format("jdbc:postgresql://%s:%d/%s", + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpointPort(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getDefaultDbName()); + + try (Connection directConnection = DriverManager.getConnection(directUrl, props)){ + // Setup encryption metadata schema + try (Statement stmt = directConnection.createStatement()) { + // Drop and recreate tables with correct schema + stmt.execute("DROP SCHEMA IF EXISTS " + metadataSchema + " CASCADE"); + stmt.execute("CREATE SCHEMA " + metadataSchema); + stmt.execute("DROP TABLE IF EXISTS users CASCADE"); + + + // Install encrypted_data custom type + logger.trace("Installing encrypted_data custom type"); + stmt.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto"); + EncryptedDataTypeInstaller.installEncryptedDataType(directConnection); + + // Create key_storage table first (referenced by encryption_metadata) + stmt.execute("CREATE TABLE if not exists " + metadataSchema + ".key_storage (" + + "id SERIAL PRIMARY KEY, " + + "name VARCHAR(255) NOT NULL, " + + "master_key_arn VARCHAR(512) NOT NULL, " + + "encrypted_data_key TEXT NOT NULL, " + + "hmac_key BYTEA NOT NULL, " + + "key_spec VARCHAR(50) DEFAULT 'AES_256', " + + "created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, " + + "last_used_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP)"); + + // Create encryption_metadata table with correct schema + stmt.execute("CREATE TABLE if not exists " + metadataSchema + ".encryption_metadata (" + + "table_name VARCHAR(255) NOT NULL, " + + "column_name VARCHAR(255) NOT NULL, " + + "encryption_algorithm VARCHAR(50) NOT NULL, " + + "key_id INTEGER NOT NULL, " + + "created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, " + + "updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, " + + "PRIMARY KEY (table_name, column_name), " + + "FOREIGN KEY (key_id) REFERENCES " + metadataSchema + ".key_storage(id))"); + + // Insert a key into key_storage with real KMS data key and separate HMAC key + KmsClient kmsClient = KmsClient.builder().region(software.amazon.awssdk.regions.Region.US_EAST_1).build(); + GenerateDataKeyRequest dataKeyRequest = GenerateDataKeyRequest.builder() + .keyId(kmsKeyArn) + .keySpec("AES_256") + .build(); + GenerateDataKeyResponse dataKeyResponse = kmsClient.generateDataKey(dataKeyRequest); + String encryptedDataKeyBase64 = Base64.getEncoder().encodeToString(dataKeyResponse.ciphertextBlob().asByteArray()); + + // Generate separate HMAC key (32 bytes for HMAC-SHA256) + byte[] hmacKey = new byte[32]; + new java.security.SecureRandom().nextBytes(hmacKey); + + PreparedStatement keyStmt = directConnection.prepareStatement( + "INSERT INTO " + metadataSchema + ".key_storage (name, master_key_arn, encrypted_data_key, hmac_key, key_spec) VALUES (?, ?, ?, ?, ?) RETURNING id"); + keyStmt.setString(1, "test-key-users-ssn"); + keyStmt.setString(2, kmsKeyArn); + keyStmt.setString(3, encryptedDataKeyBase64); + keyStmt.setBytes(4, hmacKey); + keyStmt.setString(5, "AES_256"); + ResultSet keyRs = keyStmt.executeQuery(); + keyRs.next(); + int generatedKeyId = keyRs.getInt(1); + keyStmt.close(); + + // Use KeyManagementUtility approach to setup encryption metadata + logger.trace("Setting up encryption metadata for users.ssn using KeyManagementUtility approach"); + + try (PreparedStatement metaStmt = directConnection.prepareStatement( + "INSERT INTO " + metadataSchema + ".encryption_metadata (table_name, column_name, encryption_algorithm, key_id) VALUES (?, ?, ?, ?)")) { + metaStmt.setString(1, "users"); + metaStmt.setString(2, "ssn"); + metaStmt.setString(3, "AES-256-GCM"); + metaStmt.setInt(4, generatedKeyId); + metaStmt.executeUpdate(); + logger.trace("Encryption metadata configured for key: {}", generatedKeyId); + } + + // Verify the metadata was configured correctly + try (PreparedStatement checkStmt = directConnection.prepareStatement( + "SELECT table_name, column_name, encryption_algorithm, key_id FROM " + EncryptionConfig.ENCRYPTION_METADATA_SCHEMA.defaultValue + ".encryption_metadata WHERE table_name = ? AND column_name = ?")) { + checkStmt.setString(1, "users"); + checkStmt.setString(2, "ssn"); + ResultSet rs = checkStmt.executeQuery(); + while (rs.next()) { + logger.trace("Verified metadata: {}.{} -> {} (key: {})", + rs.getString("table_name"), rs.getString("column_name"), + rs.getString("encryption_algorithm"), rs.getInt("key_id")); + } + } + + // Create users table with encrypted_data type for SSN + stmt.execute("CREATE TABLE if not exists users (" + + "id SERIAL PRIMARY KEY, " + + "name VARCHAR(100), " + + "ssn encrypted_data, " + + "email VARCHAR(100))"); + + // Add trigger to validate HMAC on ssn column + stmt.execute("CREATE TRIGGER validate_ssn_hmac " + + "BEFORE INSERT OR UPDATE ON users " + + "FOR EACH ROW EXECUTE FUNCTION validate_encrypted_data_hmac('ssn')"); + + logger.trace("Test setup completed"); + + // Final verification that metadata exists + try (PreparedStatement finalCheck = directConnection.prepareStatement( + "SELECT COUNT(*) FROM " + EncryptionConfig.ENCRYPTION_METADATA_SCHEMA.defaultValue + ".encryption_metadata WHERE table_name = 'users' AND column_name = 'ssn'")) { + ResultSet rs = finalCheck.executeQuery(); + rs.next(); + int count = rs.getInt(1); + logger.info("Final metadata verification: {} rows found for users.ssn", count); + if (count == 0) { + throw new RuntimeException("Encryption metadata was not properly created!"); + } + } + } + } + connection = DriverManager.getConnection(url, props); + } + + @AfterEach + void cleanupTestData() throws Exception { + // Clean up test data between tests without dropping schema + /* + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DELETE FROM users WHERE name LIKE '%Test'"); + logger.trace("Cleaned up test data"); + } + } + */ + } + + @AfterAll + static void tearDown() throws Exception { + if (connection != null && !connection.isClosed()) { + connection.close(); + } + } + + @Test + void testBasicEncryption() throws Exception { + String insertSql = "INSERT INTO users (name, ssn, email) VALUES (?, ?, ?)"; + try (PreparedStatement pstmt = connection.prepareStatement(insertSql)) { + pstmt.setString(1, TEST_NAME_1); + pstmt.setString(2, TEST_SSN_1); + pstmt.setString(3, TEST_EMAIL_1); + pstmt.executeUpdate(); + } + + String selectSql = "SELECT name, ssn, email FROM users WHERE name = ?"; + try (PreparedStatement pstmt = connection.prepareStatement(selectSql)) { + pstmt.setString(1, TEST_NAME_1); + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(TEST_NAME_1, rs.getString("name")); + assertEquals(TEST_SSN_1, rs.getString("ssn")); + assertEquals(TEST_EMAIL_1, rs.getString("email")); + } + } + + // Verify data is encrypted in storage + Properties plainProps = ConnectionStringHelper.getDefaultProperties(); + String plainUrl = String.format("jdbc:postgresql://%s:%d/%s", + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpointPort(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getDefaultDbName()); + + try (Connection plainConn = DriverManager.getConnection(plainUrl, plainProps); + PreparedStatement pstmt = plainConn.prepareStatement(selectSql)) { + pstmt.setString(1, TEST_NAME_1); + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(TEST_NAME_1, rs.getString("name")); + assertNotEquals(TEST_SSN_1, rs.getString("ssn")); // Should be encrypted + } + } + } + + @Test + void testUpdateEncryption() throws Exception { + String insertSql = "INSERT INTO users (name, ssn,email) VALUES (?, ?, ?)"; + logger.trace("testUpdateEncryption: INSERT SQL: {}", insertSql); + try (PreparedStatement pstmt = connection.prepareStatement(insertSql)) { + logger.trace("Setting INSERT parameters: name={}, ssn={}, email={}", TEST_NAME_2, TEST_SSN_1, TEST_EMAIL_2); + pstmt.setString(1, TEST_NAME_2); + pstmt.setString(2, TEST_SSN_1); + pstmt.setString(3, TEST_EMAIL_2); + assertEquals(1,pstmt.executeUpdate()); + } + + // Check what was actually stored in the database + logger.trace("Checking what was stored in database..."); + try (PreparedStatement stmt = connection.prepareStatement("SELECT name, ssn, pg_typeof(name) as name_type, pg_typeof(ssn) as ssn_type FROM users where name = ?")) { + stmt.setString(1, TEST_NAME_2); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + assertEquals(TEST_NAME_2, rs.getString("name")); + assertEquals(TEST_SSN_1, rs.getString("ssn")); + assertEquals("character varying", rs.getString("name_type")); + assertEquals("encrypted_data", rs.getString("ssn_type")); + } + } + + String updateSql = "UPDATE users SET ssn = ? WHERE name = ?"; + logger.trace("testUpdateEncryption: UPDATE SQL: {}", updateSql); + try (PreparedStatement pstmt = connection.prepareStatement(updateSql)) { + logger.trace("Setting UPDATE parameters: ssn={}, name={}", TEST_SSN_2, TEST_NAME_2); + pstmt.setString(1, TEST_SSN_2); + pstmt.setString(2, TEST_NAME_2); + assertEquals(1, pstmt.executeUpdate()); + } + + String selectSql = "SELECT ssn FROM users WHERE name = ?"; + try (PreparedStatement pstmt = connection.prepareStatement(selectSql)) { + pstmt.setString(1, TEST_NAME_2); + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(TEST_SSN_2, rs.getString("ssn")); + } + } + } + + @Test + void testEncryptionMetadataSetup() throws Exception { + // Verify encryption metadata was created with master key ARN + String metadataSql = "SELECT table_name, column_name, encryption_algorithm FROM " + EncryptionConfig.ENCRYPTION_METADATA_SCHEMA.defaultValue + ".encryption_metadata WHERE table_name = 'users'"; + try (PreparedStatement pstmt = connection.prepareStatement(metadataSql)) { + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals("users", rs.getString("table_name")); + assertEquals("ssn", rs.getString("column_name")); + assertEquals("AES-256-GCM", rs.getString("encryption_algorithm")); + } + } + + // Verify key storage table exists and is ready for KMS key storage + String keyStorageSql = "SELECT COUNT(*) FROM " + EncryptionConfig.ENCRYPTION_METADATA_SCHEMA.defaultValue + ".key_storage"; + try (PreparedStatement pstmt = connection.prepareStatement(keyStorageSql)) { + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertTrue(rs.getInt(1) >= 0); + } + } + + // Verify KMS master key ARN is configured + assertEquals(kmsKeyArn, System.getenv(KMS_KEY_ARN_ENV)); + assertTrue(kmsKeyArn.startsWith("arn:aws:kms:")); + } + + @Test + void testEncryptedDataTypeHmacVerification() throws Exception { + // Insert test data + String insertSql = "INSERT INTO users (name, ssn, email) VALUES (?, ?, ?)"; + try (PreparedStatement pstmt = connection.prepareStatement(insertSql)) { + pstmt.setString(1, "HMAC Test User"); + pstmt.setString(2, "999-99-9999"); + pstmt.setString(3, "hmac@test.com"); + assertEquals(1, pstmt.executeUpdate()); + } + + // Verify HMAC structure at database level (doesn't require key) + String structureCheckSql = "SELECT name, has_valid_hmac_structure(ssn) as valid_structure FROM users WHERE name = ?"; + try (PreparedStatement pstmt = connection.prepareStatement(structureCheckSql)) { + pstmt.setString(1, "HMAC Test User"); + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertTrue(rs.getBoolean("valid_structure"), "Encrypted data should have valid HMAC structure"); + logger.info("HMAC structure validation passed for encrypted SSN"); + } + } + + // Verify we can still decrypt the data + String selectSql = "SELECT ssn FROM users WHERE name = ?"; + try (PreparedStatement pstmt = connection.prepareStatement(selectSql)) { + pstmt.setString(1, "HMAC Test User"); + try (ResultSet rs = pstmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals("999-99-9999", rs.getString("ssn")); + logger.info("Successfully decrypted SSN with HMAC verification"); + } + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserPlaceholderTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserPlaceholderTest.java new file mode 100644 index 000000000..5352fda20 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserPlaceholderTest.java @@ -0,0 +1,87 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +import software.amazon.jdbc.plugin.encryption.parser.ast.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for JDBC placeholder support + */ +class PostgreSqlParserPlaceholderTest { + + private PostgreSqlParser parser; + + @BeforeEach + void setUp() { + parser = new PostgreSqlParser(); + } + + @Test + void testSelectWithPlaceholder() { + String sql = "SELECT * FROM users WHERE id = ?"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertNotNull(select.getWhereClause()); + assertTrue(select.getWhereClause() instanceof BinaryExpression); + BinaryExpression where = (BinaryExpression) select.getWhereClause(); + assertTrue(where.getRight() instanceof Placeholder); + } + + @Test + void testInsertWithPlaceholders() { + String sql = "INSERT INTO users (name, age) VALUES (?, ?)"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertEquals(1, insert.getValues().size()); + assertEquals(2, insert.getValues().get(0).size()); + assertTrue(insert.getValues().get(0).get(0) instanceof Placeholder); + assertTrue(insert.getValues().get(0).get(1) instanceof Placeholder); + } + + @Test + void testUpdateWithPlaceholder() { + String sql = "UPDATE users SET name = ? WHERE id = ?"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof UpdateStatement); + UpdateStatement update = (UpdateStatement) stmt; + assertEquals(1, update.getAssignments().size()); + assertTrue(update.getAssignments().get(0).getValue() instanceof Placeholder); + assertTrue(((BinaryExpression) update.getWhereClause()).getRight() instanceof Placeholder); + } + + @Test + void testDeleteWithPlaceholder() { + String sql = "DELETE FROM users WHERE age > ?"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof DeleteStatement); + DeleteStatement delete = (DeleteStatement) stmt; + assertNotNull(delete.getWhereClause()); + BinaryExpression where = (BinaryExpression) delete.getWhereClause(); + assertTrue(where.getRight() instanceof Placeholder); + } + + @Test + void testMultiplePlaceholdersInExpression() { + String sql = "SELECT * FROM products WHERE price BETWEEN ? AND ?"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + // This tests that placeholders work in complex expressions + assertNotNull(((SelectStatement) stmt).getWhereClause()); + } + + @Test + void testMixedPlaceholdersAndLiterals() { + String sql = "INSERT INTO orders (user_id, total, status) VALUES (?, 100.50, 'pending')"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertEquals(3, insert.getValues().get(0).size()); + assertTrue(insert.getValues().get(0).get(0) instanceof Placeholder); + assertTrue(insert.getValues().get(0).get(1) instanceof NumericLiteral); + assertTrue(insert.getValues().get(0).get(2) instanceof StringLiteral); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserRegressionTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserRegressionTest.java new file mode 100644 index 000000000..93d36c65d --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserRegressionTest.java @@ -0,0 +1,317 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +import software.amazon.jdbc.plugin.encryption.parser.ast.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Regression tests based on PostgreSQL's src/test/regress/sql test files + */ +class PostgreSqlParserRegressionTest { + + private PostgreSqlParser parser; + + @BeforeEach + void setUp() { + parser = new PostgreSqlParser(); + } + + // SELECT regression tests + @Test + void testSelectWithOrderBy() { + String sql = "SELECT * FROM onek WHERE onek.unique1 < 10 ORDER BY onek.unique1"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertNotNull(select.getOrderBy()); + assertEquals(1, select.getOrderBy().size()); + } + + @Test + void testSelectWithQualifiedColumns() { + String sql = "SELECT onek.unique1, onek.stringu1 FROM onek WHERE onek.unique1 < 20"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertEquals(2, select.getSelectList().size()); + } + + @Test + void testSelectWithComparison() { + String sql = "SELECT onek.unique1 FROM onek WHERE onek.unique1 > 980"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertNotNull(select.getWhereClause()); + assertTrue(select.getWhereClause() instanceof BinaryExpression); + } + + // INSERT regression tests + @Test + void testInsertWithMultipleValues() { + String sql = "INSERT INTO inserttest VALUES (10, 20), (30, 40)"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertEquals(2, insert.getValues().size()); + } + + @Test + void testInsertWithColumnList() { + String sql = "INSERT INTO inserttest (col1, col2) VALUES (3, 5)"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertNotNull(insert.getColumns()); + assertEquals(2, insert.getColumns().size()); + } + + @Test + void testInsertWithStringLiterals() { + String sql = "INSERT INTO inserttest VALUES (1, 'test string')"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertEquals(1, insert.getValues().size()); + assertEquals(2, insert.getValues().get(0).size()); + } + + // UPDATE regression tests + @Test + void testUpdateWithMultipleAssignments() { + String sql = "UPDATE update_test SET a = 10, b = 20 WHERE c = 'foo'"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof UpdateStatement); + UpdateStatement update = (UpdateStatement) stmt; + assertEquals(2, update.getAssignments().size()); + assertNotNull(update.getWhereClause()); + } + + @Test + void testUpdateWithNumericValues() { + String sql = "UPDATE test_table SET price = 19.99, quantity = 5"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof UpdateStatement); + UpdateStatement update = (UpdateStatement) stmt; + assertEquals(2, update.getAssignments().size()); + } + + // CREATE TABLE regression tests + @Test + void testCreateTableWithMultipleColumns() { + String sql = "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name VARCHAR NOT NULL, price DECIMAL)"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof CreateTableStatement); + CreateTableStatement create = (CreateTableStatement) stmt; + assertEquals(3, create.getColumns().size()); + } + + @Test + void testCreateTableWithConstraints() { + String sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, email VARCHAR NOT NULL)"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof CreateTableStatement); + CreateTableStatement create = (CreateTableStatement) stmt; + assertEquals(2, create.getColumns().size()); + assertTrue(create.getColumns().get(0).isPrimaryKey()); + assertTrue(create.getColumns().get(1).isNotNull()); + } + + // DELETE regression tests + @Test + void testDeleteWithComplexWhere() { + String sql = "DELETE FROM products WHERE price > 100 AND category = 'electronics'"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof DeleteStatement); + DeleteStatement delete = (DeleteStatement) stmt; + assertNotNull(delete.getWhereClause()); + assertTrue(delete.getWhereClause() instanceof BinaryExpression); + } + + @Test + void testDeleteWithNumericComparison() { + String sql = "DELETE FROM inventory WHERE quantity < 5"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof DeleteStatement); + DeleteStatement delete = (DeleteStatement) stmt; + assertNotNull(delete.getWhereClause()); + } + + // Expression complexity tests + @Test + void testComplexBooleanExpression() { + String sql = "SELECT * FROM products WHERE (price > 50 AND category = 'books') OR (price < 20 AND category = 'music')"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertNotNull(select.getWhereClause()); + assertTrue(select.getWhereClause() instanceof BinaryExpression); + } + + @Test + void testArithmeticExpression() { + String sql = "SELECT price * quantity FROM orders WHERE total > price + tax"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertEquals(1, select.getSelectList().size()); + assertNotNull(select.getWhereClause()); + } + + // String and numeric literal tests + @Test + void testStringLiteralsWithQuotes() { + String sql = "INSERT INTO messages VALUES ('Hello World', 'Test message')"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertEquals(2, insert.getValues().get(0).size()); + } + + @Test + void testNumericLiterals() { + String sql = "INSERT INTO measurements VALUES (42, 3.14159, 2.5e10)"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof InsertStatement); + InsertStatement insert = (InsertStatement) stmt; + assertEquals(3, insert.getValues().get(0).size()); + } + + // Edge cases from PostgreSQL tests + @Test + void testSelectWithParentheses() { + String sql = "SELECT (price + tax) * quantity FROM orders"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertEquals(1, select.getSelectList().size()); + } + + @Test + void testMultipleTableReferences() { + String sql = "SELECT users.name, orders.total FROM users, orders WHERE users.id = orders.user_id"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof SelectStatement); + SelectStatement select = (SelectStatement) stmt; + assertEquals(2, select.getFromClause().size()); + assertEquals(2, select.getSelectList().size()); + } + + @Test + void testComplexUpdateExpression() { + String sql = "UPDATE accounts SET balance = balance + 100 WHERE account_id = 12345"; + Statement stmt = parser.parse(sql); + assertTrue(stmt instanceof UpdateStatement); + UpdateStatement update = (UpdateStatement) stmt; + assertEquals(1, update.getAssignments().size()); + assertNotNull(update.getWhereClause()); + } + + @Test + void testSelectWithSubquery() { + String sql = "SELECT * FROM products WHERE price > (SELECT AVG(price) FROM products)"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + + assertEquals(1, selectStmt.getFromList().size()); + assertEquals("products", selectStmt.getFromList().get(0).getTableName().getName()); + assertNotNull(selectStmt.getWhereClause()); + } + + @Test + void testAdvancedPostgreSQLFeatures() { + // Test CASE expression + String sql1 = "SELECT CASE WHEN age > 18 THEN 'adult' ELSE 'minor' END FROM users"; + Statement stmt1 = parser.parse(sql1); + assertInstanceOf(SelectStatement.class, stmt1); + + // Test CAST expression + String sql2 = "SELECT CAST(price AS INTEGER) FROM products"; + Statement stmt2 = parser.parse(sql2); + assertInstanceOf(SelectStatement.class, stmt2); + + // Test CROSS JOIN + String sql3 = "SELECT * FROM users CROSS JOIN products"; + Statement stmt3 = parser.parse(sql3); + assertInstanceOf(SelectStatement.class, stmt3); + SelectStatement selectStmt3 = (SelectStatement) stmt3; + assertEquals(2, selectStmt3.getFromList().size()); + + // Test ORDER BY with NULLS FIRST + String sql4 = "SELECT * FROM users ORDER BY name ASC NULLS FIRST"; + Statement stmt4 = parser.parse(sql4); + assertInstanceOf(SelectStatement.class, stmt4); + + // Test ORDER BY with DESC and NULLS LAST + String sql5 = "SELECT * FROM products ORDER BY price DESC NULLS LAST"; + Statement stmt5 = parser.parse(sql5); + assertInstanceOf(SelectStatement.class, stmt5); + } + + @Test + void testMultipleJoinTypes() { + // Test INNER JOIN + String sql1 = "SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id"; + Statement stmt1 = parser.parse(sql1); + assertInstanceOf(SelectStatement.class, stmt1); + + // Test LEFT OUTER JOIN + String sql2 = "SELECT * FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id"; + Statement stmt2 = parser.parse(sql2); + assertInstanceOf(SelectStatement.class, stmt2); + + // Test RIGHT JOIN + String sql3 = "SELECT * FROM users RIGHT JOIN orders ON users.id = orders.user_id"; + Statement stmt3 = parser.parse(sql3); + assertInstanceOf(SelectStatement.class, stmt3); + } + + @Test + void testComplexExpressions() { + // Test nested CASE + String sql1 = "SELECT CASE WHEN status = 'active' THEN CASE WHEN age > 18 THEN 'adult' ELSE 'minor' END ELSE 'inactive' END FROM users"; + Statement stmt1 = parser.parse(sql1); + assertInstanceOf(SelectStatement.class, stmt1); + + // Test multiple CAST + String sql2 = "SELECT CAST(price AS DECIMAL), CAST(quantity AS INTEGER) FROM products"; + Statement stmt2 = parser.parse(sql2); + assertInstanceOf(SelectStatement.class, stmt2); + + // Test complex WHERE with boolean literals + String sql3 = "SELECT * FROM users WHERE active = true AND verified = false"; + Statement stmt3 = parser.parse(sql3); + assertInstanceOf(SelectStatement.class, stmt3); + } + + @Test + void testMultipleOrderByColumns() { + String sql = "SELECT * FROM users ORDER BY last_name ASC, first_name DESC NULLS LAST, age ASC NULLS FIRST"; + Statement stmt = parser.parse(sql); + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + assertNotNull(selectStmt.getOrderByList()); + assertEquals(3, selectStmt.getOrderByList().size()); + } + + @Test + void testInsertReturning() { + // PostgreSQL RETURNING clause + String sql = "INSERT INTO users (name, email) VALUES ('John', 'john@example.com')"; + Statement stmt = parser.parse(sql); + assertInstanceOf(InsertStatement.class, stmt); + } + + @Test + void testThreeWayJoin() { + String sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id JOIN products p ON o.product_id = p.id"; + Statement stmt = parser.parse(sql); + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + assertEquals(3, selectStmt.getFromList().size()); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserTest.java new file mode 100644 index 000000000..3def3b08e --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/PostgreSqlParserTest.java @@ -0,0 +1,209 @@ +package software.amazon.jdbc.plugin.encryption.parser; + +import software.amazon.jdbc.plugin.encryption.parser.ast.*; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Test cases for PostgreSQL SQL Parser + */ +public class PostgreSqlParserTest { + + private PostgreSqlParser parser; + + @BeforeEach + void setUp() { + parser = new PostgreSqlParser(); + } + + @Test + void testSimpleSelectStatement() { + String sql = "SELECT id, name FROM users"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + + assertEquals(2, selectStmt.getSelectList().size()); + assertEquals("id", ((Identifier) selectStmt.getSelectList().get(0).getExpression()).getName()); + assertEquals("name", ((Identifier) selectStmt.getSelectList().get(1).getExpression()).getName()); + + assertEquals(1, selectStmt.getFromList().size()); + assertEquals("users", selectStmt.getFromList().get(0).getTableName().getName()); + } + + @Test + void testSelectWithWhereClause() { + String sql = "SELECT * FROM users WHERE age > 18"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + + assertNotNull(selectStmt.getWhereClause()); + assertInstanceOf(BinaryExpression.class, selectStmt.getWhereClause()); + + BinaryExpression whereExpr = (BinaryExpression) selectStmt.getWhereClause(); + assertEquals(BinaryExpression.Operator.GREATER_THAN, whereExpr.getOperator()); + } + + @Test + void testSelectWithOrderBy() { + String sql = "SELECT name, age FROM users ORDER BY name ASC, age DESC"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + + assertNotNull(selectStmt.getOrderByList()); + assertEquals(2, selectStmt.getOrderByList().size()); + + OrderByItem firstOrder = selectStmt.getOrderByList().get(0); + assertEquals("name", ((Identifier) firstOrder.getExpression()).getName()); + assertEquals(OrderByItem.Direction.ASC, firstOrder.getDirection()); + + OrderByItem secondOrder = selectStmt.getOrderByList().get(1); + assertEquals("age", ((Identifier) secondOrder.getExpression()).getName()); + assertEquals(OrderByItem.Direction.DESC, secondOrder.getDirection()); + } + + @Test + void testInsertStatement() { + String sql = "INSERT INTO users (name, age) VALUES ('John', 25)"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(InsertStatement.class, stmt); + InsertStatement insertStmt = (InsertStatement) stmt; + + assertEquals("users", insertStmt.getTable().getTableName().getName()); + assertEquals(2, insertStmt.getColumns().size()); + assertEquals("name", insertStmt.getColumns().get(0).getName()); + assertEquals("age", insertStmt.getColumns().get(1).getName()); + + assertEquals(1, insertStmt.getValues().size()); + assertEquals(2, insertStmt.getValues().get(0).size()); + + assertInstanceOf(StringLiteral.class, insertStmt.getValues().get(0).get(0)); + assertEquals("John", ((StringLiteral) insertStmt.getValues().get(0).get(0)).getValue()); + + assertInstanceOf(NumericLiteral.class, insertStmt.getValues().get(0).get(1)); + assertEquals("25", ((NumericLiteral) insertStmt.getValues().get(0).get(1)).getValue()); + } + + @Test + void testUpdateStatement() { + String sql = "UPDATE users SET age = 26 WHERE name = 'John'"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(UpdateStatement.class, stmt); + UpdateStatement updateStmt = (UpdateStatement) stmt; + + assertEquals("users", updateStmt.getTable().getTableName().getName()); + assertEquals(1, updateStmt.getAssignments().size()); + + Assignment assignment = updateStmt.getAssignments().get(0); + assertEquals("age", assignment.getColumn().getName()); + assertInstanceOf(NumericLiteral.class, assignment.getValue()); + assertEquals("26", ((NumericLiteral) assignment.getValue()).getValue()); + + assertNotNull(updateStmt.getWhereClause()); + assertInstanceOf(BinaryExpression.class, updateStmt.getWhereClause()); + } + + @Test + void testDeleteStatement() { + String sql = "DELETE FROM users WHERE age < 18"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(DeleteStatement.class, stmt); + DeleteStatement deleteStmt = (DeleteStatement) stmt; + + assertEquals("users", deleteStmt.getTable().getTableName().getName()); + assertNotNull(deleteStmt.getWhereClause()); + + BinaryExpression whereExpr = (BinaryExpression) deleteStmt.getWhereClause(); + assertEquals(BinaryExpression.Operator.LESS_THAN, whereExpr.getOperator()); + } + + @Test + void testCreateTableStatement() { + String sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name VARCHAR NOT NULL)"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(CreateTableStatement.class, stmt); + CreateTableStatement createStmt = (CreateTableStatement) stmt; + + assertEquals("users", createStmt.getTableName().getName()); + assertEquals(2, createStmt.getColumns().size()); + + ColumnDefinition idCol = createStmt.getColumns().get(0); + assertEquals("id", idCol.getColumnName().getName()); + assertEquals("INTEGER", idCol.getDataType()); + assertTrue(idCol.isPrimaryKey()); + + ColumnDefinition nameCol = createStmt.getColumns().get(1); + assertEquals("name", nameCol.getColumnName().getName()); + assertEquals("VARCHAR", nameCol.getDataType()); + assertTrue(nameCol.isNotNull()); + } + + @Test + void testComplexExpression() { + String sql = "SELECT * FROM users WHERE age > 18 AND name LIKE 'J%' OR status = 'active'"; + Statement stmt = parser.parse(sql); + + assertInstanceOf(SelectStatement.class, stmt); + SelectStatement selectStmt = (SelectStatement) stmt; + + assertNotNull(selectStmt.getWhereClause()); + assertInstanceOf(BinaryExpression.class, selectStmt.getWhereClause()); + + // The expression should be parsed with correct operator precedence + BinaryExpression whereExpr = (BinaryExpression) selectStmt.getWhereClause(); + assertEquals(BinaryExpression.Operator.OR, whereExpr.getOperator()); + } + + @Test + void testLexerTokenization() { + SqlLexer lexer = new SqlLexer("SELECT id, 'test', 123, 45.67 FROM users"); + java.util.List tokens = lexer.tokenize(); + + assertEquals(Token.Type.SELECT, tokens.get(0).getType()); + assertEquals(Token.Type.IDENT, tokens.get(1).getType()); + assertEquals("id", tokens.get(1).getValue()); + assertEquals(Token.Type.COMMA, tokens.get(2).getType()); + assertEquals(Token.Type.SCONST, tokens.get(3).getType()); + assertEquals("test", tokens.get(3).getValue()); + assertEquals(Token.Type.COMMA, tokens.get(4).getType()); + assertEquals(Token.Type.ICONST, tokens.get(5).getType()); + assertEquals("123", tokens.get(5).getValue()); + assertEquals(Token.Type.COMMA, tokens.get(6).getType()); + assertEquals(Token.Type.FCONST, tokens.get(7).getType()); + assertEquals("45.67", tokens.get(7).getValue()); + assertEquals(Token.Type.FROM, tokens.get(8).getType()); + assertEquals(Token.Type.IDENT, tokens.get(9).getType()); + assertEquals("users", tokens.get(9).getValue()); + assertEquals(Token.Type.EOF, tokens.get(10).getType()); + } + + @Test + void testParseError() { + String invalidSql = "SELECT FROM"; // Missing column list + + assertThrows(SqlParser.ParseException.class, () -> { + parser.parse(invalidSql); + }); + } + + @Test + void testFormatting() { + String sql = "SELECT id, name FROM users WHERE age > 18 ORDER BY name"; + String formatted = parser.parseAndFormat(sql); + + assertTrue(formatted.contains("SELECT")); + assertTrue(formatted.contains("FROM")); + assertTrue(formatted.contains("WHERE")); + assertTrue(formatted.contains("ORDER BY")); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/SqlAnalyzerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/SqlAnalyzerTest.java new file mode 100644 index 000000000..80add441b --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/parser/SqlAnalyzerTest.java @@ -0,0 +1,330 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.encryption.parser; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class SqlAnalyzerTest { + + private SQLAnalyzer analyzer; + + @BeforeEach + public void setUp() { + analyzer = new SQLAnalyzer(); + } + + @Test + public void testSelectWithColumns() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("SELECT name, age FROM users"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "age".equals(c.columnName))); + } + + @Test + public void testSelectStar() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("SELECT * FROM products"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("products")); + assertEquals(0, result.columns.size()); // * is not added to columns + } + + @Test + public void testSelectWithoutTable() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("SELECT 1, 'test'"); + assertEquals("SELECT", result.queryType); + } + + @Test + public void testInvalidSQL() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("INVALID SQL"); + assertEquals("UNKNOWN", result.queryType); + } + + @Test + public void testComplexSelect() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT u.name, u.email, p.title FROM users u JOIN posts p ON u.id = p.user_id WHERE u.active = true"); + assertEquals("SELECT", result.queryType); + assertEquals(2, result.tables.size()); + assertTrue(result.tables.contains("users")); + assertTrue(result.tables.contains("posts")); + assertEquals(3, result.columns.size()); + + // Verify columns have correct table names (not aliases) + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "email".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "posts".equals(c.tableName) && "title".equals(c.columnName))); + } + + @Test + public void testCreateTable() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("CREATE TABLE test (id INT, name VARCHAR(50))"); + assertEquals("CREATE", result.queryType); + } + + @Test + public void testInsertWithoutPlaceholders() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("INSERT INTO users (name, email) VALUES ('John', 'john@example.com')"); + assertEquals("INSERT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "email".equals(c.columnName))); + } + + @Test + public void testInsertWithPlaceholders() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("INSERT INTO users (name, email, age) VALUES (?, ?, ?)"); + assertEquals("INSERT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(3, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "email".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "age".equals(c.columnName))); + } + + @Test + public void testUpdateWithoutPlaceholders() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("UPDATE users SET name = 'Jane', email = 'jane@example.com' WHERE id = 1"); + assertEquals("UPDATE", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); // name, email (SET clause only) + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "email".equals(c.columnName))); + } + + @Test + public void testUpdateWithPlaceholders() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("UPDATE users SET name = ?, email = ? WHERE id = ?"); + assertEquals("UPDATE", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); // name, email (SET clause only) + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "email".equals(c.columnName))); + } + + @Test + public void testDelete() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("DELETE FROM users WHERE id = 1"); + assertEquals("DELETE", result.queryType); + assertTrue(result.tables.contains("users")); + } + + @Test + public void testDrop() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("DROP TABLE users"); + assertEquals("DROP", result.queryType); + } + + @Test + public void testMultiTableJoin() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT u.name, o.total, p.title FROM users u JOIN orders o ON u.id = o.user_id JOIN products p ON o.product_id = p.id"); + assertEquals("SELECT", result.queryType); + assertEquals(3, result.tables.size()); + assertTrue(result.tables.contains("users")); + assertTrue(result.tables.contains("orders")); + assertTrue(result.tables.contains("products")); + assertEquals(3, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "orders".equals(c.tableName) && "total".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "products".equals(c.tableName) && "title".equals(c.columnName))); + } + + @Test + public void testCrossJoin() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("SELECT * FROM users CROSS JOIN products"); + assertEquals("SELECT", result.queryType); + assertEquals(2, result.tables.size()); + assertTrue(result.tables.contains("users")); + assertTrue(result.tables.contains("products")); + } + + @Test + public void testSelectWithCaseExpression() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT name, CASE WHEN age > 18 THEN 'adult' ELSE 'minor' END FROM users"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("users")); + // Should extract 'name' column, CASE is treated as expression + assertTrue(result.columns.stream().anyMatch(c -> "name".equals(c.columnName))); + } + + @Test + public void testSelectWithCast() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("SELECT CAST(price AS INTEGER) FROM products"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("products")); + } + + @Test + public void testUpdateMultipleColumns() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "UPDATE users SET name = ?, email = ?, age = ? WHERE id = ?"); + assertEquals("UPDATE", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(3, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "email".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "age".equals(c.columnName))); + } + + @Test + public void testInsertMultipleRows() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')"); + assertEquals("INSERT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); + } + + @Test + public void testSelectWithOrderBy() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT name, age FROM users ORDER BY age DESC NULLS LAST"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); + } + + @Test + public void testSelectWithGroupBy() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT department, COUNT(*) FROM employees GROUP BY department"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("employees")); + } + + @Test + public void testSelectWithHaving() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("employees")); + } + + @Test + public void testSelectWithLimitOffset() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT * FROM users ORDER BY id LIMIT 10 OFFSET 20"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("users")); + } + + @Test + public void testDeleteWithWhere() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("DELETE FROM users WHERE age < 18"); + assertEquals("DELETE", result.queryType); + assertTrue(result.tables.contains("users")); + } + + @Test + public void testSchemaQualifiedTable() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze("SELECT * FROM public.users"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("users")); + } + + @Test + public void testLeftJoin() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT u.name, o.total FROM users u LEFT JOIN orders o ON u.id = o.user_id"); + assertEquals("SELECT", result.queryType); + assertEquals(2, result.tables.size()); + assertTrue(result.tables.contains("users")); + assertTrue(result.tables.contains("orders")); + } + + @Test + public void testRightJoin() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT u.name, o.total FROM users u RIGHT JOIN orders o ON u.id = o.user_id"); + assertEquals("SELECT", result.queryType); + assertEquals(2, result.tables.size()); + assertTrue(result.tables.contains("users")); + assertTrue(result.tables.contains("orders")); + } + + @Test + public void testInsertWithSchemaQualifiedTable() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "INSERT INTO myschema.users (name, ssn) VALUES (?, ?)"); + assertEquals("INSERT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "ssn".equals(c.columnName))); + assertTrue(result.hasParameters); + } + + @Test + public void testUpdateWithSchemaQualifiedTable() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "UPDATE app_data.customers SET email = ? WHERE id = ?"); + assertEquals("UPDATE", result.queryType); + assertTrue(result.tables.contains("customers")); + assertEquals(1, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "customers".equals(c.tableName) && "email".equals(c.columnName))); + assertEquals(1, result.whereColumns.size()); + assertTrue(result.whereColumns.stream().anyMatch(c -> "customers".equals(c.tableName) && "id".equals(c.columnName))); + assertTrue(result.hasParameters); + } + + @Test + public void testSelectWithSchemaQualifiedTableAndColumns() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT u.name, u.ssn FROM hr.users u WHERE u.id = ?"); + assertEquals("SELECT", result.queryType); + assertTrue(result.tables.contains("users")); + assertEquals(2, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "ssn".equals(c.columnName))); + assertEquals(1, result.whereColumns.size()); + assertTrue(result.whereColumns.stream().anyMatch(c -> "users".equals(c.tableName) && "id".equals(c.columnName))); + assertTrue(result.hasParameters); + } + + @Test + public void testJoinWithMixedSchemaQualifiedTables() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "SELECT u.name, o.total FROM public.users u JOIN sales.orders o ON u.id = o.user_id"); + assertEquals("SELECT", result.queryType); + assertEquals(2, result.tables.size()); + assertTrue(result.tables.contains("users")); + assertTrue(result.tables.contains("orders")); + assertEquals(2, result.columns.size()); + assertTrue(result.columns.stream().anyMatch(c -> "users".equals(c.tableName) && "name".equals(c.columnName))); + assertTrue(result.columns.stream().anyMatch(c -> "orders".equals(c.tableName) && "total".equals(c.columnName))); + } + + @Test + public void testDeleteWithSchemaQualifiedTable() { + SQLAnalyzer.QueryAnalysis result = analyzer.analyze( + "DELETE FROM archive.old_records WHERE created_at < ?"); + assertEquals("DELETE", result.queryType); + assertTrue(result.tables.contains("old_records")); + assertEquals(1, result.whereColumns.size()); + assertTrue(result.whereColumns.stream().anyMatch(c -> "old_records".equals(c.tableName) && "created_at".equals(c.columnName))); + assertTrue(result.hasParameters); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/sql/SqlAnalysisServiceTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/sql/SqlAnalysisServiceTest.java new file mode 100644 index 000000000..4e9a61f06 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/encryption/sql/SqlAnalysisServiceTest.java @@ -0,0 +1,309 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.encryption.sql; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.plugin.encryption.metadata.MetadataManager; +import software.amazon.jdbc.plugin.encryption.model.ColumnEncryptionConfig; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; + +class SqlAnalysisServiceTest { + + @Mock + private PluginService pluginService; + + @Mock + private MetadataManager metadataManager; + + private SqlAnalysisService sqlAnalysisService; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + sqlAnalysisService = new SqlAnalysisService(pluginService, metadataManager); + } + + @Test + void testInsertStatements() { + // Simple INSERT + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "INSERT INTO customers (name, email) VALUES (?, ?)"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // INSERT with schema - extract just table name + result = sqlAnalysisService.analyzeSql( + "INSERT INTO public.users (id, username, password) VALUES (1, 'john', 'secret')"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("users")); + + // Multi-value INSERT + result = sqlAnalysisService.analyzeSql( + "INSERT INTO products (name, price) VALUES ('Product1', 10.99), ('Product2', 15.50)"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("products")); + } + + @Test + void testUpdateStatements() { + // Simple UPDATE + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "UPDATE customers SET email = ? WHERE id = ?"); + assertEquals("UPDATE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // UPDATE with JOIN - expect first table only + result = sqlAnalysisService.analyzeSql( + "UPDATE orders o SET status = 'shipped' FROM customers c WHERE o.customer_id = c.id"); + assertEquals("UPDATE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("orders")); + + // UPDATE with schema - extract just table name + result = sqlAnalysisService.analyzeSql( + "UPDATE public.inventory SET quantity = quantity - 1 WHERE product_id = ?"); + assertEquals("UPDATE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("inventory")); + } + + @Test + void testSelectStatements() { + // Simple SELECT + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "SELECT * FROM customers WHERE id = ?"); + assertEquals("SELECT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // SELECT with JOIN - expect first table only + result = sqlAnalysisService.analyzeSql( + "SELECT c.name, o.total FROM customers c JOIN orders o ON c.id = o.customer_id"); + assertEquals("SELECT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // SELECT with subquery - expect main table + result = sqlAnalysisService.analyzeSql( + "SELECT * FROM products WHERE price > (SELECT AVG(price) FROM products)"); + assertEquals("SELECT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("products")); + } + + @Test + void testInsertFromSelect() { + // INSERT INTO ... SELECT FROM single table - expect target table + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "INSERT INTO backup_customers SELECT * FROM customers WHERE active = true"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("backup_customers")); + + // INSERT INTO ... SELECT with specific columns - expect target table + result = sqlAnalysisService.analyzeSql( + "INSERT INTO customer_summary (name, total_orders) SELECT c.name, COUNT(o.id) FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id, c.name"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customer_summary")); + + // INSERT INTO ... SELECT with WHERE clause - expect target table + result = sqlAnalysisService.analyzeSql( + "INSERT INTO archived_orders SELECT o.*, c.name FROM orders o JOIN customers c ON o.customer_id = c.id WHERE o.created_date < '2023-01-01'"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("archived_orders")); + + // INSERT INTO ... SELECT with subquery - expect target table + result = sqlAnalysisService.analyzeSql( + "INSERT INTO high_value_customers SELECT * FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE total > 1000)"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("high_value_customers")); + + // INSERT INTO ... SELECT with UNION - expect target table + result = sqlAnalysisService.analyzeSql( + "INSERT INTO all_contacts SELECT name, email FROM customers UNION SELECT name, email FROM suppliers"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("all_contacts")); + } + + @Test + void testEdgeCases() { + // Empty SQL + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql(""); + assertEquals("UNKNOWN", result.getQueryType()); + assertTrue(result.getAffectedTables().isEmpty()); + + // Null SQL + result = sqlAnalysisService.analyzeSql(null); + assertEquals("UNKNOWN", result.getQueryType()); + assertTrue(result.getAffectedTables().isEmpty()); + + // Whitespace only + result = sqlAnalysisService.analyzeSql(" \n\t "); + assertEquals("UNKNOWN", result.getQueryType()); + assertTrue(result.getAffectedTables().isEmpty()); + } + + @Test + void testOtherStatements() { + // DELETE + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "DELETE FROM customers WHERE id = ?"); + assertEquals("DELETE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // CREATE TABLE - parser works correctly + result = sqlAnalysisService.analyzeSql( + "CREATE TABLE new_table (id SERIAL PRIMARY KEY, name VARCHAR(100))"); + assertEquals("CREATE", result.getQueryType()); + + // DROP TABLE - jOOQ parser works correctly + result = sqlAnalysisService.analyzeSql( + "DROP TABLE old_table"); + assertEquals("DROP", result.getQueryType()); + } + + @Test + void testBasicQueryAnalysis() { + // INSERT statement + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "INSERT INTO customers (name, ssn, credit_card, email) VALUES (?, ?, ?, ?)"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // UPDATE statement + result = sqlAnalysisService.analyzeSql( + "UPDATE customers SET ssn = ?, email = ? WHERE id = ?"); + assertEquals("UPDATE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // SELECT statement + result = sqlAnalysisService.analyzeSql( + "SELECT name, ssn, credit_card FROM customers WHERE id = ?"); + assertEquals("SELECT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + } + + @Test + void testUpdateParameterMapping() { + // Simple UPDATE statement + Map mapping = sqlAnalysisService.getColumnParameterMapping( + "UPDATE users SET ssn = ?, email = ? WHERE id = ?"); + assertEquals(2, mapping.size()); // ssn, email (SET clause only) + assertEquals("ssn", mapping.get(1)); + assertEquals("email", mapping.get(2)); + + // UPDATE with single column + mapping = sqlAnalysisService.getColumnParameterMapping( + "UPDATE customers SET name = ? WHERE id = ?"); + assertEquals(1, mapping.size()); // name (SET clause only) + assertEquals("name", mapping.get(1)); + + // UPDATE with multiple columns + mapping = sqlAnalysisService.getColumnParameterMapping( + "UPDATE products SET name = ?, price = ?, description = ? WHERE category = ?"); + assertEquals(3, mapping.size()); // name, price, description (SET clause only) + assertEquals("name", mapping.get(1)); + assertEquals("price", mapping.get(2)); + assertEquals("description", mapping.get(3)); + } + + @Test + void testSelectParameterMapping() { + // SELECT with WHERE clause parameter + Map mapping = sqlAnalysisService.getColumnParameterMapping( + "SELECT ssn FROM users WHERE name = ?"); + assertEquals(1, mapping.size()); + assertEquals("name", mapping.get(1)); + + // SELECT with multiple WHERE parameters + mapping = sqlAnalysisService.getColumnParameterMapping( + "SELECT ssn, email FROM users WHERE name = ? AND age = ?"); + assertEquals(2, mapping.size()); + assertEquals("name", mapping.get(1)); + assertEquals("age", mapping.get(2)); + + // SELECT with no parameters - should have no parameter mapping + mapping = sqlAnalysisService.getColumnParameterMapping( + "SELECT ssn FROM users WHERE name = 'John'"); + assertEquals(0, mapping.size()); + } + + @Test + void testMultiTableQueries() { + // JOIN query - expect first table only + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "SELECT c.name, c.ssn, o.payment_info FROM customers c JOIN orders o ON c.id = o.customer_id"); + assertEquals("SELECT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // INSERT FROM SELECT - expect target table + result = sqlAnalysisService.analyzeSql( + "INSERT INTO backup_customers SELECT name, ssn, credit_card FROM customers WHERE active = true"); + assertEquals("INSERT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("backup_customers")); + } + + @Test + void testComplexQueryAnalysis() { + // Test complex UPDATE query analysis + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "UPDATE customers SET name = ?, ssn = ? WHERE id = 123"); + + assertEquals("UPDATE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // Test parameter mapping for UPDATE (only SET clause parameters are mapped) + Map mapping = sqlAnalysisService.getColumnParameterMapping( + "UPDATE customers SET name = ?, ssn = ? WHERE id = 123"); + assertEquals(2, mapping.size()); // name, ssn (SET clause only) + assertEquals("name", mapping.get(1)); + assertEquals("ssn", mapping.get(2)); + + // Test JOIN query analysis + result = sqlAnalysisService.analyzeSql( + "SELECT c.name, c.ssn FROM customers c JOIN orders o ON c.id = o.customer_id"); + assertEquals("SELECT", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + + // Test DELETE query analysis + result = sqlAnalysisService.analyzeSql("DELETE FROM customers WHERE id = ?"); + assertEquals("DELETE", result.getQueryType()); + assertTrue(result.getAffectedTables().contains("customers")); + } + + @Test + void testCaseInsensitivity() { + // Lowercase + SqlAnalysisService.SqlAnalysisResult result = sqlAnalysisService.analyzeSql( + "insert into customers (name) values (?)"); + assertEquals("INSERT", result.getQueryType()); + + // Mixed case + result = sqlAnalysisService.analyzeSql( + "Update Customers Set Name = ? Where Id = ?"); + assertEquals("UPDATE", result.getQueryType()); + + // Uppercase + result = sqlAnalysisService.analyzeSql( + "SELECT * FROM CUSTOMERS"); + assertEquals("SELECT", result.getQueryType()); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/SqlMethodAnalyzerTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/SqlMethodAnalyzerTest.java index 93faf55dc..4004c64e7 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/SqlMethodAnalyzerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/SqlMethodAnalyzerTest.java @@ -33,13 +33,15 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.PgDialect; class SqlMethodAnalyzerTest { private static final String EXECUTE_METHOD = "execute"; private static final String EMPTY_SQL = ""; @Mock Connection conn; - + @Mock PluginService pluginService; private final SqlMethodAnalyzer sqlMethodAnalyzer = new SqlMethodAnalyzer(); private AutoCloseable closeable; @@ -47,6 +49,7 @@ class SqlMethodAnalyzerTest { @BeforeEach void setUp() { closeable = MockitoAnnotations.openMocks(this); + when(pluginService.getDialect()).thenReturn(new PgDialect()); } @AfterEach @@ -67,13 +70,13 @@ void testOpenTransaction(final String methodName, final String sql, final boolea } when(conn.getAutoCommit()).thenReturn(autocommit); - final boolean actual = sqlMethodAnalyzer.doesOpenTransaction(conn, methodName, args); + final boolean actual = sqlMethodAnalyzer.doesOpenTransaction(conn, methodName, args, pluginService); assertEquals(expected, actual); } @Test void testOpenTransactionWithEmptySqlDoesNotThrow() { - assertDoesNotThrow(() -> sqlMethodAnalyzer.doesOpenTransaction(conn, EXECUTE_METHOD, new String[]{EMPTY_SQL})); + assertDoesNotThrow(() -> sqlMethodAnalyzer.doesOpenTransaction(conn, EXECUTE_METHOD, new String[]{EMPTY_SQL}, pluginService)); } @ParameterizedTest @@ -86,39 +89,39 @@ void testCloseTransaction(final String methodName, final String sql, final boole args = new Object[] {}; } - final boolean actual = sqlMethodAnalyzer.doesCloseTransaction(conn, methodName, args); + final boolean actual = sqlMethodAnalyzer.doesCloseTransaction(conn, methodName, args, pluginService); assertEquals(expected, actual); } @Test void testCloseTransactionWithEmptySqlDoesNotThrow() { - assertDoesNotThrow(() -> sqlMethodAnalyzer.doesCloseTransaction(conn, EXECUTE_METHOD, new String[]{EMPTY_SQL})); + assertDoesNotThrow(() -> sqlMethodAnalyzer.doesCloseTransaction(conn, EXECUTE_METHOD, new String[]{EMPTY_SQL}, pluginService)); } @Test void testDoesSwitchAutoCommitFalseTrue() throws SQLException { assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Connection.setAutoCommit", - new Object[] {false})); + new Object[] {false}, pluginService)); assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Statement.execute", - new Object[] {"SET autocommit = 0"})); + new Object[] {"SET autocommit = 0"}, pluginService)); assertTrue(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Connection.setAutoCommit", - new Object[] {true})); + new Object[] {true}, pluginService)); assertTrue(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Statement.execute", - new Object[] {"SET autocommit = 1"})); + new Object[] {"SET autocommit = 1"}, pluginService)); when(conn.getAutoCommit()).thenReturn(true); assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Connection.setAutoCommit", - new Object[] {false})); + new Object[] {false}, pluginService)); assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Statement.execute", - new Object[] {"SET autocommit = 0"})); + new Object[] {"SET autocommit = 0"}, pluginService)); assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Connection.setAutoCommit", - new Object[] {true})); + new Object[] {true}, pluginService)); assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Statement.execute", - new Object[] {"SET autocommit = 1"})); + new Object[] {"SET autocommit = 1"}, pluginService)); assertFalse(sqlMethodAnalyzer.doesSwitchAutoCommitFalseTrue(conn, "Statement.execute", - new Object[] {"SET TIME ZONE 'UTC'"})); + new Object[] {"SET TIME ZONE 'UTC'"}, pluginService)); } @ParameterizedTest @@ -132,13 +135,13 @@ void testIsStatementSettingAutoCommit(final String methodName, final String sql, args = new Object[] {}; } - final boolean actual = sqlMethodAnalyzer.isStatementSettingAutoCommit(methodName, args); + final boolean actual = sqlMethodAnalyzer.isStatementSettingAutoCommit(methodName, args, pluginService); assertEquals(expected, actual); } @Test void testIsStatementSettingAutoCommitWithEmptySqlDoesNotThrow() { - assertDoesNotThrow(() -> sqlMethodAnalyzer.isStatementSettingAutoCommit(EXECUTE_METHOD, new String[]{EMPTY_SQL})); + assertDoesNotThrow(() -> sqlMethodAnalyzer.isStatementSettingAutoCommit(EXECUTE_METHOD, new String[]{EMPTY_SQL}, pluginService)); } @ParameterizedTest