diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..83ca4e9c7 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +DB_CONNECTION_STRING=jdbc:aws-wrapper:postgresql://localhost:5432/dbname +CACHE_RW_SERVER_ADDR=localhost:6379 +CACHE_RO_SERVER_ADDR=localhost:6380 +DB_USERNAME=postgres +DB_PASSWORD=admin diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index 359765fe5..09c2809ae 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -25,6 +25,10 @@ dependencies { implementation("org.mariadb.jdbc:mariadb-java-client:3.5.6") implementation("com.zaxxer:HikariCP:4.0.3") implementation("org.checkerframework:checker-qual:3.49.5") + implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") + implementation("org.apache.commons:commons-pool2:2.11.1") + annotationProcessor("org.openjdk.jmh:jmh-core:1.36") + jmhAnnotationProcessor ("org.openjdk.jmh:jmh-generator-annprocess:1.36") testImplementation("org.junit.jupiter:junit-jupiter-api:5.12.2") testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8 diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java new file mode 100644 index 000000000..8c18d4c44 --- /dev/null +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java @@ -0,0 +1,144 @@ +package software.amazon.jdbc.benchmarks; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import java.sql.*; +import java.util.*; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.profile.GCProfiler; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * Performance benchmark program against PG. + * + * This test program runs JMH benchmark tests the performance of the remote cache plugin against a + * a remote PG database and a remote cache server for both indexed queries and non-indexed queries. + * + * The database table schema is as follows: + * + * postgres=# CREATE TABLE test (id SERIAL PRIMARY KEY, int_col INTEGER, varchar_col varchar(50) NOT NULL, text_col TEXT, + * num_col DOUBLE PRECISION, date_col date, time_col TIME WITHOUT TIME ZONE, time_tz TIME WITH TIME ZONE, + * ts_col TIMESTAMP WITHOUT TIME ZONE, ts_tz TIMESTAMP WITH TIME ZONE, description TEXT); + * CREATE TABLE + * postgres=# select * from test; + * id | int_col | varchar_col | text_col | num_col | date_col | time_col | time_tz | ts_col | ts_tz | description + * ----+---------+-------------+----------+---------+----------+----------+---------+--------+-------+-------------- + * (0 rows) + * + */ +@State(Scope.Thread) +@Fork(1) +@Warmup(iterations = 1) +@Measurement(iterations = 60, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class PgCacheBenchmarks { + private static final String DB_CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-0.XYZ.us-east-2.rds.amazonaws.com:5432/postgres"; + private static final String CACHE_RW_SERVER_ADDR = "cache-0.XYZ.us-east-2.rds.amazonaws.com:6379"; + private static final String CACHE_RO_SERVER_ADDR = "cache-0.XYZ.us-east-2.rds.amazonaws.com:6380"; + + private Connection connection; + private int counter; + long startTime; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(PgCacheBenchmarks.class.getSimpleName()) + .addProfiler(GCProfiler.class) + .detectJvmArgs() + .build(); + + new Runner(opt).run(); + } + + @Setup(Level.Trial) + public void setup() throws SQLException { + try { + software.amazon.jdbc.Driver.register(); + } catch (IllegalStateException e) { + System.out.println("exception during register() is " + e.getMessage()); + } + Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "dataRemoteCache"); + properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + properties.setProperty("wrapperLogUnclosedConnections", "true"); + counter = 0; + connection = DriverManager.getConnection(DB_CONNECTION_STRING, properties); + startTime = System.currentTimeMillis(); + } + + @TearDown(Level.Trial) + public void tearDown() throws SQLException { + connection.close(); + } + + // Code to warm up the data in the table + public void warmUpDataSet() throws SQLException { + String desc_1KB = "mP48pHrR5vreBo3N6ecmlDgvfEAz0kQEOUQ89U3Rh05BTG9LhB8R0HBFBp5RIqc8vVcrphu89kW1OE2c2xApwpczFMdDAuk2SxOl9OrLvfk9zGYrdfzedcepT8LVeE6NTtYDeP3yo6UFC6AiOeqRBY5NEaNcZ8fuoXVpqOrqAhz910v5XrFxeXUyPDFxuaKFLaHfEFq7BRasUc9nfhP8gblKAGfEEmgYBpUKio27Rfo0xnavfVJQkAA2kME2PT4qZRSqeDkLmn7VBAzT9ghHqe9D4kQLQKjIyIPKqYoS8kW3ShW44VqYENwPSRAXw7UqOJqlKJ4pnmx4sPZO2kI4NYOl1JZXNlbGaSzJR0cOloKiY0z2OmUNvmD0Wju1DC9TT4OY6a6DOfFvk265BfDVxT6ufN68YG9sZuVsl7jq8SZSJg3x2cqlJuAtdSTIoKmJT1a6cEXxVusmdO27kRRp1BfWR4gz4w9HawYf9nBQOq76ObctlNvj0fYUUG3I49s3iP33CL8qZjj9RnyNUus6ieiZgta6L3mZuMRYOgCLyJrAKUYEL9KND7qirCPzVgmJHWIOnVewu8mldYFhroL89yvV3bZx4MGeyPU4KvbCsRgdORCTN0XhuLYUdiehHXnDBfuZ5yyR0saWLh8gjkLV5GkxTeKpOhpoK1o1cMiCDPYqTa64g5JundlW707c9zxc3Xnf2pW7E74YJl5oBu5vWEyPqXtYOtZOjOIRxxDY8QpoW8mpbQXxgB8DjkZZMiUCe0qHZYxvktVZJmHoaYBwpYpXVTZCfq9WajmkIOdIad1VnH5HpaECLRs6loa259yH8qesak2feDiKjfb8p3uj3s7WZUvPJwAWX9PIW1p7x6OiszXQCntOFRC3bQFNz1c98wlCBJnBSxbbYhU057TDNnoaib1h9bH7LAcqD1caE5KwLMAc5HqugkkRzT5NszkdJcpF0SxakdrAQLOKS6sNwDUzBJA76F775vmaqe3XIYecPmGtfoAKMychfEI4vfNr"; + for (int i = 0; i < 400000; i++) { + Statement stmt = connection.createStatement(); + String description = "description " + i; + String text = "here is my text data " + i; + String query = "insert into test values (" + i + ", " + i * 10 + ", '" + description + "', '" + text + "', " + i * 100 + 0.1234 + ", '2024-01-10', '10:00:00', '10:00:00-07', '2025-07-15 10:00:00', '2025-07-15 10:00:00-07'" + ", '" + desc_1KB + "');"; + int rs = stmt.executeUpdate(query); + assert rs == 1; + } + } + + private void validateResultSet(ResultSet rs, Blackhole b) throws SQLException { + while (rs.next()) { + b.consume(rs.getInt(1)); + b.consume(rs.getInt(2)); + b.consume(rs.getString(3)); + b.consume(rs.getString(4)); + b.consume(rs.getDouble(5)); + b.consume(rs.getDate(6)); + b.consume(rs.getTime(7)); + b.consume(rs.getTime(8)); + b.consume(rs.getTimestamp(9)); + b.consume(rs.getTimestamp(10)); + b.consume(rs.wasNull()); + } + } + + @Benchmark + public void runBenchmarkPrimaryKeyLookupNoCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM test where id = " + counter)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkNonIndexedLookupNoCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM test where int_col = " + counter*10)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkPrimaryKeyLookupWithCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where id = " + counter)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkNonIndexedLookupWithCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where int_col = " + counter*10)) { + validateResultSet(rs, b); + } + counter++; + } +} diff --git a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md index dea51924f..b8069fc24 100644 --- a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md +++ b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md @@ -220,7 +220,7 @@ The AWS JDBC Driver has several built-in plugins that are available to use. Plea [^2]: Federated Identity and Okta rely on IAM. Due to [^1], RDS Multi-AZ Clusters are not supported. > [!NOTE]\ -> To see information logged by plugins such as `DataCacheConnectionPlugin` and `LogQueryConnectionPlugin`, see the [Logging](#logging) section. +> To see information logged by plugins such as `DataLocalCacheConnectionPlugin` and `LogQueryConnectionPlugin`, see the [Logging](#logging) section. In addition to the built-in plugins, you can also create custom plugins more suitable for your needs. For more information, see [Custom Plugins](../development-guide/LoadablePlugins.md#using-custom-plugins). diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index 08fee3f19..ae4b9ab61 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -29,6 +29,8 @@ dependencies { implementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") implementation("org.jsoup:jsoup:1.21.1") implementation("com.mchange:c3p0:0.11.0") + implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") + implementation("org.apache.commons:commons-pool2:2.11.1") } tasks.withType { diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java new file mode 100644 index 000000000..c199beb1e --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -0,0 +1,84 @@ +package software.amazon; + +import software.amazon.util.EnvLoader; +import java.sql.*; +import java.util.*; +import java.util.logging.Logger; + +public class DatabaseConnectionWithCacheExample { + + private static final EnvLoader env = new EnvLoader(); + + private static final String DB_CONNECTION_STRING = env.get("DB_CONNECTION_STRING"); + private static final String CACHE_RW_SERVER_ADDR = env.get("CACHE_RW_SERVER_ADDR"); + private static final String CACHE_RO_SERVER_ADDR = env.get("CACHE_RO_SERVER_ADDR"); + // If the cache server is authenticated with IAM + private static final String CACHE_NAME = env.get("CACHE_NAME"); + // Both IAM and traditional auth uses the same CACHE_USERNAME + private static final String CACHE_USERNAME = env.get("CACHE_USERNAME"); // e.g., "iam-user-01" / "username" + private static final String CACHE_IAM_REGION = env.get("CACHE_IAM_REGION"); // e.g., "us-west-2" + private static final String CACHE_USE_SSL = env.get("CACHE_USE_SSL"); + // If the cache server is authenticated with traditional username password + // private static final String CACHE_PASSWORD = env.get("CACHE_PASSWORD"); + private static final String USERNAME = env.get("DB_USERNAME"); + private static final String PASSWORD = env.get("DB_PASSWORD"); + private static final int THREAD_COUNT = 8; //Use 8 Threads + private static final long TEST_DURATION_MS = 16000; //Test duration for 16 seconds + private static final String CACHE_CONNECTION_TIMEOUT = env.get("CACHE_CONNECTION_TIMEOUT"); //Set connection timeout in milliseconds + private static final String CACHE_CONNECTION_POOL_SIZE = env.get("CACHE_CONNECTION_POOL_SIZE"); //Set connection pool size + + public static void main(String[] args) throws SQLException { + final Properties properties = new Properties(); + final Logger LOGGER = Logger.getLogger(DatabaseConnectionWithCacheExample.class.getName()); + + // Configuring connection properties for the underlying JDBC driver. + properties.setProperty("user", USERNAME); + properties.setProperty("password", PASSWORD); + + // Configuring connection properties for the JDBC Wrapper. + properties.setProperty("wrapperPlugins", "dataRemoteCache"); + properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + // If the cache server is authenticated with IAM + properties.setProperty("cacheName", CACHE_NAME); + properties.setProperty("cacheUsername", CACHE_USERNAME); + properties.setProperty("cacheIamRegion", CACHE_IAM_REGION); + // If the cache server is authenticated with traditional username password + // properties.setProperty("cachePassword", PASSWORD); + properties.setProperty("cacheUseSSL", CACHE_USE_SSL); // "true" or "false" + properties.setProperty("wrapperLogUnclosedConnections", "true"); + properties.setProperty("cacheConnectionTimeout", CACHE_CONNECTION_TIMEOUT); + properties.setProperty("cacheConnectionPoolSize", CACHE_CONNECTION_POOL_SIZE); + String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas"; + + // Create threads for concurrent connection testing + Thread[] threads = new Thread[THREAD_COUNT]; + for (int t = 0; t < THREAD_COUNT; t++) { + // Each thread uses a single connection for multiple queries + threads[t] = new Thread(() -> { + try { + try (Connection conn = DriverManager.getConnection(DB_CONNECTION_STRING, properties)) { + long endTime = System.currentTimeMillis() + TEST_DURATION_MS; + try (Statement stmt = conn.createStatement()) { + while (System.currentTimeMillis() < endTime) { + ResultSet rs = stmt.executeQuery(queryStr); + System.out.println("Executed the SQL query with result sets: " + rs.toString()); + } + } + } + } catch (Exception e) { + LOGGER.warning("Error: " + e.getMessage()); + } + }); + threads[t].start(); + } + // Wait for all threads to complete + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + LOGGER.warning("Thread interrupted: " + e.getMessage()); + } + } + } +} diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java b/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java new file mode 100644 index 000000000..7b12d91f5 --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java @@ -0,0 +1,83 @@ +package software.amazon.util; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +/** + * A simple utility class to load environment variables from a .env file. + */ +public class EnvLoader { + private final Map envVars = new HashMap<>(); + + /** + * Loads environment variables from a .env file in the current directory. + */ + public EnvLoader() { + this(Paths.get(".env")); + } + + /** + * Loads environment variables from the specified file path. + * + * @param envPath Path to the .env file + */ + public EnvLoader(Path envPath) { + if (Files.exists(envPath)) { + try (BufferedReader reader = new BufferedReader(new FileReader(envPath.toFile()))) { + String line; + while ((line = reader.readLine()) != null) { + parseLine(line); + } + } catch (IOException e) { + System.err.println("Error reading .env file: " + e.getMessage()); + } + } + } + + private void parseLine(String line) { + line = line.trim(); + // Skip empty lines and comments + if (line.isEmpty() || line.startsWith("#")) { + return; + } + + // Split on the first equals sign + int delimiterPos = line.indexOf('='); + if (delimiterPos > 0) { + String key = line.substring(0, delimiterPos).trim(); + String value = line.substring(delimiterPos + 1).trim(); + + // Remove quotes if present + if ((value.startsWith("\"") && value.endsWith("\"")) || + (value.startsWith("'") && value.endsWith("'"))) { + value = value.substring(1, value.length() - 1); + } + + envVars.put(key, value); + } + } + + /** + * Gets the value of an environment variable. + * + * @param key The name of the environment variable + * @return The value of the environment variable, or null if not found + */ + public String get(String key) { + // First check the loaded .env file + String value = envVars.get(key); + + // If not found, check system environment variables + if (value == null) { + value = System.getenv(key); + } + + return value; + } +} diff --git a/examples/AWSDriverExample/src/main/resources/logback.xml b/examples/AWSDriverExample/src/main/resources/logback.xml new file mode 100644 index 000000000..e03eaf554 --- /dev/null +++ b/examples/AWSDriverExample/src/main/resources/logback.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 0e6fb29f1..7de1e94d1 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -44,8 +44,10 @@ dependencies { optionalImplementation("com.mchange:c3p0:0.11.0") optionalImplementation("org.apache.httpcomponents:httpclient:4.5.14") optionalImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") + optionalImplementation("org.apache.commons:commons-pool2:2.11.1") optionalImplementation("org.jsoup:jsoup:1.21.1") optionalImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") + optionalImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") optionalImplementation("io.opentelemetry:opentelemetry-api:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") @@ -98,10 +100,12 @@ dependencies { testImplementation("org.slf4j:slf4j-simple:2.0.17") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") + testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("io.opentelemetry:opentelemetry-api:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") testImplementation("io.opentelemetry:opentelemetry-exporter-otlp:1.52.0") + testImplementation("org.apache.commons:commons-pool2:2.11.1") testImplementation("org.jsoup:jsoup:1.21.1") testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 @@ -208,7 +212,7 @@ if (useJacoco) { "software/amazon/jdbc/wrapper/*", "software/amazon/jdbc/util/*", "software/amazon/jdbc/profile/*", - "software/amazon/jdbc/plugin/DataCacheConnectionPlugin*" + "software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin*" ) } })) @@ -223,7 +227,7 @@ if (useJacoco) { "software/amazon/jdbc/wrapper/*", "software/amazon/jdbc/util/*", "software/amazon/jdbc/profile/*", - "software/amazon/jdbc/plugin/DataCacheConnectionPlugin*" + "software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin*" ) } })) diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 952b00936..ef07e611c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -33,7 +33,8 @@ import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPluginFactory; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPluginFactory; import software.amazon.jdbc.plugin.ConnectTimeConnectionPluginFactory; -import software.amazon.jdbc.plugin.DataCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.cache.DataRemoteCachePluginFactory; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.DriverMetaDataConnectionPluginFactory; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPluginFactory; @@ -68,7 +69,8 @@ public class ConnectionPluginChainBuilder { { put("executionTime", new ExecutionTimeConnectionPluginFactory()); put("logQuery", new LogQueryConnectionPluginFactory()); - put("dataCache", new DataCacheConnectionPluginFactory()); + put("dataCache", new DataLocalCacheConnectionPluginFactory()); + put("dataRemoteCache", new DataRemoteCachePluginFactory()); put("customEndpoint", new CustomEndpointPluginFactory()); put("efm", new HostMonitoringConnectionPluginFactory()); put("efm2", new software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPluginFactory()); @@ -100,7 +102,8 @@ public class ConnectionPluginChainBuilder { new HashMap, Integer>() { { put(DriverMetaDataConnectionPluginFactory.class, 100); - put(DataCacheConnectionPluginFactory.class, 200); + put(DataLocalCacheConnectionPluginFactory.class, 200); + put(DataRemoteCachePluginFactory.class, 250); put(CustomEndpointPluginFactory.class, 380); put(AuroraInitialConnectionStrategyPluginFactory.class, 390); put(AuroraConnectionTrackerPluginFactory.class, 400); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 2697c5b03..b711f4617 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -33,7 +33,8 @@ import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataRemoteCachePlugin; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin; import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; @@ -72,7 +73,8 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { put(ExecutionTimeConnectionPlugin.class, "plugin:executionTime"); put(AuroraConnectionTrackerPlugin.class, "plugin:auroraConnectionTracker"); put(LogQueryConnectionPlugin.class, "plugin:logQuery"); - put(DataCacheConnectionPlugin.class, "plugin:dataCache"); + put(DataLocalCacheConnectionPlugin.class, "plugin:dataCache"); + put(DataRemoteCachePlugin.class, "plugin:dataRemoteCache"); put(HostMonitoringConnectionPlugin.class, "plugin:efm"); put(software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin.class, "plugin:efm2"); put(FailoverConnectionPlugin.class, "plugin:failover"); diff --git a/wrapper/src/main/java/software/amazon/jdbc/Driver.java b/wrapper/src/main/java/software/amazon/jdbc/Driver.java index 7d59e83ff..60db2c8f4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/Driver.java +++ b/wrapper/src/main/java/software/amazon/jdbc/Driver.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; import software.amazon.jdbc.plugin.AwsSecretsManagerCacheHolder; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; import software.amazon.jdbc.plugin.OpenedConnectionTracker; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; @@ -430,7 +430,7 @@ public static void clearCaches() { CustomEndpointMonitorImpl.clearCache(); OpenedConnectionTracker.clearCache(); AwsSecretsManagerCacheHolder.clearCache(); - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); FederatedAuthCacheHolder.clearCache(); OktaAuthCacheHolder.clearCache(); IamAuthCacheHolder.clearCache(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java deleted file mode 100644 index a7cf53a9c..000000000 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java +++ /dev/null @@ -1,1239 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin; - -import java.io.InputStream; -import java.io.Reader; -import java.math.BigDecimal; -import java.net.URL; -import java.sql.Array; -import java.sql.Blob; -import java.sql.Clob; -import java.sql.Date; -import java.sql.NClob; -import java.sql.Ref; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.RowId; -import java.sql.SQLException; -import java.sql.SQLWarning; -import java.sql.SQLXML; -import java.sql.Statement; -import java.sql.Time; -import java.sql.Timestamp; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Calendar; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.logging.Logger; -import software.amazon.jdbc.AwsWrapperProperty; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; - -public class DataCacheConnectionPlugin extends AbstractConnectionPlugin { - - private static final Logger LOGGER = Logger.getLogger(DataCacheConnectionPlugin.class.getName()); - - private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( - Arrays.asList( - JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, - JdbcMethod.STATEMENT_EXECUTE.methodName, - JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, - JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, - JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, - JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName - ))); - - public static final AwsWrapperProperty DATA_CACHE_TRIGGER_CONDITION = new AwsWrapperProperty( - "dataCacheTriggerCondition", "false", - "A regular expression that, if it's matched, allows the plugin to cache SQL results."); - - protected static final Map dataCache = new ConcurrentHashMap<>(); - - protected final String dataCacheTriggerCondition; - - static { - PropertyDefinition.registerPluginProperties(DataCacheConnectionPlugin.class); - } - - private final TelemetryFactory telemetryFactory; - private final TelemetryCounter hitCounter; - private final TelemetryCounter missCounter; - private final TelemetryCounter totalCallsCounter; - private final TelemetryGauge cacheSizeGauge; - - public DataCacheConnectionPlugin(final PluginService pluginService, final Properties props) { - this.telemetryFactory = pluginService.getTelemetryFactory(); - this.dataCacheTriggerCondition = DATA_CACHE_TRIGGER_CONDITION.getString(props); - - this.hitCounter = telemetryFactory.createCounter("dataCache.cache.hit"); - this.missCounter = telemetryFactory.createCounter("dataCache.cache.miss"); - this.totalCallsCounter = telemetryFactory.createCounter("dataCache.cache.totalCalls"); - this.cacheSizeGauge = telemetryFactory.createGauge("dataCache.cache.size", () -> (long) dataCache.size()); - } - - public static void clearCache() { - dataCache.clear(); - } - - @Override - public Set getSubscribedMethods() { - return subscribedMethods; - } - - @Override - public T execute( - final Class resultClass, - final Class exceptionClass, - final Object methodInvokeOn, - final String methodName, - final JdbcCallable jdbcMethodFunc, - final Object[] jdbcMethodArgs) - throws E { - - if (StringUtils.isNullOrEmpty(this.dataCacheTriggerCondition) || resultClass != ResultSet.class) { - return jdbcMethodFunc.call(); - } - - if (this.totalCallsCounter != null) { - this.totalCallsCounter.inc(); - } - - ResultSet result; - boolean needToCache = false; - final String sql = getQuery(jdbcMethodArgs); - - if (!StringUtils.isNullOrEmpty(sql) && sql.matches(this.dataCacheTriggerCondition)) { - result = dataCache.get(sql); - if (result == null) { - needToCache = true; - if (this.missCounter != null) { - this.missCounter.inc(); - } - LOGGER.finest( - () -> Messages.get( - "DataCacheConnectionPlugin.queryResultsCached", - new Object[]{methodName, sql})); - } else { - if (this.hitCounter != null) { - this.hitCounter.inc(); - } - try { - result.beforeFirst(); - } catch (final SQLException ex) { - if (exceptionClass.isAssignableFrom(ex.getClass())) { - throw exceptionClass.cast(ex); - } - throw new RuntimeException(ex); - } - return resultClass.cast(result); - } - } - - result = (ResultSet) jdbcMethodFunc.call(); - - if (needToCache) { - final ResultSet cachedResultSet; - try { - cachedResultSet = new CachedResultSet(result); - dataCache.put(sql, cachedResultSet); - cachedResultSet.beforeFirst(); - return resultClass.cast(cachedResultSet); - } catch (final SQLException ex) { - // ignore exception - } - } - - return resultClass.cast(result); - } - - protected String getQuery(final Object[] jdbcMethodArgs) { - - // Get query from method argument - if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { - return jdbcMethodArgs[0].toString(); - } - return null; - } - - public static class CachedRow { - protected final HashMap columnByIndex = new HashMap<>(); - protected final HashMap columnByName = new HashMap<>(); - - public void put(final int columnIndex, final String columnName, final Object columnValue) { - columnByIndex.put(columnIndex, columnValue); - columnByName.put(columnName, columnValue); - } - - @SuppressWarnings("unused") - public Object get(final int columnIndex) { - return columnByIndex.get(columnIndex); - } - - @SuppressWarnings("unused") - public Object get(final String columnName) { - return columnByName.get(columnName); - } - } - - @SuppressWarnings({"RedundantThrows", "checkstyle:OverloadMethodsDeclarationOrder"}) - public static class CachedResultSet implements ResultSet { - - protected ArrayList rows; - protected int currentRow; - - public CachedResultSet(final ResultSet resultSet) throws SQLException { - - final ResultSetMetaData md = resultSet.getMetaData(); - final int columns = md.getColumnCount(); - rows = new ArrayList<>(); - - while (resultSet.next()) { - final CachedRow row = new CachedRow(); - for (int i = 1; i <= columns; ++i) { - row.put(i, md.getColumnName(i), resultSet.getObject(i)); - } - rows.add(row); - } - currentRow = -1; - } - - @Override - public boolean next() throws SQLException { - if (rows.size() == 0 || isLast()) { - return false; - } - currentRow++; - return true; - } - - @Override - public void close() throws SQLException { - currentRow = rows.size() - 1; - } - - @Override - public boolean wasNull() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public InputStream getUnicodeStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public InputStream getUnicodeStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLWarning getWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void clearWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getCursorName() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public ResultSetMetaData getMetaData() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final int columnIndex) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByIndex.containsKey(columnIndex)) { - return null; // column index out of boundaries - } - return row.columnByIndex.get(columnIndex); - } - - @Override - public Object getObject(final String columnLabel) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByName.containsKey(columnLabel)) { - return null; // column name not found - } - return row.columnByName.get(columnLabel); - } - - @Override - public int findColumn(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isBeforeFirst() throws SQLException { - return this.currentRow < 0; - } - - @Override - public boolean isAfterLast() throws SQLException { - return this.currentRow >= this.rows.size(); - } - - @Override - public boolean isFirst() throws SQLException { - return this.currentRow == 0 && this.rows.size() > 0; - } - - @Override - public boolean isLast() throws SQLException { - return this.currentRow == (this.rows.size() - 1) && this.rows.size() > 0; - } - - @Override - public void beforeFirst() throws SQLException { - this.currentRow = -1; - } - - @Override - public void afterLast() throws SQLException { - this.currentRow = this.rows.size(); - } - - @Override - public boolean first() throws SQLException { - this.currentRow = 0; - return this.currentRow < this.rows.size(); - } - - @Override - public boolean last() throws SQLException { - this.currentRow = this.rows.size() - 1; - return this.currentRow >= 0; - } - - @Override - public int getRow() throws SQLException { - return this.currentRow + 1; - } - - @Override - public boolean absolute(final int row) throws SQLException { - if (row > 0) { - this.currentRow = row - 1; - } else { - this.currentRow = this.rows.size() + row; - } - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public boolean relative(final int rows) throws SQLException { - this.currentRow += rows; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public boolean previous() throws SQLException { - this.currentRow--; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public void setFetchDirection(final int direction) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchDirection() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setFetchSize(final int rows) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchSize() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getType() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getConcurrency() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowUpdated() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowInserted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowDeleted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(final int columnIndex, final boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(final int columnIndex, final byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(final int columnIndex, final short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(final int columnIndex, final int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(final int columnIndex, final long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(final int columnIndex, final float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(final int columnIndex, final double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(final int columnIndex, final BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(final int columnIndex, final String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(final int columnIndex, final byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(final int columnIndex, final Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(final int columnIndex, final Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(final int columnIndex, final Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final int columnIndex, final Object x, final int scaleOrLength) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final int columnIndex, final Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(final String columnLabel, final boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(final String columnLabel, final byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(final String columnLabel, final short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(final String columnLabel, final int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(final String columnLabel, final long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(final String columnLabel, final float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(final String columnLabel, final double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(final String columnLabel, final BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(final String columnLabel, final String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(final String columnLabel, final byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(final String columnLabel, final Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(final String columnLabel, final Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(final String columnLabel, final Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final String columnLabel, final Object x, final int scaleOrLength) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final String columnLabel, final Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void insertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void deleteRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void refreshRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void cancelRowUpdates() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToInsertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToCurrentRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Statement getStatement() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final int columnIndex, final Map> map) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final String columnLabel, final Map> map) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(final int columnIndex, final Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(final String columnLabel, final Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(final int columnIndex, final Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(final String columnLabel, final Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(final int columnIndex, final RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(final String columnLabel, final RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getHoldability() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClosed() throws SQLException { - return false; - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNString(final int columnIndex, final String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNString(final String columnLabel, final String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNClob(final int columnIndex, final NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNClob(final String columnLabel, final NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings("checkstyle:MethodName") - public NClob getNClob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public NClob getNClob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(final int columnIndex, final SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final int columnIndex, final Reader x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final InputStream inputStream, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final InputStream inputStream, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final int columnIndex, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final int columnIndex, final Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final int columnIndex, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T getObject(final int columnIndex, final Class type) throws SQLException { - return type.cast(getObject(columnIndex)); - } - - @Override - public T getObject(final String columnLabel, final Class type) throws SQLException { - return type.cast(getObject(columnLabel)); - } - - @Override - public T unwrap(final Class iface) throws SQLException { - return iface == ResultSet.class ? iface.cast(this) : null; - } - - @Override - public boolean isWrapperFor(final Class iface) throws SQLException { - return iface != null && iface.isAssignableFrom(this.getClass()); - } - } -} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java new file mode 100644 index 000000000..1c2d1aac7 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -0,0 +1,427 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisCredentials; +import io.lettuce.core.RedisCredentialsProvider; +import io.lettuce.core.RedisURI; +import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.SetArgs; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.resource.ClientResources; +import reactor.core.publisher.Mono; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; +import java.util.logging.Logger; +import org.apache.commons.pool2.BasePooledObjectFactory; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.commons.pool2.PooledObject; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.authentication.AwsCredentialsManager; +import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility; +import software.amazon.jdbc.util.StringUtils; + +// Abstraction layer on top of a connection to a remote cache server +public class CacheConnection { + private static final Logger LOGGER = Logger.getLogger(CacheConnection.class.getName()); + + private static final int DEFAULT_POOL_MIN_IDLE = 0; + private static final int DEFAULT_MAX_POOL_SIZE = 200; + private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100; + private static final long TOKEN_CACHE_DURATION = 15 * 60 - 30; + + private static final ReentrantLock READ_LOCK = new ReentrantLock(); + private static final ReentrantLock WRITE_LOCK = new ReentrantLock(); + + private final String cacheRwServerAddr; // read-write cache server + private final String cacheRoServerAddr; // read-only cache server + private final String[] defaultCacheServerHostAndPort; + private MessageDigest msgHashDigest = null; + + protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = + new AwsWrapperProperty( + "cacheEndpointAddrRw", + null, + "The cache read-write server endpoint address."); + + private static final AwsWrapperProperty CACHE_RO_ENDPOINT_ADDR = + new AwsWrapperProperty( + "cacheEndpointAddrRo", + null, + "The cache read-only server endpoint address."); + + protected static final AwsWrapperProperty CACHE_USE_SSL = + new AwsWrapperProperty( + "cacheUseSSL", + "true", + "Whether to use SSL for cache connections."); + + protected static final AwsWrapperProperty CACHE_IAM_REGION = + new AwsWrapperProperty( + "cacheIamRegion", + null, + "AWS region for ElastiCache IAM authentication."); + + protected static final AwsWrapperProperty CACHE_USERNAME = + new AwsWrapperProperty( + "cacheUsername", + null, + "Username for ElastiCache regular authentication."); + + protected static final AwsWrapperProperty CACHE_PASSWORD = + new AwsWrapperProperty( + "cachePassword", + null, + "Password for ElastiCache regular authentication."); + + protected static final AwsWrapperProperty CACHE_NAME = + new AwsWrapperProperty( + "cacheName", + null, + "Explicit cache name for ElastiCache IAM authentication. "); + + protected static final AwsWrapperProperty CACHE_CONNECTION_TIMEOUT = + new AwsWrapperProperty( + "cacheConnectionTimeout", + "2000", + "Cache connection request timeout duration in milliseconds."); + + protected static final AwsWrapperProperty CACHE_CONNECTION_POOL_SIZE = + new AwsWrapperProperty( + "cacheConnectionPoolSize", + "20", + "Cache connection pool size."); + + // Adding support for read and write connection pools to the remote cache server + private static volatile GenericObjectPool> readConnectionPool; + private static volatile GenericObjectPool> writeConnectionPool; + private static final GenericObjectPoolConfig> poolConfig = createPoolConfig(); + + private final boolean useSSL; + private final boolean iamAuthEnabled; + private final String cacheIamRegion; + private final String cacheUsername; + private final String cacheName; + private final String cachePassword; + private final Duration cacheConnectionTimeout; + private final int cacheConnectionPoolSize; + private final Properties awsProfileProperties; + private final AwsCredentialsProvider credentialsProvider; + + static { + PropertyDefinition.registerPluginProperties(CacheConnection.class); + } + + public CacheConnection(final Properties properties) { + this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); + this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties); + this.useSSL = Boolean.parseBoolean(CACHE_USE_SSL.getString(properties)); + this.cacheName = CACHE_NAME.getString(properties); + this.cacheIamRegion = CACHE_IAM_REGION.getString(properties); + this.cacheUsername = CACHE_USERNAME.getString(properties); + this.cachePassword = CACHE_PASSWORD.getString(properties); + this.cacheConnectionTimeout = Duration.ofMillis(CACHE_CONNECTION_TIMEOUT.getInteger(properties)); + this.cacheConnectionPoolSize = CACHE_CONNECTION_POOL_SIZE.getInteger(properties); + if (this.cacheConnectionPoolSize <= 0 || this.cacheConnectionPoolSize > DEFAULT_MAX_POOL_SIZE) { + throw new IllegalArgumentException( + "Cache connection pool size must be within valid range: 1-" + DEFAULT_MAX_POOL_SIZE + ", but was: " + this.cacheConnectionPoolSize); + } + // Update the static poolConfig with user values + poolConfig.setMaxTotal(this.cacheConnectionPoolSize); + poolConfig.setMaxIdle(this.cacheConnectionPoolSize); + this.iamAuthEnabled = !StringUtils.isNullOrEmpty(this.cacheIamRegion); + boolean hasTraditionalAuth = !StringUtils.isNullOrEmpty(this.cachePassword); + // Validate authentication configuration + if (this.iamAuthEnabled && hasTraditionalAuth) { + throw new IllegalArgumentException( + "Cannot specify both IAM authentication (cacheIamRegion) and traditional authentication (cachePassword). Choose one authentication method."); + } + if (this.cacheRwServerAddr == null) { + throw new IllegalArgumentException("Cache endpoint address is required"); + } + this.defaultCacheServerHostAndPort = getHostnameAndPort(this.cacheRwServerAddr); + if (this.iamAuthEnabled) { + if (this.cacheUsername == null || this.defaultCacheServerHostAndPort[0] == null || this.cacheName == null) { + throw new IllegalArgumentException("IAM authentication requires cache name, username, region, and hostname"); + } + } + if (PropertyDefinition.AWS_PROFILE.getString(properties) != null) { + this.awsProfileProperties = new Properties(); + this.awsProfileProperties.setProperty( + PropertyDefinition.AWS_PROFILE.name, + PropertyDefinition.AWS_PROFILE.getString(properties) + ); + } else { + this.awsProfileProperties = null; + } + if (this.iamAuthEnabled) { + // Handle null case + Properties propsToPass = (this.awsProfileProperties != null) + ? this.awsProfileProperties + : new Properties(); + this.credentialsProvider = AwsCredentialsManager.getProvider(null, propsToPass); + } else { + this.credentialsProvider = null; + } + } + + /* Here we check if we need to initialise connection pool for read or write to cache. + With isRead we check if we need to initialise connection pool for read or write to cache. + If isRead is true, we initialise connection pool for read. + If isRead is false, we initialise connection pool for write. + */ + private void initializeCacheConnectionIfNeeded(boolean isRead) { + if (StringUtils.isNullOrEmpty(cacheRwServerAddr)) return; + // Initialize the message digest + if (msgHashDigest == null) { + try { + msgHashDigest = MessageDigest.getInstance("SHA-384"); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-384 not supported", e); + } + } + + GenericObjectPool> cacheConnectionPool = + isRead ? readConnectionPool : writeConnectionPool; + if (cacheConnectionPool == null) { + ReentrantLock connectionPoolLock = isRead ? READ_LOCK : WRITE_LOCK; + connectionPoolLock.lock(); + try { + if ((isRead && readConnectionPool == null) || (!isRead && writeConnectionPool == null)) { + createConnectionPool(isRead); + } + } finally { + connectionPoolLock.unlock(); + } + } + } + + private void createConnectionPool(boolean isRead) { + ClientResources resources = ClientResources.builder().build(); + try { + // cache server addr string is in the format ":" + String serverAddr = cacheRwServerAddr; + // If read-only server is specified, use it for the read-only connections + if (isRead && !StringUtils.isNullOrEmpty(cacheRoServerAddr)) { + serverAddr = cacheRoServerAddr; + } + String[] hostnameAndPort = getHostnameAndPort(serverAddr); + RedisURI redisUriCluster = buildRedisURI(hostnameAndPort[0], Integer.parseInt(hostnameAndPort[1])); + + RedisClient client = RedisClient.create(resources, redisUriCluster); + GenericObjectPool> pool = new GenericObjectPool<>( + new BasePooledObjectFactory>() { + public StatefulRedisConnection create() { + + StatefulRedisConnection connection = client.connect(new ByteArrayCodec()); + // In cluster mode, we need to send READONLY command to the server for reading from replica. + // Note: we gracefully ignore ERR reply to support non cluster mode. + if (isRead) { + try { + connection.sync().readOnly(); + } catch (RedisCommandExecutionException e) { + if (e.getMessage().contains("ERR This instance has cluster support disabled")) { + LOGGER.fine("------ Note: this cache cluster has cluster support disabled ------"); + } else { + throw e; + } + } + } + return connection; + } + public PooledObject> wrap(StatefulRedisConnection connection) { + return new DefaultPooledObject<>(connection); + } + }, poolConfig); + + if (isRead) { + readConnectionPool = pool; + } else { + writeConnectionPool = pool; + } + } catch (Exception e) { + String poolType = isRead ? "read" : "write"; + String errorMsg = String.format("Failed to create Cache %s connection pool", poolType); + LOGGER.warning(errorMsg + ": " + e.getMessage()); + throw new RuntimeException(errorMsg, e); + } + } + + private static GenericObjectPoolConfig> createPoolConfig() { + GenericObjectPoolConfig> poolConfig = new GenericObjectPoolConfig<>(); + poolConfig.setMinIdle(DEFAULT_POOL_MIN_IDLE); + poolConfig.setMaxWait(Duration.ofMillis(DEFAULT_MAX_BORROW_WAIT_MS)); + return poolConfig; + } + + // Get the hash digest of the given key. + private byte[] computeHashDigest(byte[] key) { + msgHashDigest.update(key); + return msgHashDigest.digest(); + } + + public byte[] readFromCache(String key) { + boolean isBroken = false; + StatefulRedisConnection conn = null; + // get a connection from the read connection pool + try { + initializeCacheConnectionIfNeeded(true); + conn = readConnectionPool.borrowObject(); + return conn.sync().get(computeHashDigest(key.getBytes(StandardCharsets.UTF_8))); + } catch (Exception e) { + if (conn != null) { + isBroken = true; + } + LOGGER.warning("Failed to read result from cache. Treating it as a cache miss: " + e.getMessage()); + return null; + } finally { + if (conn != null && readConnectionPool != null) { + try { + this.returnConnectionBackToPool(conn, isBroken, true); + } catch (Exception ex) { + LOGGER.warning("Error closing read connection: " + ex.getMessage()); + } + } + } + } + + protected void handleCompletedCacheWrite(StatefulRedisConnection conn, Throwable ex) { + // Note: this callback upon completion of cache write is on a different thread + if (ex != null) { + LOGGER.warning("Failed to write to cache: " + ex.getMessage()); + if (writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, true, false); + } catch (Exception e) { + LOGGER.warning("Error returning broken write connection back to pool: " + e.getMessage()); + } + } + } else { + if (writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, false, false); + } catch (Exception e) { + LOGGER.warning("Error returning write connection back to pool: " + e.getMessage()); + } + } + } + } + + public void writeToCache(String key, byte[] value, int expiry) { + StatefulRedisConnection conn = null; + try { + initializeCacheConnectionIfNeeded(false); + // get a connection from the write connection pool + conn = writeConnectionPool.borrowObject(); + // Write to the cache is async. + RedisAsyncCommands asyncCommands = conn.async(); + byte[] keyHash = computeHashDigest(key.getBytes(StandardCharsets.UTF_8)); + StatefulRedisConnection finalConn = conn; + asyncCommands.set(keyHash, value, SetArgs.Builder.ex(expiry)) + .whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, exception)); + } catch (Exception e) { + // Failed to trigger the async write to the cache, return the cache connection to the pool as broken + LOGGER.warning("Unable to start writing to cache: " + e.getMessage()); + if (conn != null && writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, true, false); + } catch (Exception ex) { + LOGGER.warning("Error closing write connection: " + ex.getMessage()); + } + } + } + } + + private void returnConnectionBackToPool(StatefulRedisConnection connection, boolean isConnectionBroken, boolean isRead) { + GenericObjectPool> pool = isRead ? readConnectionPool : writeConnectionPool; + if (isConnectionBroken) { + try { + pool.invalidateObject(connection); + } catch (Exception e) { + throw new RuntimeException("Could not invalidate connection for the pool", e); + } + } else { + pool.returnObject(connection); + } + } + + // Used for unit testing only + protected void setConnectionPools(GenericObjectPool> readPool, + GenericObjectPool> writePool) { + readConnectionPool = readPool; + writeConnectionPool = writePool; + } + + // Used for unit testing only + protected void triggerPoolInit(boolean isRead) { + initializeCacheConnectionIfNeeded(isRead); + } + + protected RedisURI buildRedisURI(String hostname, int port) { + RedisURI.Builder uriBuilder = RedisURI.Builder.redis(hostname) + .withPort(port) + .withSsl(useSSL) + .withVerifyPeer(false) + .withLibraryName("aws-sql-jdbc-lettuce") + .withTimeout(cacheConnectionTimeout); + + if (this.iamAuthEnabled) { + // Create a credentials provider that Lettuce will call whenever authentication is needed + RedisCredentialsProvider credentialsProvider = () -> { + // Create a cached token supplier that automatically refreshes tokens every 14.5 minutes + Supplier tokenSupplier = CachedSupplier.memoizeWithExpiration( + () -> { + ElastiCacheIamTokenUtility tokenUtility = new ElastiCacheIamTokenUtility(this.cacheName); + return tokenUtility.generateAuthenticationToken( + this.credentialsProvider, + Region.of(this.cacheIamRegion), + this.defaultCacheServerHostAndPort[0], + Integer.parseInt(this.defaultCacheServerHostAndPort[1]), + this.cacheUsername + ); + }, + TOKEN_CACHE_DURATION, + TimeUnit.SECONDS + ); + // Package the username and token (from cache or freshly generated) into Redis credentials + return Mono.just(RedisCredentials.just(this.cacheUsername, tokenSupplier.get())); + }; + uriBuilder.withAuthentication(credentialsProvider); + } else if (!StringUtils.isNullOrEmpty(this.cachePassword)) { + uriBuilder.withAuthentication(this.cacheUsername, this.cachePassword); + } + return uriBuilder.build(); + } + + private String[] getHostnameAndPort(String serverAddr) { + return serverAddr.split(":"); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java new file mode 100644 index 000000000..b4ecbfaf1 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -0,0 +1,1431 @@ +package software.amazon.jdbc.plugin.cache; + +import org.checkerframework.checker.nullness.qual.Nullable; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.IOException; +import java.io.Reader; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.net.MalformedURLException; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.time.ZonedDateTime; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.Calendar; +import java.util.TimeZone; + +public class CachedResultSet implements ResultSet { + + public static class CachedRow { + private final Object[] rowData; + final byte[] @Nullable [] rawData; + + public CachedRow(int numColumns) { + rowData = new Object[numColumns]; + rawData = new byte[numColumns][]; + } + + private void checkColumnIndex(final int columnIndex) throws SQLException { + if (columnIndex < 1 || columnIndex > rowData.length) { + throw new SQLException("Invalid Column Index when operating CachedRow: " + columnIndex); + } + } + + public void put(final int columnIndex, final Object columnValue) throws SQLException { + checkColumnIndex(columnIndex); + rowData[columnIndex-1] = columnValue; + } + + public void putRaw(final int columnIndex, final byte[] rawColumnValue) throws SQLException { + checkColumnIndex(columnIndex); + rawData[columnIndex-1] = rawColumnValue; + } + + public Object get(final int columnIndex) throws SQLException { + checkColumnIndex(columnIndex); + // De-serialize the data object from raw bytes if needed. + if (rowData[columnIndex-1] == null && rawData[columnIndex-1] != null) { + try (ByteArrayInputStream bis = new ByteArrayInputStream(rawData[columnIndex - 1]); + ObjectInputStream ois = new ObjectInputStream(bis)) { + rowData[columnIndex - 1] = ois.readObject(); + rawData[columnIndex - 1] = null; + } catch (ClassNotFoundException e) { + throw new SQLException("ClassNotFoundException while de-serializing caching resultSet for column: " + columnIndex, e); + } catch (IOException e) { + throw new SQLException("IOException while de-serializing caching resultSet for column: " + columnIndex, e); + } + } + return rowData[columnIndex - 1]; + } + } + + protected ArrayList rows; + protected int currentRow; + protected boolean wasNullFlag; + private final CachedResultSetMetaData metadata; + protected static final ZoneId defaultTimeZoneId = ZoneId.systemDefault(); + protected static final TimeZone defaultTimeZone = TimeZone.getDefault(); + private final HashMap columnNames; + private volatile boolean closed; + + /** + * Create a CachedResultSet out of the original ResultSet queried from the database. + * @param resultSet The ResultSet queried from the underlying database (not a CachedResultSet). + * @return CachedResultSet that captures the metadata and the rows of the input ResultSet. + * @throws SQLException + */ + public CachedResultSet(final ResultSet resultSet) throws SQLException { + ResultSetMetaData srcMetadata = resultSet.getMetaData(); + final int numColumns = srcMetadata.getColumnCount(); + CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[numColumns]; + for (int i = 0; i < numColumns; i++) { + fields[i] = new CachedResultSetMetaData.Field(srcMetadata, i+1); + } + metadata = new CachedResultSetMetaData(fields); + rows = new ArrayList<>(); + this.columnNames = new HashMap<>(); + for (int i = 1; i <= numColumns; i++) { + this.columnNames.put(srcMetadata.getColumnLabel(i), i); + } + while (resultSet.next()) { + final CachedRow row = new CachedRow(numColumns); + for (int i = 1; i <= numColumns; ++i) { + Object rowObj = resultSet.getObject(i); + // For SQLXML object, convert into CachedSQLXML object that is serializable + if (rowObj instanceof SQLXML) { + rowObj = new CachedSQLXML(((SQLXML)rowObj).getString()); + } + row.put(i, rowObj); + } + rows.add(row); + } + currentRow = -1; + closed = false; + wasNullFlag = false; + } + + private CachedResultSet(final CachedResultSetMetaData md, final ArrayList resultRows) throws SQLException { + int numColumns = md.getColumnCount(); + this.columnNames = new HashMap<>(); + for (int i = 1; i <= numColumns; i++) { + this.columnNames.put(md.getColumnLabel(i), i); + } + currentRow = -1; + rows = resultRows; + metadata = md; + closed = false; + wasNullFlag = false; + } + + // Serialize the content of metadata and data rows for the current CachedResultSet into a byte array + public byte[] serializeIntoByteArray() throws SQLException { + // Serialize the metadata and then the rows + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream output = new ObjectOutputStream(baos)) { + output.writeObject(metadata); + output.writeInt(rows.size()); + int numColumns = metadata.getColumnCount(); + while (this.next()) { + // serialize individual column fields in each row + CachedRow row = rows.get(currentRow); + for (int i = 0; i < numColumns; i++) { + try (ByteArrayOutputStream objBytes = new ByteArrayOutputStream(); + ObjectOutputStream objStream = new ObjectOutputStream(objBytes)) { + objStream.writeObject(row.get(i + 1)); + objStream.flush(); + byte[] dataByteArray = objBytes.toByteArray(); + int serializedLength = dataByteArray.length; + output.writeInt(serializedLength); + output.write(dataByteArray, 0, serializedLength); + } + } + } + output.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new SQLException("Error while serializing the ResultSet for caching: ", e); + } + } + + /** + * Form a ResultSet from the raw data from the cache server. Each of the column objects are stored as + * raw bytes and the actual de-serialization into Java objects will happen lazily upon access later on. + */ + public static ResultSet deserializeFromByteArray(byte[] data) throws SQLException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(data); + ObjectInputStream ois = new ObjectInputStream(bis)) { + CachedResultSetMetaData metadata = (CachedResultSetMetaData) ois.readObject(); + int numRows = ois.readInt(); + int numColumns = metadata.getColumnCount(); + ArrayList resultRows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + // Store the raw bytes for each column object in CachedRow + final CachedRow row = new CachedRow(numColumns); + for(int j = 0; j < numColumns; j++) { + int nextObjSize = ois.readInt(); // The size of the next serialized object in its raw bytes form + byte[] objData = new byte[nextObjSize]; + int lengthRead = 0; + while (lengthRead < nextObjSize) { + int bytesRead = ois.read(objData, lengthRead, nextObjSize-lengthRead); + if (bytesRead == -1) { + throw new SQLException("End of stream reached when reading the data for CachedResultSet"); + } + lengthRead += bytesRead; + } + row.putRaw(j+1, objData); + } + resultRows.add(row); + } + return new CachedResultSet(metadata, resultRows); + } catch (ClassNotFoundException e) { + throw new SQLException("ClassNotFoundException while de-serializing resultSet for caching", e); + } catch (IOException e) { + throw new SQLException("IOException while de-serializing resultSet for caching", e); + } + } + + @Override + public boolean next() throws SQLException { + if (rows.isEmpty()) return false; + if (this.currentRow >= rows.size() - 1) { + afterLast(); + return false; + } + currentRow++; + return true; + } + + @Override + public void close() throws SQLException { + currentRow = rows.size() - 1; + closed = true; + } + + @Override + public boolean wasNull() throws SQLException { + if (isClosed()) { + throw new SQLException("This result set is closed"); + } + return this.wasNullFlag; + } + + @Override + public String getString(final int columnIndex) throws SQLException { + Object value = checkAndGetColumnValue(columnIndex); + if (value == null) return null; + return value.toString(); + } + + @Override + public boolean getBoolean(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return false; + if (val instanceof Boolean) return (Boolean) val; + if (val instanceof Number) return ((Number) val).intValue() != 0; + return Boolean.parseBoolean(val.toString()); + } + + @Override + public byte getByte(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Byte) return (Byte) val; + if (val instanceof Number) return ((Number) val).byteValue(); + return Byte.parseByte(val.toString()); + } + + @Override + public short getShort(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Short) return (Short) val; + if (val instanceof Number) return ((Number) val).shortValue(); + return Short.parseShort(val.toString()); + } + + @Override + public int getInt(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Integer) return (Integer) val; + if (val instanceof Number) return ((Number) val).intValue(); + return Integer.parseInt(val.toString()); + } + + @Override + public long getLong(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Long) return (Long) val; + if (val instanceof Number) return ((Number) val).longValue(); + return Long.parseLong(val.toString()); + } + + @Override + public float getFloat(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Float) return (Float) val; + if (val instanceof Number) return ((Number) val).floatValue(); + return Float.parseFloat(val.toString()); + } + + @Override + public double getDouble(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Double) return (Double) val; + if (val instanceof Number) return ((Number) val).doubleValue(); + return Double.parseDouble(val.toString()); + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof BigDecimal) return (BigDecimal) val; + if (val instanceof Number) return new BigDecimal(((Number)val).doubleValue()).setScale(scale, RoundingMode.HALF_UP); + return new BigDecimal(Double.parseDouble(val.toString())).setScale(scale, RoundingMode.HALF_UP); + } + + @Override + public byte[] getBytes(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof byte[]) return (byte[]) val; + // Convert non-byte data to string, then to bytes (standard JDBC behavior) + return val.toString().getBytes(); + } + + private Date convertToDate(Object dateObj, Calendar cal) throws SQLException { + if (dateObj == null) return null; + if (dateObj instanceof Date) return (Date)dateObj; + if (dateObj instanceof Number) return new Date(((Number)dateObj).longValue()); + if (dateObj instanceof LocalDate) { + // Convert the LocalDate for the specified time zone into Date representing + // the same instant of time for the default time zone. + LocalDate localDate = (LocalDate)dateObj; + if (cal == null) return Date.valueOf(localDate); + LocalDateTime localDateTime = localDate.atStartOfDay(); + ZonedDateTime originalZonedDateTime = localDateTime.atZone(cal.getTimeZone().toZoneId()); + ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); + return Date.valueOf(targetZonedDateTime.toLocalDate()); + } + if (dateObj instanceof Timestamp) { + Timestamp timestamp = (Timestamp) dateObj; + long millis = timestamp.getTime(); + if (cal == null) return new Date(millis); + long adjustedMillis = millis - cal.getTimeZone().getOffset(millis) + + defaultTimeZone.getOffset(millis); + return new Date(adjustedMillis); + } + + // Note: normally the user should properly store the Date object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Date already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Date. We try to do a + // best-effort string parsing into Date with standard format "YYYY-MM-DD". The user is then + // expected to handle parsing failure and implement custom logic to fetch this as String. + return Date.valueOf(dateObj.toString()); + } + + @Override + public Date getDate(final int columnIndex) throws SQLException { + // The value cached is the string representation of epoch time in milliseconds + return convertToDate(checkAndGetColumnValue(columnIndex), null); + } + + private Time convertToTime(Object timeObj, Calendar cal) throws SQLException { + if (timeObj == null) return null; + if (timeObj instanceof Time) return (Time) timeObj; + if (timeObj instanceof Number) return new Time(((Number)timeObj).longValue()); // TODO: test + if (timeObj instanceof LocalTime) { + // Convert the LocalTime for the specified time zone into Time representing + // the same instant of time for the default time zone. + LocalTime localTime = (LocalTime)timeObj; + if (cal == null) return Time.valueOf(localTime); + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.now(), localTime); + ZonedDateTime originalZonedDateTime = localDateTime.atZone(cal.getTimeZone().toZoneId()); + ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); + return Time.valueOf(targetZonedDateTime.toLocalTime()); + } + if (timeObj instanceof OffsetTime) { + OffsetTime localTime = ((OffsetTime)timeObj).withOffsetSameInstant(OffsetDateTime.now().getOffset()); + return Time.valueOf(localTime.toLocalTime()); + } + if (timeObj instanceof Timestamp) { + Timestamp timestamp = (Timestamp) timeObj; + long millis = timestamp.getTime(); + if (cal == null) return new Time(millis); + long adjustedMillis = millis - cal.getTimeZone().getOffset(millis) + + defaultTimeZone.getOffset(millis); + return new Time(adjustedMillis); + } + + // Note: normally the user should properly store the Time object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Time already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Time. We try to do a + // best-effort string parsing into Time with standard format "HH:MM:SS". The user is then + // expected to handle parsing failure and implement custom logic to fetch this as String. + return Time.valueOf(timeObj.toString()); + } + + @Override + public Time getTime(final int columnIndex) throws SQLException { + return convertToTime(checkAndGetColumnValue(columnIndex), null); + } + + private Timestamp convertToTimestamp(Object timestampObj, Calendar calendar) { + if (timestampObj == null) return null; + if (timestampObj instanceof Timestamp) return (Timestamp) timestampObj; + if (timestampObj instanceof Number) return new Timestamp(((Number)timestampObj).longValue()); + if (timestampObj instanceof LocalDateTime) { + // Convert LocalDateTime based on the specified calendar time zone info into a + // Timestamp based on the JVM's default time zone representing the same instant + long epochTimeInMillis; + LocalDateTime localTime = (LocalDateTime)timestampObj; + if (calendar != null) { + epochTimeInMillis = localTime.atZone(calendar.getTimeZone().toZoneId()).toInstant().toEpochMilli(); + } else { + epochTimeInMillis = localTime.atZone(defaultTimeZoneId).toInstant().toEpochMilli(); + } + return new Timestamp(epochTimeInMillis); + } + if (timestampObj instanceof OffsetDateTime) { + return Timestamp.from(((OffsetDateTime)timestampObj).toInstant()); + } + if (timestampObj instanceof ZonedDateTime) { + return Timestamp.from(((ZonedDateTime)timestampObj).toInstant()); + } + + // Note: normally the user should properly store the Timestamp/DateTime object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Timestamp already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Timestamp. We try to do a + // best-effort string parsing into Timestamp with standard format "YYYY-MM-DD HH:MM:SS". The user is + // then expected to handle parsing failure and implement custom logic to fetch this as String. + return Timestamp.valueOf(timestampObj.toString()); + } + + @Override + public Timestamp getTimestamp(final int columnIndex) throws SQLException { + return convertToTimestamp(checkAndGetColumnValue(columnIndex), null); + } + + @Override + public InputStream getAsciiStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(final String columnLabel) throws SQLException { + return getString(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public boolean getBoolean(final String columnLabel) throws SQLException { + return getBoolean(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public byte getByte(final String columnLabel) throws SQLException { + return getByte(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public short getShort(final String columnLabel) throws SQLException { + return getShort(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public int getInt(final String columnLabel) throws SQLException { + return getInt(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public long getLong(final String columnLabel) throws SQLException { + return getLong(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public float getFloat(final String columnLabel) throws SQLException { + return getFloat(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public double getDouble(final String columnLabel) throws SQLException { + return getDouble(checkAndGetColumnIndex(columnLabel)); + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { + return getBigDecimal(checkAndGetColumnIndex(columnLabel), scale); + } + + @Override + public byte[] getBytes(final String columnLabel) throws SQLException { + return getBytes(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public Date getDate(final String columnLabel) throws SQLException { + return getDate(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public Time getTime(final String columnLabel) throws SQLException { + return getTime(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public Timestamp getTimestamp(final String columnLabel) throws SQLException { + return getTimestamp(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public InputStream getAsciiStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return null; + } + + @Override + public void clearWarnings() throws SQLException { + // no-op + } + + @Override + public String getCursorName() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + return metadata; + } + + private void checkCurrentRow() throws SQLException { + if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { + throw new SQLException("The current row index " + this.currentRow + " is out of range."); + } + } + + @Override + public Object getObject(final int columnIndex) throws SQLException { + checkCurrentRow(); + return checkAndGetColumnValue(columnIndex); + } + + @Override + public Object getObject(final String columnLabel) throws SQLException { + checkCurrentRow(); + return checkAndGetColumnValue(checkAndGetColumnIndex(columnLabel)); + } + + // Check the column index passed in is proper, and return the value of the column from the current row + private Object checkAndGetColumnValue(final int columnIndex) throws SQLException { + if (columnIndex == 0 || columnIndex > this.columnNames.size()) throw new SQLException("Column out of bounds"); + final CachedRow row = this.rows.get(this.currentRow); + final Object val = row.get(columnIndex); + this.wasNullFlag = (val == null); + return val; + } + + // Check column label exists and returns the column index corresponding to the column name + private int checkAndGetColumnIndex(final String columnLabel) throws SQLException { + final Integer colIndex = columnNames.get(columnLabel); + if (colIndex == null) throw new SQLException("Column not found: " + columnLabel); + return colIndex; + } + + @Override + public int findColumn(final String columnLabel) throws SQLException { + final Integer colIndex = columnNames.get(columnLabel); + if (colIndex == null) { + throw new SQLException("The column " + columnLabel + " is not found in this ResultSet."); + } + return colIndex; + } + + @Override + public Reader getCharacterStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof BigDecimal) return (BigDecimal) val; + if (val instanceof Number) return BigDecimal.valueOf(((Number) val).doubleValue()); + return new BigDecimal(val.toString()); + } + + @Override + public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { + return getBigDecimal(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public boolean isBeforeFirst() throws SQLException { + return this.currentRow < 0; + } + + @Override + public boolean isAfterLast() throws SQLException { + return this.currentRow >= this.rows.size(); + } + + @Override + public boolean isFirst() throws SQLException { + return this.currentRow == 0 && !this.rows.isEmpty(); + } + + @Override + public boolean isLast() throws SQLException { + return this.currentRow == (this.rows.size() - 1) && !this.rows.isEmpty(); + } + + @Override + public void beforeFirst() throws SQLException { + this.currentRow = -1; + } + + @Override + public void afterLast() throws SQLException { + this.currentRow = this.rows.size(); + } + + @Override + public boolean first() throws SQLException { + this.currentRow = 0; + return this.currentRow < this.rows.size(); + } + + @Override + public boolean last() throws SQLException { + this.currentRow = this.rows.size() - 1; + return this.currentRow >= 0; + } + + @Override + public int getRow() throws SQLException { + if (this.currentRow >= 0 && this.currentRow < this.rows.size()) { + return this.currentRow + 1; + } + return 0; + } + + @Override + public boolean absolute(final int row) throws SQLException { + if (row == 0) { + this.beforeFirst(); + return false; + } else { + int rowsSize = this.rows.size(); + if (row < 0) { + if (row < -rowsSize) { + this.beforeFirst(); + return false; + } + this.currentRow = rowsSize + row; + } else { // row > 0 + if (row > rowsSize) { + this.afterLast(); + return false; + } + this.currentRow = row - 1; + } + } + return true; + } + + @Override + public boolean relative(final int rows) throws SQLException { + this.currentRow += rows; + if (this.currentRow < 0) { + this.beforeFirst(); + return false; + } else if (this.currentRow >= this.rows.size()) { + this.afterLast(); + return false; + } + return true; + } + + @Override + public boolean previous() throws SQLException { + if (this.currentRow < 1) { + this.beforeFirst(); + return false; + } + this.currentRow--; + return true; + } + + @Override + public void setFetchDirection(final int direction) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchDirection() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setFetchSize(final int rows) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchSize() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getType() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getConcurrency() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowUpdated() throws SQLException { + return false; + } + + @Override + public boolean rowInserted() throws SQLException { + return false; + } + + @Override + public boolean rowDeleted() throws SQLException { + return false; + } + + @Override + public void updateNull(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(final int columnIndex, final boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(final int columnIndex, final byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(final int columnIndex, final short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(final int columnIndex, final int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(final int columnIndex, final long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(final int columnIndex, final float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(final int columnIndex, final double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(final int columnIndex, final BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(final int columnIndex, final String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(final int columnIndex, final byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(final int columnIndex, final Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(final int columnIndex, final Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(final int columnIndex, final Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final int columnIndex, final Object x, final int scaleOrLength) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final int columnIndex, final Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNull(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(final String columnLabel, final boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(final String columnLabel, final byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(final String columnLabel, final short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(final String columnLabel, final int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(final String columnLabel, final long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(final String columnLabel, final float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(final String columnLabel, final double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(final String columnLabel, final BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(final String columnLabel, final String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(final String columnLabel, final byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(final String columnLabel, final Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(final String columnLabel, final Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(final String columnLabel, final Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final String columnLabel, final Object x, final int scaleOrLength) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final String columnLabel, final Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void insertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void refreshRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelRowUpdates() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToInsertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToCurrentRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Statement getStatement() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final int columnIndex, final Map> map) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final String columnLabel, final Map> map) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { + return convertToDate(checkAndGetColumnValue(columnIndex), cal); + } + + @Override + public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { + return getDate(checkAndGetColumnIndex(columnLabel), cal); + } + + @Override + public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { + return convertToTime(checkAndGetColumnValue(columnIndex), cal); + } + + @Override + public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { + return getTime(checkAndGetColumnIndex(columnLabel), cal); + } + + @Override + public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { + return convertToTimestamp(checkAndGetColumnValue(columnIndex), cal); + } + + @Override + public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { + return getTimestamp(checkAndGetColumnIndex(columnLabel), cal); + } + + @Override + public URL getURL(final int columnIndex) throws SQLException { + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof URL) return (URL) val; + try { + return new URL(val.toString()); + } catch (MalformedURLException e) { + throw new SQLException("Cannot extract url: " + val, e); + } + } + + @Override + public URL getURL(final String columnLabel) throws SQLException { + return getURL(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public void updateRef(final int columnIndex, final Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRef(final String columnLabel, final Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(final int columnIndex, final Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(final String columnLabel, final Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public RowId getRowId(final int columnIndex) throws SQLException { + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof RowId) return (RowId) val; + throw new SQLException("Cannot extract rowId: " + val); + } + + @Override + public RowId getRowId(final String columnLabel) throws SQLException { + return getRowId(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public void updateRowId(final int columnIndex, final RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRowId(final String columnLabel, final RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getHoldability() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClosed() throws SQLException { + return closed; + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNString(final int columnIndex, final String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNString(final String columnLabel, final String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNClob(final int columnIndex, final NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNClob(final String columnLabel, final NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings("checkstyle:MethodName") + public NClob getNClob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public NClob getNClob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML getSQLXML(final int columnIndex) throws SQLException { + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof SQLXML) return (SQLXML) val; + return new CachedSQLXML(val.toString()); + } + + @Override + public SQLXML getSQLXML(final String columnLabel) throws SQLException { + return getSQLXML(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public void updateSQLXML(final int columnIndex, final SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getNString(final int columnIndex) throws SQLException { + return getString(columnIndex); + } + + @Override + public String getNString(final String columnLabel) throws SQLException { + return getString(columnLabel); + } + + @Override + public Reader getNCharacterStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getNCharacterStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final int columnIndex, final Reader x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final InputStream inputStream, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final InputStream inputStream, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final int columnIndex, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final int columnIndex, final Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final int columnIndex, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getObject(final int columnIndex, final Class type) throws SQLException { + return type.cast(getObject(columnIndex)); + } + + @Override + public T getObject(final String columnLabel, final Class type) throws SQLException { + return type.cast(getObject(columnLabel)); + } + + @Override + public T unwrap(final Class iface) throws SQLException { + if (iface.isAssignableFrom(this.getClass())) { + return iface.cast(this); + } else { + throw new SQLException("Cannot unwrap to " + iface.getName()); + } + } + + @Override + public boolean isWrapperFor(final Class iface) throws SQLException { + return iface != null && iface.isAssignableFrom(this.getClass()); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java new file mode 100644 index 000000000..bf295cb6b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java @@ -0,0 +1,180 @@ +package software.amazon.jdbc.plugin.cache; + +import java.io.Serializable; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; + +class CachedResultSetMetaData implements ResultSetMetaData, Serializable { + protected final Field[] columns; + + protected static class Field implements Serializable { + String catalog; + String className; + String label; + String name; + String typeName; + int type; + int displaySize; + int precision; + String tableName; + int scale; + String schemaName; + boolean isAutoIncrement; + boolean isCaseSensitive; + boolean isCurrency; + boolean isDefinitelyWritable; + int isNullable; + boolean isReadOnly; + boolean isSearchable; + boolean isSigned; + boolean isWritable; + + protected Field(final ResultSetMetaData srcMetadata, int column) throws SQLException { + catalog = srcMetadata.getCatalogName(column); + className = srcMetadata.getColumnClassName(column); + label = srcMetadata.getColumnLabel(column); + name = srcMetadata.getColumnName(column); + typeName = srcMetadata.getColumnTypeName(column); + type = srcMetadata.getColumnType(column); + displaySize = srcMetadata.getColumnDisplaySize(column); + precision = srcMetadata.getPrecision(column); + tableName = srcMetadata.getTableName(column); + scale = srcMetadata.getScale(column); + schemaName = srcMetadata.getSchemaName(column); + isAutoIncrement = srcMetadata.isAutoIncrement(column); + isCaseSensitive = srcMetadata.isCaseSensitive(column); + isCurrency = srcMetadata.isCurrency(column); + isDefinitelyWritable = srcMetadata.isDefinitelyWritable(column); + isNullable = srcMetadata.isNullable(column); + isReadOnly = srcMetadata.isReadOnly(column); + isSearchable = srcMetadata.isSearchable(column); + isSigned = srcMetadata.isSigned(column); + isWritable = srcMetadata.isWritable(column); + } + } + + CachedResultSetMetaData(Field[] columns) { + this.columns = columns; + } + + @Override + public int getColumnCount() throws SQLException { + return columns.length; + } + + private Field getColumns(final int column) throws SQLException { + if (column == 0 || column > columns.length) + throw new SQLException("Wrong column number: " + column); + return columns[column - 1]; + } + + @Override + public boolean isAutoIncrement(int column) throws SQLException { + return getColumns(column).isAutoIncrement; + } + + @Override + public boolean isCaseSensitive(int column) throws SQLException { + return getColumns(column).isCaseSensitive; + } + + @Override + public boolean isSearchable(int column) throws SQLException { + return getColumns(column).isSearchable; + } + + @Override + public boolean isCurrency(int column) throws SQLException { + return getColumns(column).isCurrency; + } + + @Override + public int isNullable(int column) throws SQLException { + return getColumns(column).isNullable; + } + + @Override + public boolean isSigned(int column) throws SQLException { + return getColumns(column).isSigned; + } + + @Override + public int getColumnDisplaySize(int column) throws SQLException { + return getColumns(column).displaySize; + } + + @Override + public String getColumnLabel(int column) throws SQLException { + return getColumns(column).label; + } + + @Override + public String getColumnName(int column) throws SQLException { + return getColumns(column).name; + } + + @Override + public String getSchemaName(int column) throws SQLException { + return getColumns(column).schemaName; + } + + @Override + public int getPrecision(int column) throws SQLException { + return getColumns(column).precision; + } + + @Override + public int getScale(int column) throws SQLException { + return getColumns(column).scale; + } + + @Override + public String getTableName(int column) throws SQLException { + return getColumns(column).tableName; + } + + @Override + public String getCatalogName(int column) throws SQLException { + return getColumns(column).catalog; + } + + @Override + public int getColumnType(int column) throws SQLException { + return getColumns(column).type; + } + + @Override + public String getColumnTypeName(int column) throws SQLException { + return getColumns(column).typeName; + } + + @Override + public boolean isReadOnly(int column) throws SQLException { + return getColumns(column).isReadOnly; + } + + @Override + public boolean isWritable(int column) throws SQLException { + return getColumns(column).isWritable; + } + + @Override + public boolean isDefinitelyWritable(int column) throws SQLException { + return getColumns(column).isDefinitelyWritable; + } + + @Override + public String getColumnClassName(int column) throws SQLException { + return getColumns(column).className; + } + + @Override + public T unwrap(Class iface) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + throw new UnsupportedOperationException(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java new file mode 100644 index 000000000..a49240172 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java @@ -0,0 +1,118 @@ +package software.amazon.jdbc.plugin.cache; + +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.XMLReaderFactory; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Result; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Reader; +import java.io.Serializable; +import java.io.StringReader; +import java.io.Writer; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.sql.SQLXML; + +public class CachedSQLXML implements SQLXML, Serializable { + private boolean freed; + private String data; + + public CachedSQLXML(String data) { + this.data = data; + this.freed = false; + } + + @Override + public void free() throws SQLException { + if (this.freed) return; + this.data = null; + this.freed = true; + } + + private void checkFreed() throws SQLException { + if (this.freed) { + throw new SQLException("This SQLXML object has already been freed."); + } + } + + @Override + public InputStream getBinaryStream() throws SQLException { + checkFreed(); + if (this.data == null) return null; + return new ByteArrayInputStream(this.data.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public OutputStream setBinaryStream() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream() throws SQLException { + checkFreed(); + if (this.data == null) return null; + return new StringReader(this.data); + } + + @Override + public Writer setCharacterStream() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString() throws SQLException { + checkFreed(); + return this.data; + } + + @Override + public void setString(String value) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getSource(Class sourceClass) throws SQLException { + checkFreed(); + if (this.data == null) return null; + + try { + if (sourceClass == null || DOMSource.class.equals(sourceClass)) { + DocumentBuilder builder = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + return (T) new DOMSource(builder.parse(new InputSource(new StringReader(data)))); + } + + if (SAXSource.class.equals(sourceClass)) { + XMLReader reader = XMLReaderFactory.createXMLReader(); + return sourceClass.cast(new SAXSource(reader, new InputSource(new StringReader(data)))); + } + + if (StreamSource.class.equals(sourceClass)) { + return sourceClass.cast(new StreamSource(new StringReader(data))); + } + + if (StAXSource.class.equals(sourceClass)) { + XMLStreamReader xsr = XMLInputFactory.newFactory().createXMLStreamReader(new StringReader(data)); + return sourceClass.cast(new StAXSource(xsr)); + } + throw new SQLException("Unsupported source class for XML data: " + sourceClass.getName()); + } catch (Exception e) { + throw new SQLException("Unable to decode XML data.", e); + } + } + + @Override + public T setResult(Class resultClass) throws SQLException { + throw new UnsupportedOperationException(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java new file mode 100644 index 000000000..ac3d505e1 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; + +public final class CachedSupplier { + + private CachedSupplier() { + throw new UnsupportedOperationException("Utility class should not be instantiated"); + } + + public static Supplier memoizeWithExpiration( + Supplier delegate, long duration, TimeUnit unit) { + + Objects.requireNonNull(delegate, "delegate Supplier must not be null"); + Objects.requireNonNull(unit, "TimeUnit must not be null"); + if (duration <= 0) { + throw new IllegalArgumentException("duration must be > 0"); + } + + return new ExpiringMemoizingSupplier<>(delegate, duration, unit); + } + + private static final class ExpiringMemoizingSupplier implements Supplier { + + private final Supplier delegate; + private final long durationNanos; + private final ReentrantLock lock = new ReentrantLock(); + + private volatile T value; + private volatile long expirationNanos; // 0 means not yet initialized + + ExpiringMemoizingSupplier(Supplier delegate, long duration, TimeUnit unit) { + this.delegate = delegate; + this.durationNanos = unit.toNanos(duration); + } + + @Override + public T get() { + long now = System.nanoTime(); + + // Check if value is expired or uninitialized + if (expirationNanos == 0 || now - expirationNanos >= 0) { + lock.lock(); + try { + if (expirationNanos == 0 || now - expirationNanos >= 0) { + value = delegate.get(); + long next = now + durationNanos; + expirationNanos = (next == 0) ? 1 : next; // avoid 0 sentinel + } + } finally { + lock.unlock(); + } + } + return value; + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java new file mode 100644 index 000000000..3d4743fea --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java @@ -0,0 +1,167 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.logging.Logger; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +public class DataLocalCacheConnectionPlugin extends AbstractConnectionPlugin { + + private static final Logger LOGGER = Logger.getLogger(DataLocalCacheConnectionPlugin.class.getName()); + + private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( + Arrays.asList( + JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.STATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName + ))); + + public static final AwsWrapperProperty DATA_CACHE_TRIGGER_CONDITION = new AwsWrapperProperty( + "dataCacheTriggerCondition", "false", + "A regular expression that, if it's matched, allows the plugin to cache SQL results."); + + protected static final Map dataCache = new ConcurrentHashMap<>(); + + protected final String dataCacheTriggerCondition; + + static { + PropertyDefinition.registerPluginProperties(DataLocalCacheConnectionPlugin.class); + } + + private final TelemetryFactory telemetryFactory; + private final TelemetryCounter hitCounter; + private final TelemetryCounter missCounter; + private final TelemetryCounter totalCallsCounter; + private final TelemetryGauge cacheSizeGauge; + + public DataLocalCacheConnectionPlugin(final PluginService pluginService, final Properties props) { + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.dataCacheTriggerCondition = DATA_CACHE_TRIGGER_CONDITION.getString(props); + + this.hitCounter = telemetryFactory.createCounter("dataCache.cache.hit"); + this.missCounter = telemetryFactory.createCounter("dataCache.cache.miss"); + this.totalCallsCounter = telemetryFactory.createCounter("dataCache.cache.totalCalls"); + this.cacheSizeGauge = telemetryFactory.createGauge("dataCache.cache.size", () -> (long) dataCache.size()); + } + + public static void clearCache() { + dataCache.clear(); + } + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + @Override + public T execute( + final Class resultClass, + final Class exceptionClass, + final Object methodInvokeOn, + final String methodName, + final JdbcCallable jdbcMethodFunc, + final Object[] jdbcMethodArgs) + throws E { + + if (StringUtils.isNullOrEmpty(this.dataCacheTriggerCondition) || resultClass != ResultSet.class) { + return jdbcMethodFunc.call(); + } + + if (this.totalCallsCounter != null) { + this.totalCallsCounter.inc(); + } + + ResultSet result; + boolean needToCache = false; + final String sql = getQuery(jdbcMethodArgs); + + if (!StringUtils.isNullOrEmpty(sql) && sql.matches(this.dataCacheTriggerCondition)) { + result = dataCache.get(sql); + if (result == null) { + needToCache = true; + if (this.missCounter != null) { + this.missCounter.inc(); + } + LOGGER.finest( + () -> Messages.get( + "DataLocalCacheConnectionPlugin.queryResultsCached", + new Object[]{methodName, sql})); + } else { + if (this.hitCounter != null) { + this.hitCounter.inc(); + } + try { + result.beforeFirst(); + } catch (final SQLException ex) { + if (exceptionClass.isAssignableFrom(ex.getClass())) { + throw exceptionClass.cast(ex); + } + throw new RuntimeException(ex); + } + return resultClass.cast(result); + } + } + + result = (ResultSet) jdbcMethodFunc.call(); + + if (needToCache) { + final ResultSet cachedResultSet; + try { + cachedResultSet = new CachedResultSet(result); + dataCache.put(sql, cachedResultSet); + cachedResultSet.beforeFirst(); + return resultClass.cast(cachedResultSet); + } catch (final SQLException ex) { + // ignore exception + } + } + + return resultClass.cast(result); + } + + protected String getQuery(final Object[] jdbcMethodArgs) { + + // Get query from method argument + if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { + return jdbcMethodArgs[0].toString(); + } + return null; + } + +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java new file mode 100644 index 000000000..c28e03e89 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java @@ -0,0 +1,30 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.util.Properties; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; + +public class DataLocalCacheConnectionPluginFactory implements ConnectionPluginFactory { + + @Override + public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { + return new DataLocalCacheConnectionPlugin(pluginService, props); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java new file mode 100644 index 000000000..476c129d5 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -0,0 +1,365 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class DataRemoteCachePlugin extends AbstractConnectionPlugin { + private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); + private static final String QUERY_HINT_START_PATTERN = "/*+"; + private static final String QUERY_HINT_END_PATTERN = "*/"; + private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM("; + private static final String TELEMETRY_CACHE_LOOKUP = "jdbc-cache-lookup"; + private static final String TELEMETRY_DATABASE_QUERY = "jdbc-database-query"; + private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( + Arrays.asList(JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.STATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName))); + + private int maxCacheableQuerySize; + private PluginService pluginService; + private TelemetryFactory telemetryFactory; + private TelemetryCounter cacheHitCounter; + private TelemetryCounter cacheMissCounter; + private TelemetryCounter totalQueryCounter; + private TelemetryCounter malformedHintCounter; + private TelemetryCounter cacheBypassCounter; + private CacheConnection cacheConnection; + private String dbUserName; + + private static final AwsWrapperProperty CACHE_MAX_QUERY_SIZE = + new AwsWrapperProperty( + "cacheMaxQuerySize", + "16384", + "The max query size for remote caching"); + + static { + PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); + } + + public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { + try { + Class.forName("io.lettuce.core.RedisClient"); // Lettuce dependency + Class.forName("org.apache.commons.pool2.impl.GenericObjectPool"); // Object pool dependency + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("DataRemoteCachePlugin.notInClassPath", new Object[] {e.getMessage()})); + } + this.pluginService = pluginService; + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.cacheHitCounter = telemetryFactory.createCounter("JdbcCachedQueryCount"); + this.cacheMissCounter = telemetryFactory.createCounter("JdbcCacheMissCount"); + this.totalQueryCounter = telemetryFactory.createCounter("JdbcCacheTotalQueryCount"); + this.malformedHintCounter = telemetryFactory.createCounter("JdbcCacheMalformedQueryHint"); + this.cacheBypassCounter = telemetryFactory.createCounter("JdbcCacheBypassCount"); + this.maxCacheableQuerySize = CACHE_MAX_QUERY_SIZE.getInteger(properties); + this.cacheConnection = new CacheConnection(properties); + this.dbUserName = PropertyDefinition.USER.getString(properties); + } + + // Used for unit testing purposes only + protected void setCacheConnection(CacheConnection conn) { + this.cacheConnection = conn; + } + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + private String getCacheQueryKey(String query) { + // Check some basic session states. The important ones for caching include (but not limited to): + // schema name, username which can affect the query result from the DB in addition to the query string + try { + Connection currentConn = pluginService.getCurrentConnection(); + DatabaseMetaData metadata = currentConn.getMetaData(); + // Fetch and record the schema name if the session state doesn't currently have it + SessionStateService sessionStateService = pluginService.getSessionStateService(); + String catalog = sessionStateService.getCatalog().orElse(null); + String schema = sessionStateService.getSchema().orElse(null); + if (catalog == null && schema == null) { + // Fetch the current schema name and store it in sessionStateService + catalog = currentConn.getCatalog(); + schema = currentConn.getSchema(); + if (catalog != null) sessionStateService.setCatalog(catalog); + if (schema != null) sessionStateService.setSchema(schema); + } + + if (dbUserName == null) { + // For MySQL, metadata username is actually @. We just need the part before '@'. + dbUserName = metadata.getUserName(); + int nameIndexEnd = dbUserName.indexOf('@'); + if (nameIndexEnd > 0) { + dbUserName = dbUserName.substring(0, nameIndexEnd); + } + } + LOGGER.finest("DB driver protocol " + pluginService.getDriverProtocol() + + ", database product: " + metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion() + + ", catalog: " + catalog + ", schema: " + schema + ", user: " + dbUserName + + ", driver: " + metadata.getDriverName() + " " + metadata.getDriverVersion()); + // The cache key contains the schema name, user name, and the query string + String[] words = {catalog, schema, dbUserName, query}; + return String.join("_", words); + } catch (SQLException e) { + LOGGER.warning("Error getting session state: " + e.getMessage()); + return null; + } + } + + private ResultSet fetchResultSetFromCache(String queryStr) { + if (cacheConnection == null) return null; + + String cacheQueryKey = getCacheQueryKey(queryStr); + if (cacheQueryKey == null) return null; // Treat this as a cache miss + byte[] cachedResult = cacheConnection.readFromCache(cacheQueryKey); + if (cachedResult == null) return null; + // Convert result into ResultSet + try { + return CachedResultSet.deserializeFromByteArray(cachedResult); + } catch (Exception e) { + LOGGER.warning("Error de-serializing cached result: " + e.getMessage()); + return null; // Treat this as a cache miss + } + } + + /** + * Cache the given ResultSet object. + * The ResultSet object passed in would be consumed to create a CacheResultSet object. It is returned + * for consumer consumption. + */ + private ResultSet cacheResultSet(String queryStr, ResultSet rs, int expiry) throws SQLException { + // Write the resultSet into the cache as a single key + String cacheQueryKey = getCacheQueryKey(queryStr); + if (cacheQueryKey == null) return rs; // Treat this condition as un-cacheable + CachedResultSet crs = new CachedResultSet(rs); + byte[] jsonString = crs.serializeIntoByteArray(); + cacheConnection.writeToCache(cacheQueryKey, jsonString, expiry); + crs.beforeFirst(); + return crs; + } + + /** + * Determine the TTL based on an input query + * @param queryHint string. e.g. "CACHE_PARAM(ttl=100s, key=custom)" + * @return TTL in seconds to cache the query. + * null if the query is not cacheable. + */ + protected Integer getTtlForQuery(String queryHint) { + // Empty query is not cacheable + if (StringUtils.isNullOrEmpty(queryHint)) return null; + // Find CACHE_PARAM anywhere in the hint string (case insensitive) + String upperHint = queryHint.toUpperCase(); + int cacheParamStart = upperHint.indexOf(CACHE_PARAM_PATTERN); + if (cacheParamStart == -1) return null; + + // Find the matching closing parenthesis + int paramsStart = cacheParamStart + CACHE_PARAM_PATTERN.length(); + int paramsEnd = upperHint.indexOf(")", paramsStart); + if (paramsEnd == -1) return null; + + // Extract parameters between parentheses + String cacheParams = upperHint.substring(paramsStart, paramsEnd).trim(); + // Empty parameters + if (StringUtils.isNullOrEmpty(cacheParams)) { + LOGGER.warning("Empty CACHE_PARAM parameters"); + incrCounter(malformedHintCounter); + return null; + } + + // Parse comma-separated parameters + String[] params = cacheParams.split(","); + Integer ttlValue = null; + + for (String param : params) { + String[] keyValue = param.trim().split("="); + if (keyValue.length != 2) { + LOGGER.warning("Invalid caching parameter format: " + param); + incrCounter(malformedHintCounter); + return null; + } + String key = keyValue[0].trim(); + String value = keyValue[1].trim(); + + if ("TTL".equals(key)) { + if (!value.endsWith("S")) { + LOGGER.warning("TTL must end with 's': " + value); + incrCounter(malformedHintCounter); + return null; + } else{ + // Parse TTL value (e.g., "300s") + try { + ttlValue = Integer.parseInt(value.substring(0, value.length() - 1)); + // treat negative and 0 ttls as not cacheable + if (ttlValue <= 0) { + return null; + } + } catch (NumberFormatException e) { + LOGGER.warning(String.format("Invalid TTL format of %s for query %s", value, queryHint)); + incrCounter(malformedHintCounter); + return null; + } + } + } + } + return ttlValue; + } + + @Override + public T execute( + final Class resultClass, + final Class exceptionClass, + final Object methodInvokeOn, + final String methodName, + final JdbcCallable jdbcMethodFunc, + final Object[] jdbcMethodArgs) + throws E { + if (resultClass != ResultSet.class) { + return jdbcMethodFunc.call(); + } + + incrCounter(totalQueryCounter); + + ResultSet result; + boolean needToCache = false; + final String sql = getQuery(jdbcMethodArgs); + + TelemetryContext cacheContext = null; + TelemetryContext dbContext = null; + // If the query is cacheable, we try to fetch the query result from the cache. + boolean isInTransaction = pluginService.isInTransaction(); + // Get the query hint part in front of the query itself + String mainQuery = sql; // The main part of the query with the query hint prefix trimmed + int endOfQueryHint = 0; + Integer configuredQueryTtl = null; + // Queries longer than 16KB is not cacheable + if ((sql.length() < maxCacheableQuerySize) && sql.startsWith(QUERY_HINT_START_PATTERN)) { + endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN); + if (endOfQueryHint > 0) { + configuredQueryTtl = getTtlForQuery(sql.substring(QUERY_HINT_START_PATTERN.length(), endOfQueryHint).trim()); + mainQuery = sql.substring(endOfQueryHint + QUERY_HINT_END_PATTERN.length()).trim(); + } + } + + // Query result can be served from the cache if it has a configured TTL value, and it is + // not executed in a transaction as a transaction typically need to return consistent results. + if (!isInTransaction && (configuredQueryTtl != null)) { + cacheContext = telemetryFactory.openTelemetryContext( + TELEMETRY_CACHE_LOOKUP, TelemetryTraceLevel.TOP_LEVEL); + Exception cacheException = null; + try{ + result = fetchResultSetFromCache(mainQuery); + if (result == null) { + // Cache miss. Need to fetch result from the database + needToCache = true; + incrCounter(cacheMissCounter); + LOGGER.finest("Got a cache miss for SQL: " + sql); + } else { + LOGGER.finest("Got a cache hit for SQL: " + sql); + // Cache hit. Return the cached result + incrCounter(cacheHitCounter); + try { + result.beforeFirst(); + } catch (final SQLException ex) { + cacheException = ex; + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); + } + return resultClass.cast(result); + } + } finally { + if (cacheContext != null) { + if (cacheException != null) { + cacheContext.setSuccess(false); + cacheContext.setException(cacheException); + cacheContext.closeContext(); + } else if (!needToCache) { // Cache hit + cacheContext.setSuccess(true); + cacheContext.closeContext(); + } else { // Cache miss - leave context open + cacheContext.setSuccess(false); + } + } + } + } else { + incrCounter(cacheBypassCounter); + } + + dbContext = telemetryFactory.openTelemetryContext( + TELEMETRY_DATABASE_QUERY, TelemetryTraceLevel.TOP_LEVEL); + + try { + result = (ResultSet) jdbcMethodFunc.call(); + } finally { + if (dbContext != null) dbContext.closeContext(); + if (cacheContext != null) cacheContext.closeContext(); + } + + // We need to cache the query result if we got a cache miss for the query result, + // or the query is cacheable and executed inside a transaction. + if (isInTransaction && (configuredQueryTtl != null)) { + needToCache = true; + } + if (needToCache) { + try { + result = cacheResultSet(mainQuery, result, configuredQueryTtl); + } catch (final SQLException ex) { + // Log and re-throw exception + LOGGER.warning("Encountered SQLException when caching query results: " + ex.getMessage()); + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); + } + } + + return resultClass.cast(result); + } + + private void incrCounter(TelemetryCounter counter) { + if (counter == null) return; + counter.inc(); + } + + protected String getQuery(final Object[] jdbcMethodArgs) { + // Get query from method argument + if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { + return jdbcMethodArgs[0].toString().trim(); + } + return null; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java similarity index 83% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java index 555ed55cf..fb15d69c5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java @@ -14,17 +14,17 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import java.util.Properties; import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.ConnectionPluginFactory; import software.amazon.jdbc.PluginService; -public class DataCacheConnectionPluginFactory implements ConnectionPluginFactory { +public class DataRemoteCachePluginFactory implements ConnectionPluginFactory { @Override public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { - return new DataCacheConnectionPlugin(pluginService, props); + return new DataRemoteCachePlugin(pluginService, props); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java new file mode 100644 index 000000000..12d96825a --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java @@ -0,0 +1,125 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.util.Objects; +import java.util.logging.Logger; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.CredentialUtils; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.StringUtils; + +public class ElastiCacheIamTokenUtility implements IamTokenUtility { + + private static final Logger LOGGER = Logger.getLogger(ElastiCacheIamTokenUtility.class.getName()); + private static final String PARAM_ACTION = "Action"; + private static final String PARAM_USER = "User"; + private static final String ACTION_NAME = "connect"; + private static final String PARAM_RESOURCE_TYPE = "ResourceType"; + private static final String RESOURCE_TYPE_SERVERLESS_CACHE = "ServerlessCache"; + private static final String SERVICE_NAME = "elasticache"; + private static final String PROTOCOL = "http"; + private static final Duration EXPIRATION_DURATION = Duration.ofSeconds(15 * 60 - 30); + public static final String SERVERLESS_CACHE_IDENTIFIER = ".serverless."; + + private final Clock clock; + private String cacheName = null; + private final Aws4Signer signer; + + public ElastiCacheIamTokenUtility(String cacheName) { + this.cacheName = Objects.requireNonNull(cacheName, "cacheName cannot be null"); + this.clock = Clock.systemUTC(); + this.signer = Aws4Signer.create(); + } + + // For testing only + public ElastiCacheIamTokenUtility(String cacheName, Instant fixedInstant, Aws4Signer signer) { + this.cacheName = Objects.requireNonNull(cacheName, "cacheName cannot be null"); + this.clock = Clock.fixed(fixedInstant, ZoneId.of("UTC")); + this.signer = signer; + } + + @Override + public String generateAuthenticationToken( + final @NonNull AwsCredentialsProvider credentialsProvider, + final @NonNull Region region, + final @NonNull String hostname, + final int port, + final @NonNull String username) { + + boolean isServerless = isServerlessCache(hostname); + if (this.cacheName == null) { + throw new IllegalArgumentException("Cache name cannot be null for cache with IAM authentication"); + } + + SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol(PROTOCOL) // ElastiCache uses http, not https + .host(this.cacheName) + .encodedPath("/") + .putRawQueryParameter(PARAM_ACTION, ACTION_NAME) + .putRawQueryParameter(PARAM_USER, username); + + if (isServerless) { + requestBuilder.putRawQueryParameter(PARAM_RESOURCE_TYPE, RESOURCE_TYPE_SERVERLESS_CACHE); + } + + final SdkHttpFullRequest httpRequest = requestBuilder.build(); + + final Instant expirationTime = Instant.now(this.clock).plus(EXPIRATION_DURATION); + + final AwsCredentials credentials = CredentialUtils.toCredentials( + CompletableFutureUtils.joinLikeSync(credentialsProvider.resolveIdentity())); + + final Aws4PresignerParams presignRequest = Aws4PresignerParams.builder() + .signingClockOverride(this.clock) + .expirationTime(expirationTime) + .awsCredentials(credentials) + .signingName(SERVICE_NAME) + .signingRegion(region) + .build(); + + final SdkHttpFullRequest fullRequest = this.signer.presign(httpRequest, presignRequest); + final String signedUrl = fullRequest.getUri().toString(); + + // Format should be: + // Regular: /?Action=connect&User=&X-Amz-Security-Token=...&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=...&X-Amz-SignedHeaders=host&X-Amz-Expires=870&X-Amz-Credential=...&X-Amz-Signature=... + // Serverless: /?Action=connect&User=&ResourceType=ServerlessCache&X-Amz-Security-Token=...&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=...&X-Amz-SignedHeaders=host&X-Amz-Expires=870&X-Amz-Credential=...&X-Amz-Signature=... + // Note: This must be the real ElastiCache hostname, not proxy or tunnels + final String result = StringUtils.replacePrefixIgnoreCase(signedUrl, "http://", ""); + LOGGER.finest(() -> "Generated ElastiCache authentication token with expiration of " + expirationTime); + return result; + } + + private boolean isServerlessCache(String hostname) { + if (hostname == null) { + throw new IllegalArgumentException("Hostname cannot be null"); + } + return hostname.contains(SERVERLESS_CACHE_IDENTIFIER); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java index 0037bdbd9..837a90d5c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java @@ -573,7 +573,7 @@ public static Connection getConnectionFromSqlObject(final Object obj) { } } catch (final SQLException | UnsupportedOperationException e) { // Do nothing. The UnsupportedOperationException comes from ResultSets returned by - // DataCacheConnectionPlugin and will be triggered when getStatement is called. + // DataLocalCacheConnectionPlugin and will be triggered when getStatement is called. } return null; diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index c5ea7c275..94afe5a1f 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -126,8 +126,12 @@ CustomEndpointPlugin.waitingForCustomEndpointInfo=Custom endpoint info for ''{0} CustomEndpointPluginFactory.awsSdkNotInClasspath=Required dependency 'AWS Java SDK RDS v2.x' is not on the classpath. -DataCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} +DataLocalCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} +# Data Remote Cache Plugin +DataRemoteCachePlugin.notInClassPath=Required dependency for DataRemoteCachePlugin is not on the classpath: ''{0}'' + +# Default Connection Plugin DefaultConnectionPlugin.executingMethod=Executing method: ''{0}'' DefaultConnectionPlugin.noHostsAvailable=The default connection plugin received an empty host list from the plugin service. DefaultConnectionPlugin.unknownRoleRequested=A HostSpec with a role of HostRole.UNKNOWN was requested via getHostSpecByStrategy. The requested role must be either HostRole.WRITER or HostRole.READER diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 3d26591e6..3e2f17251 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -51,6 +51,7 @@ dependencies { testImplementation("org.testcontainers:mariadb:1.20.4") testImplementation("org.testcontainers:junit-jupiter:1.20.4") testImplementation("org.testcontainers:toxiproxy:1.20.4") + testImplementation("org.apache.commons:commons-pool2:2.11.1") testImplementation("org.apache.poi:poi-ooxml:5.3.0") testImplementation("org.slf4j:slf4j-simple:2.0.13") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.17.1") @@ -58,6 +59,7 @@ dependencies { testImplementation("io.opentelemetry:opentelemetry-sdk:1.42.1") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.43.0") testImplementation("io.opentelemetry:opentelemetry-exporter-otlp:1.44.1") + testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 testImplementation("jakarta.persistence:jakarta.persistence-api:2.2.3") diff --git a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java index dc8f6afa3..d6b77fee5 100644 --- a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java +++ b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java @@ -39,8 +39,8 @@ import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin.CachedResultSet; +import software.amazon.jdbc.plugin.cache.CachedResultSet; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; @TestMethodOrder(MethodOrderer.MethodName.class) @ExtendWith(TestDriverProvider.class) @@ -58,20 +58,20 @@ public class DataCachePluginTests { @BeforeEach public void beforeEach() { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); } @TestTemplate public void testQueryCacheable() throws SQLException { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); final Properties props = ConnectionStringHelper.getDefaultProperties(); PropertyDefinition.CONNECT_TIMEOUT.set(props, "30000"); PropertyDefinition.SOCKET_TIMEOUT.set(props, "30000"); props.setProperty(PropertyDefinition.PLUGINS.name, "dataCache"); - props.setProperty(DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*testTable.*"); + props.setProperty(DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*testTable.*"); Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(), props); @@ -174,14 +174,14 @@ private void printTable() { @TestTemplate public void testQueryNotCacheable() throws SQLException { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); final Properties props = ConnectionStringHelper.getDefaultProperties(); PropertyDefinition.CONNECT_TIMEOUT.set(props, "30000"); PropertyDefinition.SOCKET_TIMEOUT.set(props, "30000"); props.setProperty(PropertyDefinition.PLUGINS.name, "dataCache"); props.setProperty( - DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*WRONG_EXPRESSION.*"); + DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*WRONG_EXPRESSION.*"); Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(), props); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java new file mode 100644 index 000000000..014fab5c4 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java @@ -0,0 +1,466 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import io.lettuce.core.RedisFuture; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.api.sync.RedisCommands; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.function.BiConsumer; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class CacheConnectionTest { + @Mock GenericObjectPool> mockReadConnPool; + @Mock GenericObjectPool> mockWriteConnPool; + @Mock StatefulRedisConnection mockConnection; + @Mock RedisCommands mockSyncCommands; + @Mock RedisAsyncCommands mockAsyncCommands; + @Mock RedisFuture mockCacheResult; + private AutoCloseable closeable; + private CacheConnection cacheConnection; + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + Properties props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + cacheConnection = new CacheConnection(props); + cacheConnection.setConnectionPools(mockReadConnPool, mockWriteConnPool); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void testIamAuth_PropertyExtraction() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Verify all IAM fields are set correctly + assertEquals("us-west-2", getField(connection, "cacheIamRegion")); + assertEquals("myuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + } + + @Test + void testIamAuth_PropertyExtractionTraditional() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cachePassword", "password"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Verify all IAM fields are set correctly + assertEquals("myuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("password", getField(connection, "cachePassword")); + } + + @Test + void testIamAuthEnabled_WhenRegionProvided() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Use reflection to verify iamAuthEnabled is true + Field field = CacheConnection.class.getDeclaredField("iamAuthEnabled"); + field.setAccessible(true); + assertTrue((boolean) field.get(connection)); + // Verify all IAM fields are set correctly + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + } + + @Test + void testConstructor_IamAuthEnabled_MissingCacheName() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); + props.setProperty("cacheUsername", "myuser"); + // Missing cacheName property + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new CacheConnection(props) + ); + + assertTrue(exception.getMessage().contains("IAM authentication requires cache name, username, region, and hostname")); + } + + @Test + void testTraditionalAuth_WhenNoIamRegion() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "user"); + props.setProperty("cachePassword", "pass"); + + CacheConnection connection = new CacheConnection(props); + + assertFalse((boolean) getField(connection, "iamAuthEnabled")); + assertNull(getField(connection, "credentialsProvider")); + assertEquals("user", getField(connection, "cacheUsername")); + assertEquals("pass", getField(connection, "cachePassword")); + } + + @Test + void testConstructor_NoRwAddress() { + Properties props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRo", "localhost:6379"); + + assertThrows(IllegalArgumentException.class, () -> new CacheConnection(props)); + } + + @Test + void testConstructor_IamAuthEnabled_MissingCacheUsername() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + + assertThrows(IllegalArgumentException.class, () -> new CacheConnection(props)); + } + + @Test + void testConstructor_ConflictingAuthenticationMethods() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); // IAM auth + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cachePassword", "mypassword"); // Traditional auth + props.setProperty("cacheName", "my-cache"); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new CacheConnection(props) + ); + + assertTrue(exception.getMessage().contains("Cannot specify both IAM authentication")); + } + + @Test + void testAwsCredentialsProvider_WithProfile() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + props.setProperty("awsProfile", "test-profile"); + + CacheConnection connection = new CacheConnection(props); + + // Verify the awsProfileProperties field contains the correct profile + Properties awsProfileProps = (Properties) getField(connection, "awsProfileProperties"); + assertEquals("test-profile", awsProfileProps.getProperty("awsProfile")); + + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("test.cache.amazonaws.com:6379", getField(connection, "cacheRwServerAddr")); + } + + @Test + void testAwsCredentialsProvider_WithoutProfile() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + // No awsProfile property + + CacheConnection connection = new CacheConnection(props); + + // Verify awsProfileProperties is not empty when no profile specified + Properties awsProfileProps = (Properties) getField(connection, "awsProfileProperties"); + assertNull(awsProfileProps); + + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("test.cache.amazonaws.com:6379", getField(connection, "cacheRwServerAddr")); + } + + @Test + void testBuildRedisURI_IamAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "test-cache"); + + try (MockedConstruction mockedTokenUtility = mockConstruction(ElastiCacheIamTokenUtility.class)) { + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("test-cache.cache.amazonaws.com", 6379); + + // Verify URI properties + assertNotNull(uri); + assertEquals("test-cache.cache.amazonaws.com", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertTrue(uri.isSsl()); + assertNotNull(uri.getCredentialsProvider()); + + // Trigger the credentials provider to create the token utility + uri.getCredentialsProvider().resolveCredentials().block(); + + // Verify URI properties + assertNotNull(uri); + assertEquals("test-cache.cache.amazonaws.com", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertTrue(uri.isSsl()); + assertNotNull(uri.getCredentialsProvider()); // IAM credentials provider set + + // Verify ElastiCacheIamTokenUtility was constructed with correct parameters + // Verify token utility construction + assertEquals(1, mockedTokenUtility.constructed().size()); + ElastiCacheIamTokenUtility tokenUtility = mockedTokenUtility.constructed().get(0); + verify(tokenUtility).generateAuthenticationToken( + any(AwsCredentialsProvider.class), + eq(Region.US_EAST_1), + eq("test-cache.cache.amazonaws.com"), + eq(6379), + eq("testuser") + ); + } + } + + @Test + void testBuildRedisURI_TraditionalAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "user"); + props.setProperty("cachePassword", "pass"); + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertEquals("user", uri.getUsername()); + assertEquals("pass", new String(uri.getPassword())); + } + + @Test + void testBuildRedisURI_NoAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertNull(uri.getUsername()); + assertNull(uri.getPassword()); + } + + @Test + void test_writeToCache() throws Exception { + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult); + when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null); + cacheConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockConnection).async(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockCacheResult).whenComplete(any(BiConsumer.class)); + } + + @Test + void test_writeToCacheException() throws Exception { + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenThrow(new RuntimeException("test exception")); + cacheConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockConnection).async(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockWriteConnPool).invalidateObject(mockConnection); + } + + @Test + void test_handleCompletedCacheWrite() throws Exception { + cacheConnection.handleCompletedCacheWrite(mockConnection, null); + verify(mockWriteConnPool).returnObject(mockConnection); + cacheConnection.handleCompletedCacheWrite(mockConnection, new RuntimeException("test")); + verify(mockWriteConnPool).invalidateObject(mockConnection); + } + + @Test + void test_readFromCache() throws Exception { + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenReturn(value); + byte[] result = cacheConnection.readFromCache("myQueryKey"); + assertEquals(value, result); + verify(mockReadConnPool).borrowObject(); + verify(mockConnection).sync(); + verify(mockSyncCommands).get(any()); + verify(mockReadConnPool).returnObject(mockConnection); + } + + @Test + void test_readFromCacheException() throws Exception { + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenThrow(new RuntimeException("test")); + assertNull(cacheConnection.readFromCache("myQueryKey")); + verify(mockReadConnPool).borrowObject(); + verify(mockConnection).sync(); + verify(mockSyncCommands).get(any()); + verify(mockReadConnPool).invalidateObject(mockConnection); + } + + @Test + void test_cacheConnectionPoolSize_default() throws Exception { + clearStaticPools(); + + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + + CacheConnection connection = new CacheConnection(props); + + // Create real pools (no network until borrow) + connection.triggerPoolInit(true); + connection.triggerPoolInit(false); + + GenericObjectPool> readPool = getStaticPool("readConnectionPool"); + GenericObjectPool> writePool = getStaticPool("writeConnectionPool"); + + assertNotNull(readPool, "read pool should be created"); + assertNotNull(writePool, "write pool should be created"); + + assertEquals(20, readPool.getMaxTotal()); + assertEquals(20, readPool.getMaxIdle()); + assertEquals(20, writePool.getMaxTotal()); + assertEquals(20, writePool.getMaxIdle()); + assertNotEquals(8, readPool.getMaxTotal()); // making sure it does not set the default values of Generic pool + assertNotEquals(8, writePool.getMaxIdle()); + } + + @Test + void test_cacheConnectionPoolSize_Initialization() throws Exception { + clearStaticPools(); + + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + props.setProperty("cacheConnectionPoolSize", "15"); + + CacheConnection connection = new CacheConnection(props); + + // Create real pools (no network until borrow) + connection.triggerPoolInit(true); + connection.triggerPoolInit(false); + + GenericObjectPool> readPool = getStaticPool("readConnectionPool"); + GenericObjectPool> writePool = getStaticPool("writeConnectionPool"); + + assertNotNull(readPool, "read pool should be created"); + assertNotNull(writePool, "write pool should be created"); + + assertEquals(15, readPool.getMaxTotal()); + assertEquals(15, readPool.getMaxIdle()); + assertEquals(15, writePool.getMaxTotal()); + assertEquals(15, writePool.getMaxIdle()); + } + + @Test + void test_cacheConnectionTimeout_Initialization() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheConnectionTimeout", "5000"); + + CacheConnection connection = new CacheConnection(props); + Duration timeout = (Duration) getField(connection, "cacheConnectionTimeout"); + assertEquals(Duration.ofMillis(5000), timeout); + } + + @Test + void test_cacheConnectionTimeout_default() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection = new CacheConnection(props); + Duration timeout = (Duration) getField(connection, "cacheConnectionTimeout"); + assertEquals(Duration.ofMillis(2000), timeout, "default should be 2000 ms"); + } + + @SuppressWarnings("unchecked") + private static GenericObjectPool> getStaticPool(String field) throws Exception { + Field f = CacheConnection.class.getDeclaredField(field); + f.setAccessible(true); + return (GenericObjectPool>) f.get(null); + } + + private static void clearStaticPools() throws Exception { + for (String fieldName : new String[]{"readConnectionPool", "writeConnectionPool"}) { + Field f = CacheConnection.class.getDeclaredField(fieldName); + f.setAccessible(true); + f.set(null, null); + } + } + + private Object getField(Object obj, String fieldName) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(obj); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java new file mode 100644 index 000000000..458ae96cc --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.Test; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public class CacheSupplierTest { + + @Test + void testMemoizeWithExpiration_ValidParameters() { + Supplier delegate = () -> "test-value"; + + Supplier cached = CachedSupplier.memoizeWithExpiration(delegate, 1, TimeUnit.SECONDS); + + assertNotNull(cached); + assertEquals("test-value", cached.get()); + } + + @Test + void testMemoizeWithExpiration_NullDelegate() { + assertThrows(NullPointerException.class, () -> + CachedSupplier.memoizeWithExpiration(null, 1, TimeUnit.SECONDS)); + } + + @Test + void testMemoizeWithExpiration_NullTimeUnit() { + Supplier delegate = () -> "test"; + + assertThrows(NullPointerException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, 1, null)); + } + + @Test + void testMemoizeWithExpiration_ZeroDuration() { + Supplier delegate = () -> "test"; + + assertThrows(IllegalArgumentException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, 0, TimeUnit.SECONDS)); + } + + @Test + void testMemoizeWithExpiration_NegativeDuration() { + Supplier delegate = () -> "test"; + + assertThrows(IllegalArgumentException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, -1, TimeUnit.SECONDS)); + } + + @Test + void testCaching_DelegateCalledOnce() { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("cached-value"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 1, TimeUnit.SECONDS); + + // Call multiple times quickly + assertEquals("cached-value", cached.get()); + assertEquals("cached-value", cached.get()); + assertEquals("cached-value", cached.get()); + + // Delegate should only be called once due to caching + verify(mockDelegate, times(1)).get(); + } + + @Test + void testExpiration_DelegateCalledAgainAfterExpiry() throws InterruptedException { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("value1", "value2"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 50, TimeUnit.MILLISECONDS); + + // First call + assertEquals("value1", cached.get()); + verify(mockDelegate, times(1)).get(); + + // Wait for expiration + Thread.sleep(100); + + // Second call after expiration + assertEquals("value2", cached.get()); + verify(mockDelegate, times(2)).get(); + } + + @Test + void testConcurrentAccess() throws InterruptedException { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("concurrent-value"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 5, TimeUnit.SECONDS); + + // Simulate concurrent access + Thread[] threads = new Thread[10]; + String[] results = new String[10]; + + for (int i = 0; i < 10; i++) { + final int index = i; + threads[i] = new Thread(() -> results[index] = cached.get()); + threads[i].start(); + } + + // Wait for all threads + for (Thread thread : threads) { + thread.join(); + } + + // All should get the same cached value + for (String result : results) { + assertEquals("concurrent-value", result); + } + + // Delegate should only be called once despite concurrent access + verify(mockDelegate, times(1)).get(); + } + + @Test + void testExpirationNanos_EdgeCase() { + Supplier timeSupplier = () -> System.nanoTime(); + + Supplier cached = CachedSupplier.memoizeWithExpiration(timeSupplier, 1, TimeUnit.NANOSECONDS); + + Long first = cached.get(); + Long second = cached.get(); + + // Due to very short expiration, second call might get different value + assertNotNull(first); + assertNotNull(second); + } + + @Test + void testPrivateConstructor() { + // Verify utility class has private constructor + assertThrows(Exception.class, () -> { + java.lang.reflect.Constructor constructor = + CachedSupplier.class.getDeclaredConstructor(); + constructor.setAccessible(true); + constructor.newInstance(); + }); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java new file mode 100644 index 000000000..00ac96e34 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -0,0 +1,865 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +import java.sql.*; +import java.sql.Date; +import java.time.*; +import java.util.*; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import java.net.URL; +import java.net.MalformedURLException; + +import java.math.BigDecimal; + +public class CachedResultSetTest { + private CachedResultSet testResultSet; + @Mock ResultSet mockResultSet; + @Mock ResultSetMetaData mockResultSetMetadata; + private AutoCloseable closeable; + private static final Calendar estCal = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + private final TimeZone defaultTimeZone = TimeZone.getDefault(); + + // Column values: label, name, typeName, type, displaySize, precision, tableName, + // scale, schemaName, isAutoIncrement, isCaseSensitive, isCurrency, isDefinitelyWritable, + // isNullable, isReadOnly, isSearchable, isSigned, isWritable + private static final Object [][] testColumnMetadata = { + {"fieldNull", "fieldNull", "String", Types.VARCHAR, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldInt", "fieldInt", "Integer", Types.INTEGER, 10, 2, "table", 1, "public", true, false, false, false, 0, false, true, true, true}, + {"fieldString", "fieldString", "String", Types.VARCHAR, 10, 2, "table", 1, "public", false, false, false, false, 0, false, true, false, true}, + {"fieldBoolean", "fieldBoolean", "Boolean", Types.BOOLEAN, 10, 2, "table", 1, "public", false, false, false, false, 0, false, true, false, true}, + {"fieldByte", "fieldByte", "Byte", Types.TINYINT, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldShort", "fieldShort", "Short", Types.SMALLINT, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldLong", "fieldLong", "Long", Types.BIGINT, 10, 2, "table", 1, "public", false, false, false, false, 1, false, true, false, false}, + {"fieldFloat", "fieldFloat", "Float", Types.REAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldDouble", "fieldDouble", "Double", Types.DOUBLE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldBigDecimal", "fieldBigDecimal", "BigDecimal", Types.DECIMAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldDate", "fieldDate", "Date", Types.DATE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldTime", "fieldTime", "Time", Types.TIME, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldDateTime", "fieldDateTime", "Timestamp", Types.TIMESTAMP, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldSqlXml", "fieldSqlXml", "SqlXml", Types.SQLXML, 100, 1, "table", 1, "public", false, false, false, false, 0, true, true, false, false} + }; + + private static final Object [][] testColumnValues = { + {null, null}, + {1, 123456}, + {"John Doe", "Tony Stark"}, + {true, false}, + {(byte)100, (byte)70}, // Letter d and F in ASCII + {(short)55, (short)135}, + {2^33L, -2^35L}, + {3.14159f, -233.14159f}, + {2345.23345d, -2344355.4543d}, + {new BigDecimal("15.33"), new BigDecimal("-12.45")}, + {Date.valueOf("2025-03-15"), Date.valueOf("1102-01-15")}, + {Time.valueOf("22:54:00"), Time.valueOf("01:10:00")}, + {Timestamp.valueOf("2025-03-15 22:54:00"), Timestamp.valueOf("1950-01-18 21:50:05")}, + {new CachedSQLXML("A"), new CachedSQLXML("Value AValue B")} + }; + + private void mockGetMetadataFields(int column, int testMetadataCol) throws SQLException { + when(mockResultSetMetadata.getCatalogName(column)).thenReturn(""); + when(mockResultSetMetadata.getColumnClassName(column)).thenReturn("MyClass" + testMetadataCol); + when(mockResultSetMetadata.getColumnLabel(column)).thenReturn((String) testColumnMetadata[testMetadataCol][0]); + when(mockResultSetMetadata.getColumnName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][1]); + when(mockResultSetMetadata.getColumnTypeName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][2]); + when(mockResultSetMetadata.getColumnType(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][3]); + when(mockResultSetMetadata.getColumnDisplaySize(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][4]); + when(mockResultSetMetadata.getPrecision(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][5]); + when(mockResultSetMetadata.getTableName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][6]); + when(mockResultSetMetadata.getScale(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][7]); + when(mockResultSetMetadata.getSchemaName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][8]); + when(mockResultSetMetadata.isAutoIncrement(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][9]); + when(mockResultSetMetadata.isCaseSensitive(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][10]); + when(mockResultSetMetadata.isCurrency(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][11]); + when(mockResultSetMetadata.isDefinitelyWritable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][12]); + when(mockResultSetMetadata.isNullable(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][13]); + when(mockResultSetMetadata.isReadOnly(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][14]); + when(mockResultSetMetadata.isSearchable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][15]); + when(mockResultSetMetadata.isSigned(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][16]); + when(mockResultSetMetadata.isWritable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][17]); + } + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")); + } + + @AfterEach + void cleanUp() { + TimeZone.setDefault(defaultTimeZone); + } + + void setUpDefaultTestResultSet() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(testColumnMetadata.length); + for (int i = 0; i < testColumnMetadata.length; i++) { + mockGetMetadataFields(1+i, i); + when(mockResultSet.getObject(1+i)).thenReturn(testColumnValues[i][0], testColumnValues[i][1]); + } + when(mockResultSet.next()).thenReturn(true, true, false); + testResultSet = new CachedResultSet(mockResultSet); + } + + private void verifyDefaultMetadata(ResultSet rs) throws SQLException { + ResultSetMetaData md = rs.getMetaData(); + for (int i = 0; i < md.getColumnCount(); i++) { + assertEquals("", md.getCatalogName(i+1)); + assertEquals("MyClass" + i, md.getColumnClassName(i+1)); + assertEquals(testColumnMetadata[i][0], md.getColumnLabel(i+1)); + assertEquals(testColumnMetadata[i][1], md.getColumnName(i+1)); + assertEquals(testColumnMetadata[i][2], md.getColumnTypeName(i+1)); + assertEquals(testColumnMetadata[i][3], md.getColumnType(i+1)); + assertEquals(testColumnMetadata[i][4], md.getColumnDisplaySize(i+1)); + assertEquals(testColumnMetadata[i][5], md.getPrecision(i+1)); + assertEquals(testColumnMetadata[i][6], md.getTableName(i+1)); + assertEquals(testColumnMetadata[i][7], md.getScale(i+1)); + assertEquals(testColumnMetadata[i][8], md.getSchemaName(i+1)); + assertEquals(testColumnMetadata[i][9], md.isAutoIncrement(i+1)); + assertEquals(testColumnMetadata[i][10], md.isCaseSensitive(i+1)); + assertEquals(testColumnMetadata[i][11], md.isCurrency(i+1)); + assertEquals(testColumnMetadata[i][12], md.isDefinitelyWritable(i+1)); + assertEquals(testColumnMetadata[i][13], md.isNullable(i+1)); + assertEquals(testColumnMetadata[i][14], md.isReadOnly(i+1)); + assertEquals(testColumnMetadata[i][15], md.isSearchable(i+1)); + assertEquals(testColumnMetadata[i][16], md.isSigned(i+1)); + assertEquals(testColumnMetadata[i][17], md.isWritable(i+1)); + } + } + + private void verifyDefaultRow(ResultSet rs, int row) throws SQLException { + assertFalse(rs.wasNull()); + assertNull(rs.getObject(1)); // fieldNull + assertEquals(1, rs.findColumn("fieldNull")); + assertTrue(rs.wasNull()); + assertEquals((int) testColumnValues[1][row], rs.getInt(2)); // fieldInt + assertFalse(rs.wasNull()); + assertEquals((int) testColumnValues[1][row], rs.getInt("fieldInt")); + assertEquals(2, rs.findColumn("fieldInt")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[2][row], rs.getString(3)); // fieldString + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[2][row], rs.getString("fieldString")); + assertEquals(3, rs.findColumn("fieldString")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[3][row], rs.getBoolean(4)); // fieldBoolean + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[3][row], rs.getBoolean("fieldBoolean")); + assertEquals(4, rs.findColumn("fieldBoolean")); + assertFalse(rs.wasNull()); + assertEquals((byte) testColumnValues[4][row], rs.getByte(5)); // fieldByte + assertFalse(rs.wasNull()); + assertEquals((byte) testColumnValues[4][row], rs.getByte("fieldByte")); + assertEquals(5, rs.findColumn("fieldByte")); + assertFalse(rs.wasNull()); + assertEquals((short) testColumnValues[5][row], rs.getShort(6)); // fieldShort + assertFalse(rs.wasNull()); + assertEquals((short) testColumnValues[5][row], rs.getShort("fieldShort")); + assertEquals(6, rs.findColumn("fieldShort")); + assertFalse(rs.wasNull()); + assertNull(rs.getObject("fieldNull")); + assertTrue(rs.wasNull()); + assertEquals((Long) testColumnValues[6][row], rs.getLong(7)); // fieldLong + assertFalse(rs.wasNull()); + assertEquals((Long) testColumnValues[6][row], rs.getLong("fieldLong")); + assertEquals(7, rs.findColumn("fieldLong")); + assertFalse(rs.wasNull()); + assertEquals((float) testColumnValues[7][row], rs.getFloat(8), 0); // fieldFloat + assertFalse(rs.wasNull()); + assertEquals((float) testColumnValues[7][row], rs.getFloat("fieldFloat"), 0); + assertEquals(8, rs.findColumn("fieldFloat")); + assertFalse(rs.wasNull()); + assertEquals((double) testColumnValues[8][row], rs.getDouble(9)); // fieldDouble + assertFalse(rs.wasNull()); + assertEquals((double) testColumnValues[8][row], rs.getDouble("fieldDouble")); + assertEquals(9, rs.findColumn("fieldDouble")); + assertFalse(rs.wasNull()); + assertEquals(0, rs.getBigDecimal(10).compareTo((BigDecimal) testColumnValues[9][row])); // fieldBigDecimal + assertFalse(rs.wasNull()); + assertEquals(0, rs.getBigDecimal("fieldBigDecimal").compareTo((BigDecimal) testColumnValues[9][row])); + assertEquals(10, rs.findColumn("fieldBigDecimal")); + assertFalse(rs.wasNull()); + assertNull(rs.getObject(1)); // fieldNull + assertTrue(rs.wasNull()); + assertEquals(testColumnValues[10][row], rs.getDate(11)); // fieldDate + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[10][row], rs.getDate("fieldDate")); + assertEquals(11, rs.findColumn("fieldDate")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[11][row], rs.getTime(12)); // fieldTime + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[11][row], rs.getTime("fieldTime")); + assertEquals(12, rs.findColumn("fieldTime")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[12][row], rs.getTimestamp(13)); // fieldDateTime + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[12][row], rs.getTimestamp("fieldDateTime")); + assertEquals(13, rs.findColumn("fieldDateTime")); + assertFalse(rs.wasNull()); + String sqlXmlString = ((SQLXML)testColumnValues[13][row]).getString(); + assertEquals(sqlXmlString, rs.getSQLXML(14).getString()); // fieldSqlXml + assertFalse(rs.wasNull()); + assertEquals(sqlXmlString, rs.getSQLXML("fieldSqlXml").getString()); + assertEquals(14, rs.findColumn("fieldSqlXml")); + assertFalse(rs.wasNull()); + verifyNonexistingField(rs); + } + + private void verifyNonexistingField(ResultSet rs) { + try { + rs.getObject("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } + try { + rs.findColumn("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } + } + + @Test + void test_basic_cached_result_set() throws Exception { + // Basic verification of the test result set + setUpDefaultTestResultSet(); + verifyDefaultMetadata(testResultSet); + assertEquals(0, testResultSet.getRow()); + assertTrue(testResultSet.next()); + assertEquals(1, testResultSet.getRow()); + verifyDefaultRow(testResultSet, 0); + assertTrue(testResultSet.next()); + assertEquals(2, testResultSet.getRow()); + verifyDefaultRow(testResultSet, 1); + assertFalse(testResultSet.next()); + assertEquals(0, testResultSet.getRow()); + assertNull(testResultSet.getWarnings()); + testResultSet.clearWarnings(); + assertNull(testResultSet.getWarnings()); + testResultSet.beforeFirst(); + // Test serialization and de-serialization of the result set + byte[] serialized_data = testResultSet.serializeIntoByteArray(); + ResultSet rs = CachedResultSet.deserializeFromByteArray(serialized_data); + verifyDefaultMetadata(rs); + assertTrue(rs.next()); + verifyDefaultRow(rs, 0); + assertTrue(rs.next()); + verifyDefaultRow(rs, 1); + assertFalse(rs.next()); + assertNull(rs.getWarnings()); + rs.relative(-10); // We should be before the start of the rows + assertTrue(rs.isBeforeFirst()); + assertEquals(0, rs.getRow()); + rs.relative(10); // We should be after the end of the rows + assertTrue(rs.isAfterLast()); + assertEquals(0, rs.getRow()); + rs.absolute(-10); // We should be before the start of the rows + assertTrue(rs.isBeforeFirst()); + assertFalse(rs.absolute(100)); // Jump to after the end of the rows + assertTrue(rs.isAfterLast()); + assertEquals(0, rs.getRow()); + assertFalse(rs.absolute(0)); // Go to the beginning of rows + assertTrue(rs.isBeforeFirst()); + assertTrue(rs.next()); // We are at first row + verifyDefaultRow(rs, 0); + rs.relative(1); // Advances to next row + verifyDefaultRow(rs, 1); + assertTrue(rs.previous()); // Go back to first row + verifyDefaultRow(rs, 0); + assertFalse(rs.previous()); + assertTrue(rs.absolute(2)); // Jump to second row + verifyDefaultRow(rs, 1); + assertTrue(rs.first()); // go to first row + verifyDefaultRow(rs, 0); + assertEquals(1, rs.getRow()); + assertTrue(rs.last()); // go to last row + verifyDefaultRow(rs, 1); + assertEquals(2, rs.getRow()); + } + + @Test + void test_get_special_bigDecimal() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 9); + when(mockResultSet.getObject(1)).thenReturn( + 12450.567, + -132.45, + "142.346", + "invalid", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + CachedResultSet rs = new CachedResultSet(mockResultSet); + + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("12450.567"))); + + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("-132.45"))); + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("142.346"))); + assertTrue(rs.next()); + try { + rs.getBigDecimal(1); + fail("Invalid value should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Value is null + assertTrue(rs.next()); + assertNull(rs.getBigDecimal(1)); + } + + @Test + void test_get_special_timestamp() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 12); + when(mockResultSet.getObject(1)).thenReturn( + 1504844311000L, + LocalDateTime.of(1981, 3, 10, 1, 10, 20), + OffsetDateTime.parse("2025-08-10T10:00:00+03:00"), + ZonedDateTime.parse("2024-07-30T10:00:00+02:00[Europe/Berlin]"), + "2015-03-15 12:50:04", + "invalidDateTime", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Timestamp from a number + assertTrue(cachedRs.next()); + assertEquals(new Timestamp(1504844311000L), cachedRs.getTimestamp(1)); + // Timestamp from LocalDateTime + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("1981-03-10 01:10:20"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("1981-03-09 22:10:20"), cachedRs.getTimestamp(1, estCal)); + // Timestamp from OffsetDateTime (containing time zone info) + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2025-08-10 00:00:00"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2025-08-10 00:00:00"), cachedRs.getTimestamp(1, estCal)); + // Timestmap from ZonedDateTime (containing time zone info) + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2024-07-30 01:00:00"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2024-07-30 01:00:00"), cachedRs.getTimestamp(1, estCal)); + // Timestamp from String + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2015-03-15 12:50:04"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2015-03-15 12:50:04"), cachedRs.getTimestamp(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getTimestamp(1); + fail("Invalid timestamp should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Timestamp is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getTimestamp(1)); + } + + @Test + void test_get_special_time() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 11); + when(mockResultSet.getObject(1)).thenReturn( + 4362000L, + LocalTime.of(10, 20, 30), + OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC), + new Timestamp(1755621000000L), // Date and time (GMT): Tuesday, August 19, 2025 4:30:00 PM + new Timestamp(1735713000000L), // Date and time (GMT): Wednesday, January 1, 2025 6:30:00 AM + new Timestamp(0L), // 1970-01-01 00:00:00 UTC (epoch) + new Timestamp(Timestamp.valueOf(LocalDateTime.now().plusYears(1).withHour(9).withMinute(30).withSecond(0).withNano(0)).getTime()), // Future Date: next year same date at 9:30 AM + "15:34:20", + "InvalidTime", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, + true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Time from a number + assertTrue(cachedRs.next()); + assertEquals(new Time(4362000L), cachedRs.getTime(1)); + // Time from LocalTime + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("10:20:30"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("07:20:30"), cachedRs.getTime(1, estCal)); + // Time from OffsetTime + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1, estCal)); + // Time from Timestamp + assertTrue(cachedRs.next()); + Timestamp timestampOne = new Timestamp(1755621000000L); + // Compare underlying millis + assertEquals(timestampOne.getTime(), cachedRs.getTime(1).getTime()); + // Compare logical wall-clock time + assertEquals(LocalTime.of(9, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(6, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Time from Timestamp Edge Case + assertTrue(cachedRs.next()); + Timestamp timestampTwo = new Timestamp(1735713000000L); + assertEquals(timestampTwo.getTime(), cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(22, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(19, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Epoch time of 0 + assertTrue(cachedRs.next()); + assertEquals(new Time(0), cachedRs.getTime(1)); + assertEquals(0L, cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(16, 0, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(13, 0, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Future date + assertTrue(cachedRs.next()); + Timestamp futureTimestamp = new Timestamp(Timestamp.valueOf(LocalDateTime.now().plusYears(1).withHour(9).withMinute(30).withSecond(0).withNano(0)).getTime()); + assertEquals(futureTimestamp.getTime(), cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(9, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(6, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Timestamp from String + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getTime(1); + fail("Invalid time should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Time is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getTime(1)); + } + + @Test + void test_get_special_date() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 10); + when(mockResultSet.getObject(1)).thenReturn( + 1515944311000L, + -1000000000L, + LocalDate.of(2010, 10, 30), + new Timestamp(1755621000000L), // Date and time (GMT): Tuesday, August 19, 2025 4:30:00 PM + new Timestamp(1735713000000L), // Date and time (GMT): Wednesday, January 1, 2025 6:30:00 AM + new Timestamp(1755673200000L), // Date and time (GMT): Wednesday, August 20, 2025 7:00:00 AM --> PDT Aug 20 12AM + new Timestamp(1735718400000L), // Date and time (GMT): Wednesday, January 1, 2025 8:00:00 AM --> PST Jan 1 12AM + new Timestamp(0L), // 1970-01-01 00:00:00 UTC (epoch) + "2025-03-15", + "InvalidDate", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, + true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Date from a number + assertTrue(cachedRs.next()); + Date date = cachedRs.getDate(1); + assertEquals(new Date(1515944311000L), date); + assertTrue(cachedRs.next()); + assertEquals(new Date(-1000000000L), cachedRs.getDate(1)); + // Date from LocalDate + + assertTrue(cachedRs.next()); + assertEquals(Date.valueOf("2010-10-30"), cachedRs.getDate(1)); + assertEquals(Date.valueOf("2010-10-29"), cachedRs.getDate(1, estCal)); + // Date from Timestamp + assertTrue(cachedRs.next()); + Timestamp tsForDate1 = new Timestamp(1755621000000L); + assertEquals(new Date(tsForDate1.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025, 8, 19), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2025, 8, 19), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate2 = new Timestamp(1735713000000L); + assertEquals(new Date(tsForDate2.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2024, 12, 31), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2024, 12, 31), cachedRs.getDate(1, estCal).toLocalDate()); + // Date from Timestamp Edge Case + assertTrue(cachedRs.next()); + Timestamp tsForDate3 = new Timestamp(1755673200000L); + assertEquals(new Date(tsForDate3.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025,8,20), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2025,8,19), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate4 = new Timestamp(1735718400000L); + assertEquals(new Date(tsForDate4.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025,1,1), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2024,12,31), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate5 = new Timestamp(0L); + assertEquals(new Date(tsForDate5.getTime()), cachedRs.getDate(1)); + assertEquals(new Date(0L), cachedRs.getDate(1)); + assertEquals(LocalDate.of(1969,12,31), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(1969,12,31), cachedRs.getDate(1, estCal).toLocalDate()); + // Date from String + assertTrue(cachedRs.next()); + assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1)); + assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getDate(1); + fail("Invalid date should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Date is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getDate(1)); + } + + @Test + void test_get_nstring() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + when(mockResultSet.getObject(1)).thenReturn("test string", 123, null); + when(mockResultSet.next()).thenReturn(true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test string value - both index and label versions + assertTrue(cachedRs.next()); + assertEquals("test string", cachedRs.getNString(1)); + assertFalse(cachedRs.wasNull()); + assertEquals("test string", cachedRs.getNString("fieldString")); + assertFalse(cachedRs.wasNull()); + + // Test number conversion + assertTrue(cachedRs.next()); + assertEquals("123", cachedRs.getNString(1)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getNString(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_bytes() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 4); + // Test data + byte[] testBytes = {1, 2, 3, 4, 5}; + when(mockResultSet.getObject(1)).thenReturn(testBytes, "not bytes", 123, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test bytes values - both index and label versions + assertTrue(cachedRs.next()); + assertArrayEquals(testBytes, cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + assertArrayEquals(testBytes, cachedRs.getBytes("fieldByte")); + assertFalse(cachedRs.wasNull()); + + // Test non-byte array input (should convert to bytes) + assertTrue(cachedRs.next()); + assertArrayEquals("not bytes".getBytes(), cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + + // Test number input (should convert to bytes) + assertTrue(cachedRs.next()); + assertArrayEquals("123".getBytes(), cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getBytes(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_boolean() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 3); + // Test data: boolean, numbers, strings, null + when(mockResultSet.getObject(1)).thenReturn( + true, false, 0, 1, -5, "true", "false", "invalid", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual boolean values - both index and label versions + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); + assertFalse(cachedRs.wasNull()); + assertTrue(cachedRs.getBoolean("fieldBoolean")); + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); + assertFalse(cachedRs.wasNull()); + + // Test number conversions: 0 = true, non-zero = false + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // 0 → false + + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // 1 → true + + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // -5 → true + + // Test string conversions + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // "true" → true + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // "false" → false + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // "invalid" → false (parseBoolean) + + // Test null handling + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // null → false + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_URL() throws SQLException { + // Setup single column with string metadata (URLs stored as strings) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: URL object, valid URL string, invalid URL string, null + // URL object setup + URL testUrl = null; + try { + testUrl = new URL("https://example.com"); + } catch (MalformedURLException e) { + fail("Test setup failed"); + } + + when(mockResultSet.getObject(1)).thenReturn( + testUrl, "https://valid.com", "invalid-url", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual URL object - both index and label versions + assertTrue(cachedRs.next()); + assertEquals(testUrl, cachedRs.getURL(1)); + assertFalse(cachedRs.wasNull()); + assertEquals(testUrl, cachedRs.getURL("fieldString")); + + // Test valid URL string conversion + assertTrue(cachedRs.next()); + URL validURL = null; + try { + validURL = new URL("https://valid.com"); + } catch (MalformedURLException e) { + fail("Failed setting up new valid URL"); + } + assertEquals(validURL, cachedRs.getURL(1)); + assertFalse(cachedRs.wasNull()); + + // Test invalid URL string (should throw SQLException) + assertTrue(cachedRs.next()); + assertThrows(SQLException.class, () -> cachedRs.getURL(1)); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getURL(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_sql_xml() throws SQLException { + String longXml = + "\n" + + " TechCorp\n" + + "\n" + + " Intel i7\n" + + " 16GB\n" + + " 512GB SSD\n" + + "\n" + + " 1200.00\n" + + ""; + SQLXML testXml = new CachedSQLXML("PostgreSQL GuideJohn Doe"); + SQLXML testXml2 = new CachedSQLXML(longXml); + SQLXML invalidXml = new CachedSQLXML("A"); + // Setup single column with string metadata (URLs stored as strings) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 13); + when(mockResultSet.getObject(1)).thenReturn(testXml, testXml2, invalidXml, "invalid-xml", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual SQLXML objects - both index and label versions + assertTrue(cachedRs.next()); + assertEquals(testXml.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(testXml.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals(testXml2.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(testXml2.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals(invalidXml.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(invalidXml.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals("invalid-xml", cachedRs.getSQLXML(1).getString()); + assertEquals("invalid-xml", cachedRs.getSQLXML("fieldSqlXml").getString()); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertNull(cachedRs.getSQLXML(1)); + assertTrue(cachedRs.wasNull()); + assertNull(cachedRs.getSQLXML("fieldSqlXml")); + assertTrue(cachedRs.wasNull()); + + assertFalse(cachedRs.next()); + } + + + @Test + void test_get_object_with_index_and_type() throws SQLException { + // Setup single column with string metadata (mixed data types) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: string, integer, boolean, null + when(mockResultSet.getObject(1)).thenReturn("test", 123, true, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid type conversions + assertTrue(cachedRs.next()); + assertEquals("test", cachedRs.getObject(1, String.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Integer.valueOf(123), cachedRs.getObject(1, Integer.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Boolean.TRUE, cachedRs.getObject(1, Boolean.class)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getObject(1, String.class)); + assertTrue(cachedRs.wasNull()); + + // Test invalid type conversion (should throw ClassCastException) + cachedRs.beforeFirst(); + // Wraps around + assertTrue(cachedRs.next()); + assertThrows(ClassCastException.class, () -> cachedRs.getObject(1, Integer.class)); + } + + @Test + void test_get_object_with_label_and_type() throws SQLException { + // Setup single column with string metadata (mixed data types) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: string, integer, boolean, HashSet (unsupported type), null + HashSet testSet = new HashSet<>(); + testSet.add("item1"); + testSet.add("item2"); + + when(mockResultSet.getObject(1)).thenReturn("test", 123, true, testSet, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid type conversions + assertTrue(cachedRs.next()); + assertEquals("test", cachedRs.getObject("fieldString", String.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Integer.valueOf(123), cachedRs.getObject("fieldString", Integer.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Boolean.TRUE, cachedRs.getObject("fieldString", Boolean.class)); + assertFalse(cachedRs.wasNull()); + + // Test unsupported data type (HashSet) - should work with getObject() + assertTrue(cachedRs.next()); + HashSet retrievedSet = cachedRs.getObject("fieldString", HashSet.class); + assertEquals(testSet, retrievedSet); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getObject("fieldString", String.class)); + assertTrue(cachedRs.wasNull()); + + // Test invalid type conversion (should throw ClassCastException) + cachedRs.beforeFirst(); + // Wraps around + assertTrue(cachedRs.next()); + assertThrows(ClassCastException.class, () -> cachedRs.getObject(1, Integer.class)); + } + + @Test + void test_unwrap() throws SQLException { + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid unwrap to ResultSet interface + ResultSet unwrappedResultSet = cachedRs.unwrap(ResultSet.class); + assertSame(cachedRs, unwrappedResultSet); + + // Test valid unwrap to CachedResultSet class + CachedResultSet unwrappedCachedResultSet = cachedRs.unwrap(CachedResultSet.class); + assertSame(cachedRs, unwrappedCachedResultSet); + + // Test invalid unwrap attempts should throw SQLException + assertThrows(SQLException.class, () -> cachedRs.unwrap(String.class)); + assertThrows(SQLException.class, () -> cachedRs.unwrap(Integer.class)); + } + + @Test + void test_is_wrapper_for() throws SQLException { + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid wrapper checks + assertTrue(cachedRs.isWrapperFor(ResultSet.class)); + assertTrue(cachedRs.isWrapperFor(CachedResultSet.class)); + + // Test invalid wrapper checks + assertFalse(cachedRs.isWrapperFor(String.class)); + assertFalse(cachedRs.isWrapperFor(Integer.class)); + + // Test null class parameter + assertFalse(cachedRs.isWrapperFor(null)); + } +} + diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java new file mode 100644 index 000000000..7340ac1b0 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java @@ -0,0 +1,174 @@ +package software.amazon.jdbc.plugin.cache; + +import org.junit.jupiter.api.Test; +import org.w3c.dom.*; +import org.xml.sax.Attributes; +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.DefaultHandler; +import java.io.InputStream; +import java.io.Reader; +import java.sql.SQLException; +import java.sql.SQLXML; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; + +import static org.junit.jupiter.api.Assertions.*; + +public class CachedSQLXMLTest { + + @Test + void test_basic_XML() throws Exception { + String xml = "Value AValue B"; + SQLXML sqlxml = new CachedSQLXML(xml); + assertEquals(xml, sqlxml.getString()); + + // Test binary stream + byte[] array = new byte[100]; + InputStream stream = sqlxml.getBinaryStream(); + assertEquals(xml.length(), stream.available()); + assertTrue(stream.read(array) > 0); + assertEquals(xml, new String(array, 0, xml.length())); + stream.close(); + + // Test character stream + char[] chars = new char[100]; + Reader reader = sqlxml.getCharacterStream(); + assertTrue(reader.read(chars) > 0); + assertEquals(xml, new String(chars, 0, xml.length())); + reader.close(); + + // Test free() + sqlxml.free(); + assertThrows(SQLException.class, sqlxml::getString); + assertThrows(SQLException.class, sqlxml::getCharacterStream); + assertThrows(SQLException.class, sqlxml::getBinaryStream); + assertThrows(SQLException.class, () -> sqlxml.getSource(DOMSource.class)); + } + + private void validateDOMElement(Document document, String elementName, String elementValue) { + NodeList elements = document.getElementsByTagName(elementName); + assertEquals(1, elements.getLength()); + Element element = (Element) elements.item(0); + assertEquals(elementName, element.getNodeName()); + assertEquals(elementValue, element.getTextContent()); + } + + private void validateSimpleDocument(Document document) { + Element rootElement = document.getDocumentElement(); + assertEquals("product", rootElement.getNodeName()); + NodeList elements = document.getElementsByTagName("product"); + assertEquals(1, elements.getLength()); // product has 3 elements + elements = document.getElementsByTagName("specs"); + assertEquals(1, elements.getLength()); // specs has 3 elements + validateDOMElement(document, "manufacturer", "TechCorp"); + validateDOMElement(document, "cpu", "Intel i7"); + validateDOMElement(document, "ram", "16GB"); + validateDOMElement(document, "storage", "512GB SSD"); + validateDOMElement(document, "price", "1200.00"); + } + + static private void validateDocElements(String name, String value) { + if (name.equalsIgnoreCase("manufacturer")) { + assertEquals("TechCorp", value); + } else if (name.equalsIgnoreCase("cpu")) { + assertEquals("Intel i7", value); + } else if (name.equalsIgnoreCase("ram")) { + assertEquals("16GB", value); + } else if (name.equalsIgnoreCase("storage")) { + assertEquals("512GB SSD", value); + } else if (name.equalsIgnoreCase("price")) { + assertEquals("1200.00", value); + } + } + + static private class XmlReaderContentHandler extends DefaultHandler { + private StringBuilder currentValue; + + @Override + public void startElement(String uri, String localName, String qName, Attributes attributes) { + currentValue = new StringBuilder(); // Reset for each new element + } + + @Override + public void endElement(String uri, String localName, String qName) { + // Verify the element's value + String value = currentValue.toString().trim(); + validateDocElements(qName, value); + } + + @Override + public void characters(char[] ch, int start, int length) { + currentValue.append(ch, start, length); + } + } + + @Test + void test_getSource_XML() throws Exception { + // Test parsing a more complex XML via getSource() + String xml = " \n" + + "\n" + + " TechCorp\n\n" + + "\n" + + " Intel i7\n" + + " 16GB\n" + + " 512GB SSD\n" + + "\n" + + " 1200.00\n" + + "\n"; + SQLXML sqlxml = new CachedSQLXML(xml); + assertEquals(xml, sqlxml.getString()); + + // DOM source + DOMSource domSource = sqlxml.getSource(null); + Node node = domSource.getNode(); + assertEquals(Node.DOCUMENT_NODE, node.getNodeType()); + validateSimpleDocument((Document) node); + domSource = sqlxml.getSource(DOMSource.class); + node = domSource.getNode(); + assertEquals(Node.DOCUMENT_NODE, node.getNodeType()); + validateSimpleDocument((Document) node); + + // SAX source + SAXSource src = sqlxml.getSource(SAXSource.class); + XMLReader xmlReader = src.getXMLReader(); + xmlReader.setContentHandler(new XmlReaderContentHandler()); + xmlReader.parse(src.getInputSource()); + + // Streams source + StreamSource xmlSource = sqlxml.getSource(StreamSource.class); + DocumentBuilder db = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + Document doc = db.parse(new InputSource(xmlSource.getReader())); + doc.getDocumentElement().normalize(); + validateSimpleDocument(doc); + + // StAX Source + StAXSource staxSource = sqlxml.getSource(StAXSource.class); + XMLStreamReader sReader = staxSource.getXMLStreamReader(); + String elementName = ""; + StringBuilder elementValue = new StringBuilder(); + while (sReader.hasNext()) { + int event = sReader.next(); + if (event == XMLStreamReader.START_ELEMENT) { + elementName = sReader.getLocalName(); + } else if (event == XMLStreamReader.CHARACTERS) { + elementValue.append(sReader.getText()); + } else if (event == XMLStreamReader.END_ELEMENT) { + validateDocElements(elementName, elementValue.toString().trim()); + elementName = ""; + elementValue = new StringBuilder(); + } + } + sReader.close(); // Close the reader when done + + // Invalid source class + assertThrows(SQLException.class, () -> sqlxml.getSource(Source.class)); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java similarity index 89% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java index 46e27337f..cd307d2bf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.anyString; @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; -class DataCacheConnectionPluginTest { +class DataLocalCacheConnectionPluginTest { private static final Properties props = new Properties(); @@ -55,8 +55,8 @@ class DataCacheConnectionPluginTest { @BeforeEach void setUp() throws SQLException { closeable = MockitoAnnotations.openMocks(this); - props.setProperty(DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, "foo"); - DataCacheConnectionPlugin.clearCache(); + props.setProperty(DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, "foo"); + DataLocalCacheConnectionPlugin.clearCache(); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); @@ -82,7 +82,7 @@ void cleanUp() throws Exception { void test_execute_withEmptyCache() throws SQLException { final String methodName = "Statement.executeQuery"; - final DataCacheConnectionPlugin plugin = new DataCacheConnectionPlugin(mockPluginService, props); + final DataLocalCacheConnectionPlugin plugin = new DataLocalCacheConnectionPlugin(mockPluginService, props); final ResultSet rs = plugin.execute( ResultSet.class, @@ -99,7 +99,7 @@ void test_execute_withEmptyCache() throws SQLException { void test_execute_withCache() throws Exception { final String methodName = "Statement.executeQuery"; - final DataCacheConnectionPlugin plugin = new DataCacheConnectionPlugin(mockPluginService, props); + final DataLocalCacheConnectionPlugin plugin = new DataLocalCacheConnectionPlugin(mockPluginService, props); when(mockCallable.call()).thenReturn(mockResult1, mockResult2); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java new file mode 100644 index 000000000..54b466e50 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -0,0 +1,535 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.*; +import java.util.Optional; +import java.util.Properties; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class DataRemoteCachePluginTest { + private Properties props; + private final String methodName = "Statement.executeQuery"; + private AutoCloseable closeable; + + private DataRemoteCachePlugin plugin; + @Mock PluginService mockPluginService; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TelemetryCounter mockCacheHitCounter; + @Mock TelemetryCounter mockCacheMissCounter; + @Mock TelemetryCounter mockTotalQueryCounter; + @Mock TelemetryCounter mockMalformedHintCounter; + @Mock TelemetryCounter mockCacheBypassCounter; + @Mock TelemetryContext mockTelemetryContext; + @Mock ResultSet mockResult1; + @Mock Statement mockStatement; + @Mock ResultSetMetaData mockMetaData; + @Mock Connection mockConnection; + @Mock SessionStateService mockSessionStateService; + @Mock DatabaseMetaData mockDbMetadata; + @Mock CacheConnection mockCacheConn; + @Mock JdbcCallable mockCallable; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter("JdbcCachedQueryCount")).thenReturn(mockCacheHitCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheMissCount")).thenReturn(mockCacheMissCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheTotalQueryCount")).thenReturn(mockTotalQueryCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheMalformedQueryHint")).thenReturn(mockMalformedHintCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheBypassCount")).thenReturn(mockCacheBypassCounter); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockResult1.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getColumnCount()).thenReturn(1); + when(mockMetaData.getColumnLabel(1)).thenReturn("fooName"); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void test_getTTLFromQueryHint() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Null and empty query hint content are not cacheable + assertNull(plugin.getTtlForQuery(null)); + assertNull(plugin.getTtlForQuery("")); + assertNull(plugin.getTtlForQuery(" ")); + // Valid CACHE_PARAM cases - these are the hint contents after /*+ and before */ + assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s)")); + assertEquals(100, plugin.getTtlForQuery("CACHE_PARAM(ttl=100s)")); + assertEquals(35, plugin.getTtlForQuery("CACHE_PARAM(ttl=35s)")); + + // Case insensitive + assertEquals(200, plugin.getTtlForQuery("cache_param(ttl=200s)")); + assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150s)")); + assertEquals(200, plugin.getTtlForQuery("cache_param(tTl=200s)")); + assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150S)")); + assertEquals(200, plugin.getTtlForQuery("cache_param(TTL=200S)")); + + // CACHE_PARAM anywhere in hint content (mixed with other hint directives) + assertEquals(250, plugin.getTtlForQuery("INDEX(table1 idx1) CACHE_PARAM(ttl=250s)")); + assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM(ttl=200s) USE_NL(t1 t2)")); + assertEquals(180, plugin.getTtlForQuery("FIRST_ROWS(10) CACHE_PARAM(ttl=180s) PARALLEL(4)")); + assertEquals(200, plugin.getTtlForQuery("foo=bar,CACHE_PARAM(ttl=200s),baz=qux")); + + // Whitespace handling + assertEquals(400, plugin.getTtlForQuery("CACHE_PARAM( ttl=400s )")); + assertEquals(500, plugin.getTtlForQuery("CACHE_PARAM(ttl = 500s)")); + assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM( ttl = 200s , key = test )")); + + // Invalid cases - no CACHE_PARAM in hint content + assertNull(plugin.getTtlForQuery("INDEX(table1 idx1)")); + assertNull(plugin.getTtlForQuery("FIRST_ROWS(100)")); + assertNull(plugin.getTtlForQuery("cachettl=300s")); // old format + assertNull(plugin.getTtlForQuery("NO_CACHE")); + + // Missing parentheses + assertNull(plugin.getTtlForQuery("CACHE_PARAM ttl=300s")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300s")); + + // Multiple parameters (future-proofing) + assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s, key=test)")); + + // Large TTL values should work + assertEquals(999999, plugin.getTtlForQuery("CACHE_PARAM(ttl=999999s)")); + assertEquals(86400, plugin.getTtlForQuery("CACHE_PARAM(ttl=86400s)")); // 24 hours + } + + @Test + void test_getTTLFromQueryHint_MalformedHints() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Test malformed cases + assertNull(plugin.getTtlForQuery("CACHE_PARAM()")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=abc)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300)")); // missing 's' + + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(invalid_format)")); + + // Invalid TTL values (negative and zero) does not count toward malformed hints + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=0s)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-10s)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-1s)")); + + // Verify counter was incremented 8 times (5 original + 3 new) + verify(mockMalformedHintCounter, times(5)).inc(); + } + + @Test + void test_execute_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"select * from mytable where ID = 2"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockPluginService).isInTransaction(); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); + } + + @Test + void test_execute_noCachingLongQuery() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); + } + + @Test + void test_execute_cachingMissAndHit() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()).thenReturn(Optional.of("mysql")); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("mysql"); + when(mockConnection.getSchema()).thenReturn(null); + when(mockDbMetadata.getUserName()).thenReturn("user1@1.1.1.1"); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(serializedTestResultSet); + + ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); + + assertTrue(rs2.next()); + assertEquals("bar1", rs2.getString("fooName")); + assertFalse(rs2.next()); + verify(mockPluginService, times(3)).getCurrentConnection(); + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCacheConn, times(2)).readFromCache("mysql_null_user1_select * from A"); + verify(mockPluginService, times(3)).getSessionStateService(); + verify(mockSessionStateService, times(3)).getCatalog(); + verify(mockSessionStateService, times(3)).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setCatalog("mysql"); + verify(mockDbMetadata).getUserName(); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("mysql_null_user1_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(1)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + // First call: Cache miss + Database call + verify(mockTelemetryFactory, times(2)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Cache context calls: 1 miss (setSuccess(false)) + 1 hit (setSuccess(true)) + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(1)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(3)).closeContext(); + } + + @Test + void test_transaction_cacheQuery() throws Exception { + props.setProperty("user", "dbuser"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("postgres"); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s) */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getSchema(); + verify(mockSessionStateService).getCatalog(); + verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("public"); + verify(mockSessionStateService).setCatalog("postgres"); + verify(mockDbMetadata, never()).getUserName(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("postgres_public_dbuser_select * from T"), any(), eq(300)); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_transaction_cacheQuery_multiple_query_params() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockDbMetadata.getUserName()).thenReturn("dbuser"); + when(mockConnection.getCatalog()).thenReturn(null); + when(mockConnection.getSchema()).thenReturn("mysql"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s, otherParam=abc) */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getCatalog(); + verify(mockSessionStateService).getSchema(); + verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("mysql"); + verify(mockDbMetadata).getUserName(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("null_mysql_dbuser_select * from T"), any(), eq(300)); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_transaction_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockCallable.call()).thenReturn(mockResult1); + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + "Statement.execute", mockCallable, new String[]{"delete from mytable"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_JdbcCacheBypassCount_malformed_hint() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Setup - not in transaction with malformed cache hint + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + // Query with malformed cache hint - should increment both malformed and bypass counters + String queryWithMalformedHint = "/*+ CACHE_PARAM(ttl=invalid) */ SELECT * FROM users WHERE id = 123"; + plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{queryWithMalformedHint}); + // Verify malformed counter incremented first + verify(mockMalformedHintCounter, times(1)).inc(); + // Verify bypass counter incremented (because configuredQueryTtl becomes null) + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify cache flow counters were NOT called + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_JdbcCacheBypassCount_double_bypass_prevention() throws Exception { + props.setProperty("user", "testuser"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Setup - query that meets MULTIPLE bypass conditions + when(mockPluginService.isInTransaction()).thenReturn(true); // Bypass condition #1 + when(mockCallable.call()).thenReturn(mockResult1); + + // Mock result set for caching + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("testdata"); + + // Query that is BOTH too large AND in transaction - double bypass conditions + String largeQueryInTransaction = "/*+ CACHE_PARAM(ttl=300s) */ SELECT * FROM table WHERE data = '" + + RandomStringUtils.randomAlphanumeric(16384) + "'"; // >16KB AND in transaction + + // Execute + plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{largeQueryInTransaction}); + + // Verify bypass counter incremented EXACTLY ONCE (not twice) + verify(mockCacheBypassCounter, times(1)).inc(); + + // Verify cache flow counters were NOT called + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + + // Verify malformed counter not called (hint is valid, just large query) + verify(mockMalformedHintCounter, never()).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_execute_multipleCacheHits() throws Exception { + props.setProperty("user", "user"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()).thenReturn(Optional.of("public")); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockConnection.getCatalog()).thenReturn(null); + when(mockCacheConn.readFromCache("null_public_user_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("null_public_user_select * from A")).thenReturn(serializedTestResultSet); + + for (int i = 0; i < 10; i ++) { + ResultSet cur_rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); + + assertTrue(cur_rs.next()); + assertEquals("bar1", cur_rs.getString("fooName")); + assertFalse(cur_rs.next()); + } + + verify(mockPluginService, times(12)).getCurrentConnection(); + verify(mockPluginService, times(11)).isInTransaction(); + verify(mockCacheConn, times(11)).readFromCache("null_public_user_select * from A"); + verify(mockPluginService, times(12)).getSessionStateService(); + verify(mockSessionStateService, times(12)).getCatalog(); + verify(mockSessionStateService, times(12)).getSchema(); + verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("public"); + verify(mockDbMetadata, never()).getUserName(); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("null_public_user_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(11)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(10)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + verify(mockTelemetryFactory, times(11)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(10)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(12)).closeContext(); + } + + void compareResults(final ResultSet expected, final ResultSet actual) throws SQLException { + int i = 1; + while (expected.next() && actual.next()) { + assertEquals(expected.getObject(i), actual.getObject(i)); + i++; + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java new file mode 100644 index 000000000..4eeb24e19 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java @@ -0,0 +1,193 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.regions.Region; + +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; + +public class ElastiCacheIamTokenUtilityTest { + @Mock private AwsCredentialsProvider mockCredentialsProvider; + @Mock private AwsCredentials mockCredentials; + @Mock private Aws4Signer mockSigner; + @Mock private SdkHttpFullRequest mockSignedRequest; + + private AutoCloseable closeable; + private ElastiCacheIamTokenUtility tokenUtility; + private final Instant fixedInstant = Instant.parse("2025-01-01T12:00:00Z"); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + void testConstructor_WithCacheName() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache"); + assertNotNull(tokenUtility); + } + + @Test + void testConstructor_WithCacheNameAndFixedInstant() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + assertNotNull(tokenUtility); + } + + @Test + void testConstructor_NullCacheName() { + assertThrows(NullPointerException.class, () -> + new ElastiCacheIamTokenUtility(null)); + } + + @Test + void testConstructor_NullCacheNameWithInstant() { + assertThrows(NullPointerException.class, () -> + new ElastiCacheIamTokenUtility(null, fixedInstant, mockSigner)); + } + + @Test + void testGenerateAuthenticationToken_RegularCache() { + // Setup mock credentials provider to return mockCredentials + when(mockCredentialsProvider.resolveIdentity()) + .thenReturn((CompletableFuture) CompletableFuture.completedFuture(mockCredentials)); + + // Add custom presign behavior to capture and verify arguments + when(mockSigner.presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class))) + .thenAnswer(invocation -> { + SdkHttpFullRequest request = invocation.getArgument(0); + Aws4PresignerParams presignParams = invocation.getArgument(1); + + // Verify SdkHttpFullRequest + assertEquals("test-cache", request.host()); + assertEquals("/", request.encodedPath()); + assertEquals("connect", request.rawQueryParameters().get("Action").get(0)); + assertEquals("testuser", request.rawQueryParameters().get("User").get(0)); + assertFalse(request.rawQueryParameters().containsKey("ResourceType")); + + // Verify Aws4PresignerParams + assertEquals("elasticache", presignParams.signingName()); + assertEquals(Region.US_EAST_1, presignParams.signingRegion()); + assertEquals(mockCredentials, presignParams.awsCredentials()); + + Instant expectedExpiration = fixedInstant.plus(Duration.ofSeconds(15 * 60 - 30)); + assertEquals(expectedExpiration, presignParams.expirationTime().get()); + assertEquals(fixedInstant, presignParams.signingClockOverride().get().instant()); + + return mockSignedRequest; + }); + + when(mockSignedRequest.getUri()).thenReturn(java.net.URI.create("http://test-cache/result")); + + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + String token = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-cache.cache.amazonaws.com", 6379, "testuser"); + + assertEquals("test-cache/result", token); + verify(mockSigner).presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class)); + } + + @Test + void testGenerateAuthenticationToken_ServerlessCache() { + // Setup mock credentials provider to return mockCredentials + when(mockCredentialsProvider.resolveIdentity()) + .thenReturn((CompletableFuture) CompletableFuture.completedFuture(mockCredentials)); + + // Add custom presign behavior to capture and verify arguments + when(mockSigner.presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class))) + .thenAnswer(invocation -> { + SdkHttpFullRequest request = invocation.getArgument(0); + Aws4PresignerParams presignParams = invocation.getArgument(1); + + // Verify SdkHttpFullRequest + assertEquals("test-cache", request.host()); + assertEquals("/", request.encodedPath()); + assertEquals("connect", request.rawQueryParameters().get("Action").get(0)); + assertEquals("testuser", request.rawQueryParameters().get("User").get(0)); + assertEquals("ServerlessCache", request.rawQueryParameters().get("ResourceType").get(0)); + + // Verify Aws4PresignerParams + assertEquals("elasticache", presignParams.signingName()); + assertEquals(Region.US_EAST_1, presignParams.signingRegion()); + assertEquals(mockCredentials, presignParams.awsCredentials()); + + Instant expectedExpiration = fixedInstant.plus(Duration.ofSeconds(15 * 60 - 30)); + assertEquals(expectedExpiration, presignParams.expirationTime().get()); + assertEquals(fixedInstant, presignParams.signingClockOverride().get().instant()); + + return mockSignedRequest; + }); + + when(mockSignedRequest.getUri()).thenReturn(java.net.URI.create("http://test-cache.serverless.cache.amazonaws.com/result")); + + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + String token = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-cache.serverless.cache.amazonaws.com", 6379, "testuser"); + + assertEquals("test-cache.serverless.cache.amazonaws.com/result", token); + verify(mockSigner).presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class)); + } + + @Test + void testGenerateAuthenticationToken_NullCacheName() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + // Use reflection to set cacheName to null to test the validation + try { + java.lang.reflect.Field field = ElastiCacheIamTokenUtility.class.getDeclaredField("cacheName"); + field.setAccessible(true); + field.set(tokenUtility, null); + + assertThrows(IllegalArgumentException.class, () -> + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-host", 6379, "testuser")); + } catch (Exception e) { + fail("Reflection failed: " + e.getMessage()); + } + } + + @Test + void testGenerateAuthenticationToken_NullHostname() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + assertThrows(IllegalArgumentException.class, () -> + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, null, 6379, "testuser")); + } +}