From 5a33638e698e5a16d42d3225170afc3e942a51bd Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:14:57 +0900 Subject: [PATCH 001/215] Change latest version to 0.5.x --- project/Versions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/Versions.scala b/project/Versions.scala index 29340d45a..1e86f4412 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -3,7 +3,8 @@ */ object LdbcVersions { - val latest = "0.4" + val latest = "0.5" + val v04 = "0.4" val v03 = "0.3" val v02 = "0.2" val v01 = "0.1" From 8d7149917f6757d47e303db901d9beeb72831a56 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:15:34 +0900 Subject: [PATCH 002/215] Adde 0.4.x in old version --- docs/src/main/mdoc/older-versions.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/main/mdoc/older-versions.md b/docs/src/main/mdoc/older-versions.md index 49002545f..7ba431cbe 100644 --- a/docs/src/main/mdoc/older-versions.md +++ b/docs/src/main/mdoc/older-versions.md @@ -17,13 +17,16 @@ and now also more API stability than those older releases. | Release | Date | Scala Versions | sbt plugin | Scala.js | Scala Native | Doc Sources | API (Scaladoc) | |---------|----------|----------------|------------|----------|--------------|-----------------|-----------------| +| 0.4 | Oct 2025 | 3.3.x | 1.x | | | [Browse][doc04] | [Browse][api04] | | 0.3 | May 2025 | 3.3.x | 1.x | | | [Browse][doc03] | [Browse][api03] | | 0.2 | Mar 2024 | 3.3.x | 1.x | | | [Browse][doc02] | [Browse][api02] | | 0.1 | Dec 2023 | 3.3.x | 1.x | | | | [Browse][api01] | [doc02]: https://takapi327.github.io/ldbc/0.2/ [doc03]: https://takapi327.github.io/ldbc/0.3/ +[doc04]: https://takapi327.github.io/ldbc/0.4/ +[api04]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.4.0/index.html [api03]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.3.3/index.html [api02]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.2.1/index.html [api01]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.1.1/index.html From 3ee4f3e9fac8ae2b45ef7641b6df0338e89ce3b4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:16:20 +0900 Subject: [PATCH 003/215] Added 0.5.x version --- project/LaikaSettings.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/project/LaikaSettings.scala b/project/LaikaSettings.scala index 905808007..57c8bf44d 100644 --- a/project/LaikaSettings.scala +++ b/project/LaikaSettings.scala @@ -31,9 +31,10 @@ object LaikaSettings { val v02: Version = version(LdbcVersions.v02) val v03: Version = version(LdbcVersions.v03) - val v04: Version = version(LdbcVersions.latest, "Stable") - val current: Version = v04 - val all: Seq[Version] = Seq(v04, v03, v02) + val v04: Version = version(LdbcVersions.v04, "Stable") + val v05: Version = version(LdbcVersions.latest, "Dev") + val current: Version = v05 + val all: Seq[Version] = Seq(v05, v04, v03, v02) val config: Versions = Versions .forCurrentVersion(current) From d8a804370011136e88d4c3ef30b54af56d052744 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:20:13 +0900 Subject: [PATCH 004/215] Update use version --- .../shared/src/main/scala/ldbc/connector/data/Constants.scala | 2 +- .../src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala index 64b2931e4..367749c75 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala @@ -14,4 +14,4 @@ import ldbc.connector.util.Version object Constants: val DRIVER_NAME: String = "MySQL Connector/L" - val DRIVER_VERSION: Version = Version(0, 4, 0) + val DRIVER_VERSION: Version = Version(0, 5, 0) diff --git a/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala b/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala index 1732f1e82..fe95c7cf0 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala @@ -160,7 +160,7 @@ trait DatabaseMetaDataTest extends CatsEffectSuite: yield metaData.getDriverVersion() }, if prefix == "jdbc" then "mysql-connector-j-8.4.0 (Revision: 1c3f5c149e0bfe31c7fbeb24e2d260cd890972c4)" - else "ldbc-connector-0.4.0" + else "ldbc-connector-0.5.0" ) } @@ -180,7 +180,7 @@ trait DatabaseMetaDataTest extends CatsEffectSuite: for metaData <- conn.getMetaData() yield metaData.getDriverMinorVersion() }, - if prefix == "jdbc" then 4 else 4 + if prefix == "jdbc" then 4 else 5 ) } From 1f59b17fd041b215dc38effd441760027fa0a132 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:25:35 +0900 Subject: [PATCH 005/215] Delete ldbc-hikari project --- build.sbt | 18 +- .../scala/ldbc/hikari/Configuration.scala | 105 ------- .../ldbc/hikari/HikariConfigBuilder.scala | 293 ------------------ .../ldbc/hikari/HikariDataSourceBuilder.scala | 62 ---- .../src/test/resources/reference.conf | 53 ---- .../ldbc/hikari/HikariConfigBuilderTest.scala | 109 ------- .../main/scala/ldbc/sbt/Dependencies.scala | 1 - 7 files changed, 3 insertions(+), 638 deletions(-) delete mode 100644 module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala delete mode 100644 module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala delete mode 100644 module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala delete mode 100644 module/ldbc-hikari/src/test/resources/reference.conf delete mode 100644 module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala diff --git a/build.sbt b/build.sbt index 02d01d8ae..124d4e8fd 100644 --- a/build.sbt +++ b/build.sbt @@ -145,18 +145,6 @@ lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) .nativeSettings(Test / nativeBrewFormulas += "s2n") .dependsOn(core) -lazy val hikari = LepusSbtProject("ldbc-hikari", "module/ldbc-hikari") - .settings(description := "Project to build HikariCP") - .settings( - onLoadMessage := s"${ scala.Console.RED }WARNING: This project is deprecated and will be removed in future versions.${ scala.Console.RESET }", - libraryDependencies ++= Seq( - catsEffect, - typesafeConfig, - hikariCP - ) ++ specs2 - ) - .dependsOn(dsl.jvm) - lazy val plugin = LepusSbtPluginProject("ldbc-plugin", "plugin") .settings(description := "Projects that provide sbt plug-ins") .settings((Compile / sourceGenerators) += Def.task { @@ -215,10 +203,11 @@ lazy val benchmark = (project in file("benchmark")) scala3Compiler, mysql, doobie, - slick + slick, + hikariCP, ) ) - .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm, hikari) + .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm) .enablePlugins(JmhPlugin, AutomateHeaderPlugin, NoPublishPlugin) lazy val http4sExample = crossProject(JVMPlatform) @@ -388,7 +377,6 @@ lazy val ldbc = tlCrossRootProject tests, docs, benchmark, - hikari, mcpDocumentServer ) .aggregate(examples *) diff --git a/module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala b/module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala deleted file mode 100644 index 9c1c852e8..000000000 --- a/module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import java.time.Duration as JavaDuration - -import scala.concurrent.duration.{ Duration, FiniteDuration, * } -import scala.jdk.CollectionConverters.* - -import com.typesafe.config.* - -case class Configuration(config: Config): - def get[A](path: String)(using loader: ConfigLoader[A]): A = - loader.load(config, path) - -object Configuration: - - def load( - classLoader: ClassLoader, - directSettings: Map[String, String] - ): Configuration = - try - val directConfig: Config = ConfigFactory.parseMap(directSettings.asJava) - val config: Config = ConfigFactory.load(classLoader, directConfig) - Configuration(config) - catch case e: ConfigException => throw new Exception(e.getMessage) - - def load(): Configuration = Configuration(ConfigFactory.load()) - -trait ConfigLoader[A]: - self => - - def load(config: Config, path: String): A - - def map[B](f: A => B): ConfigLoader[B] = (config: Config, path: String) => f(self.load(config, path)) - -object ConfigLoader: - def apply[A](f: Config => String => A): ConfigLoader[A] = - (config: Config, path: String) => f(config)(path) - - given ConfigLoader[String] = ConfigLoader(_.getString) - given ConfigLoader[Int] = ConfigLoader(_.getInt) - given ConfigLoader[Long] = ConfigLoader(_.getLong) - given ConfigLoader[Number] = ConfigLoader(_.getNumber) - given ConfigLoader[Double] = ConfigLoader(_.getDouble) - given ConfigLoader[Boolean] = ConfigLoader(_.getBoolean) - given ConfigLoader[ConfigMemorySize] = ConfigLoader(_.getMemorySize) - given ConfigLoader[FiniteDuration] = ConfigLoader(_.getDuration).map(_.toNanos.nanos) - given ConfigLoader[JavaDuration] = ConfigLoader(_.getDuration) - given ConfigLoader[Duration] = ConfigLoader(config => - path => - if config.getIsNull(path) then Duration.Inf - else config.getDuration(path).toNanos.nanos - ) - - given seqBoolean: ConfigLoader[Seq[Boolean]] = - ConfigLoader(_.getBooleanList).map(_.asScala.map(_.booleanValue).toSeq) - given seqInt: ConfigLoader[Seq[Int]] = - ConfigLoader(_.getIntList).map(_.asScala.map(_.toInt).toSeq) - given seqLong: ConfigLoader[Seq[Long]] = - ConfigLoader(_.getDoubleList).map(_.asScala.map(_.longValue).toSeq) - given seqNumber: ConfigLoader[Seq[Number]] = - ConfigLoader(_.getNumberList).map(_.asScala.toSeq) - given seqDouble: ConfigLoader[Seq[Double]] = - ConfigLoader(_.getDoubleList).map(_.asScala.map(_.doubleValue).toSeq) - given seqString: ConfigLoader[Seq[String]] = - ConfigLoader(_.getStringList).map(_.asScala.toSeq) - given seqBytes: ConfigLoader[Seq[ConfigMemorySize]] = - ConfigLoader(_.getMemorySizeList).map(_.asScala.toSeq) - given seqFinite: ConfigLoader[Seq[FiniteDuration]] = - ConfigLoader(_.getDurationList).map(_.asScala.map(_.toNanos.nanos).toSeq) - given seqJavaDuration: ConfigLoader[Seq[JavaDuration]] = - ConfigLoader(_.getDurationList).map(_.asScala.toSeq) - given seqScalaDuration: ConfigLoader[Seq[Duration]] = - ConfigLoader(_.getDurationList).map(_.asScala.map(_.toNanos.nanos).toSeq) - given seqConfig: ConfigLoader[Seq[Config]] = - ConfigLoader(_.getConfigList).map(_.asScala.toSeq) - given seqConfiguration: ConfigLoader[Seq[Configuration]] = - summon[ConfigLoader[Seq[Config]]].map(_.map(Configuration(_))) - - given ConfigLoader[Config] = ConfigLoader(_.getConfig) - given ConfigLoader[ConfigObject] = ConfigLoader(_.getObject) - given ConfigLoader[ConfigList] = ConfigLoader(_.getList) - given ConfigLoader[Configuration] = summon[ConfigLoader[Config]].map(Configuration(_)) - - given [A](using loader: ConfigLoader[A]): ConfigLoader[Option[A]] with - override def load(config: Config, path: String): Option[A] = - if config.hasPath(path) && !config.getIsNull(path) then Some(loader.load(config, path)) - else None - - given [A](using loader: ConfigLoader[A]): ConfigLoader[Map[String, A]] with - override def load(config: Config, path: String): Map[String, A] = - val obj = config.getObject(path) - val conf = obj.toConfig - obj - .keySet() - .asScala - .map { key => - key -> loader.load(conf, key) - } - .toMap diff --git a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala b/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala deleted file mode 100644 index 6769ee53c..000000000 --- a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala +++ /dev/null @@ -1,293 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import java.util.concurrent.{ ScheduledExecutorService, ThreadFactory, TimeUnit } -import java.util.Properties - -import javax.sql.DataSource as JDataSource - -import scala.concurrent.duration.Duration - -import com.zaxxer.hikari.metrics.MetricsTrackerFactory -import com.zaxxer.hikari.HikariConfig - -/** - * Build the Configuration of HikariCP. - */ -trait HikariConfigBuilder: - - protected val config: Configuration = Configuration.load() - - protected val path: String = "ldbc.hikari" - - /** List of keys to retrieve from conf file. */ - final private val CATALOG = "catalog" - final private val CONNECTION_TIMEOUT = "connection_timeout" - final private val IDLE_TIMEOUT = "idle_timeout" - final private val LEAK_DETECTION_THRESHOLD = "leak_detection_threshold" - final private val MAXIMUM_POOL_SIZE = "maximum_pool_size" - final private val MAX_LIFETIME = "max_lifetime" - final private val MINIMUM_IDLE = "minimum_idle" - final private val POOL_NAME = "pool_name" - final private val VALIDATION_TIMEOUT = "validation_timeout" - final private val ALLOW_POOL_SUSPENSION = "allow_pool_suspension" - final private val AUTO_COMMIT = "auto_commit" - final private val CONNECTION_INIT_SQL = "connection_init_sql" - final private val CONNECTION_TEST_QUERY = "connection_test_query" - final private val DATA_SOURCE_CLASSNAME = "data_source_classname" - final private val DATASOURCE_JNDI = "datasource_jndi" - final private val INITIALIZATION_FAIL_TIMEOUT = "initialization_fail_timeout" - final private val ISOLATE_INTERNAL_QUERIES = "isolate_internal_queries" - final private val JDBC_URL = "jdbc_url" - final private val READONLY = "readonly" - final private val REGISTER_MBEANS = "register_mbeans" - final private val SCHEMA = "schema" - final private val USERNAME = "username" - final private val PASSWORD = "password" - final private val DRIVER_CLASS_NAME = "driver_class_name" - final private val TRANSACTION_ISOLATION = "transaction_isolation" - - /** Number of application cores */ - private val maxCore: Int = Runtime.getRuntime.availableProcessors() - - /** - * Method to retrieve values matching any key from the conf file from the path configuration, with any type. - * - * @param func - * Process to get values from Configuration wrapped in Option - * @tparam T - * Type of value retrieved from conf file - */ - final private def readConfig[T](func: Configuration => Option[T]): Option[T] = - config.get[Option[Configuration]](path).flatMap(func(_)) - - /** Method to retrieve catalog information from the conf file. */ - private def getCatalog: Option[String] = - readConfig(_.get[Option[String]](CATALOG)) - - /** Method to retrieve connection timeout information from the conf file. */ - private def getConnectionTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](CONNECTION_TIMEOUT)) - - /** Method to retrieve idle timeout information from the conf file. */ - private def getIdleTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](IDLE_TIMEOUT)) - - /** Method to retrieve leak detection threshold information from the conf file. */ - private def getLeakDetectionThreshold: Option[Duration] = - readConfig(_.get[Option[Duration]](LEAK_DETECTION_THRESHOLD)) - - /** Method to retrieve maximum pool size information from the conf file. */ - private def getMaximumPoolSize: Option[Int] = - readConfig(_.get[Option[Int]](MAXIMUM_POOL_SIZE)) - - /** Method to retrieve max life time information from the conf file. */ - private def getMaxLifetime: Option[Duration] = - readConfig(_.get[Option[Duration]](MAX_LIFETIME)) - - /** Method to retrieve minimum idle information from the conf file. */ - private def getMinimumIdle: Option[Int] = - readConfig(_.get[Option[Int]](MINIMUM_IDLE)) - - /** Method to retrieve pool name information from the conf file. */ - private def getPoolName: Option[String] = - readConfig(_.get[Option[String]](POOL_NAME)) - - /** Method to retrieve validation timeout information from the conf file. */ - private def getValidationTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](VALIDATION_TIMEOUT)) - - /** Method to retrieve allow pool suspension information from the conf file. */ - private def getAllowPoolSuspension: Option[Boolean] = - readConfig(_.get[Option[Boolean]](ALLOW_POOL_SUSPENSION)) - - /** Method to retrieve auto commit information from the conf file. */ - private def getAutoCommit: Option[Boolean] = - readConfig(_.get[Option[Boolean]](AUTO_COMMIT)) - - /** Method to retrieve connection init sql information from the conf file. */ - private def getConnectionInitSql: Option[String] = - readConfig(_.get[Option[String]](CONNECTION_INIT_SQL)) - - /** Method to retrieve connection test query information from the conf file. */ - private def getConnectionTestQuery: Option[String] = - readConfig(_.get[Option[String]](CONNECTION_TEST_QUERY)) - - /** Method to retrieve data source class name information from the conf file. */ - private def getDataSourceClassname: Option[String] = - readConfig(_.get[Option[String]](DATA_SOURCE_CLASSNAME)) - - /** Method to retrieve data source jndi information from the conf file. */ - private def getDatasourceJndi: Option[String] = - readConfig(_.get[Option[String]](DATASOURCE_JNDI)) - - /** Method to retrieve initialization fail time out information from the conf file. */ - private def getInitializationFailTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](INITIALIZATION_FAIL_TIMEOUT)) - - /** Method to retrieve isolate internal queries information from the conf file. */ - private def getIsolateInternalQueries: Option[Boolean] = - readConfig(_.get[Option[Boolean]](ISOLATE_INTERNAL_QUERIES)) - - /** Method to retrieve jdbc url information from the conf file. */ - private def getJdbcUrl: Option[String] = - readConfig(_.get[Option[String]](JDBC_URL)) - - /** Method to retrieve readonly information from the conf file. */ - private def getReadonly: Option[Boolean] = - readConfig(_.get[Option[Boolean]](READONLY)) - - /** Method to retrieve register mbeans information from the conf file. */ - private def getRegisterMbeans: Option[Boolean] = - readConfig(_.get[Option[Boolean]](REGISTER_MBEANS)) - - /** Method to retrieve schema information from the conf file. */ - private def getSchema: Option[String] = - readConfig(_.get[Option[String]](SCHEMA)) - - /** Method to retrieve user name information from the conf file. */ - protected def getUserName: Option[String] = - readConfig(_.get[Option[String]](USERNAME)) - - /** Method to retrieve password information from the conf file. */ - protected def getPassword: Option[String] = - readConfig(_.get[Option[String]](PASSWORD)) - - /** Method to retrieve driver class name information from the conf file. */ - protected def getDriverClassName: Option[String] = - readConfig(_.get[Option[String]](DRIVER_CLASS_NAME)) - - /** Method to retrieve transaction isolation information from the conf file. */ - protected def getTransactionIsolation: Option[String] = - readConfig(_.get[Option[String]](TRANSACTION_ISOLATION)).map { v => - if v == "TRANSACTION_NONE" || v == "TRANSACTION_READ_UNCOMMITTED" || v == "TRANSACTION_READ_COMMITTED" || v == "TRANSACTION_REPEATABLE_READ" || v == "TRANSACTION_SERIALIZABLE" - then v - else - throw new IllegalArgumentException( - "TransactionIsolation must be TRANSACTION_NONE,TRANSACTION_READ_UNCOMMITTED,TRANSACTION_READ_COMMITTED,TRANSACTION_REPEATABLE_READ,TRANSACTION_SERIALIZABLE." - ) - } - - /** List of variables predefined as default settings. */ - val connectionTimeout: Long = getConnectionTimeout.getOrElse(Duration(30, TimeUnit.SECONDS)).toMillis - val idleTimeout: Long = getIdleTimeout.getOrElse(Duration(10, TimeUnit.MINUTES)).toMillis - val leakDetectionThreshold: Long = getLeakDetectionThreshold.getOrElse(Duration.Zero).toMillis - val maximumPoolSize: Int = getMaximumPoolSize.getOrElse(maxCore * 2) - val maxLifetime: Long = getMaxLifetime.getOrElse(Duration(30, TimeUnit.MINUTES)).toMillis - val minimumIdle: Int = getMinimumIdle.getOrElse(10) - val validationTimeout: Long = getValidationTimeout.getOrElse(Duration(5, TimeUnit.SECONDS)).toMillis - val allowPoolSuspension: Boolean = getAllowPoolSuspension.getOrElse(false) - val autoCommit: Boolean = getAutoCommit.getOrElse(true) - val initializationFailTimeout: Long = - getInitializationFailTimeout.getOrElse(Duration(1, TimeUnit.MILLISECONDS)).toMillis - val isolateInternalQueries: Boolean = getIsolateInternalQueries.getOrElse(false) - val readonly: Boolean = getReadonly.getOrElse(false) - val registerMbeans: Boolean = getRegisterMbeans.getOrElse(false) - - /** - * Method to generate HikariConfig based on DatabaseConfig and other settings. - * - * @param jDataSource - * Factories for connection to physical data sources - * @param dataSourceProperties - * Properties (name/value pairs) used to configure DataSource/java.sql.Driver - * @param healthCheckProperties - * Properties (name/value pairs) used to configure HealthCheck/java.sql.Driver - * @param healthCheckRegistry - * Set the HealthCheckRegistry that will be used for registration of health checks by HikariCP. - * @param metricRegistry - * Set a MetricRegistry instance to use for registration of metrics used by HikariCP. - * @param metricsTrackerFactory - * Set a MetricsTrackerFactory instance to use for registration of metrics used by HikariCP. - * @param scheduledExecutor - * Set the ScheduledExecutorService used for housekeeping. - * @param threadFactory - * Set the thread factory to be used to create threads. - */ - def build( - jDataSource: Option[JDataSource] = None, - dataSourceProperties: Option[Properties] = None, - healthCheckProperties: Option[Properties] = None, - healthCheckRegistry: Option[Object] = None, - metricRegistry: Option[Object] = None, - metricsTrackerFactory: Option[MetricsTrackerFactory] = None, - scheduledExecutor: Option[ScheduledExecutorService] = None, - threadFactory: Option[ThreadFactory] = None - ): HikariConfig = - - val hikariConfig = new HikariConfig() - - getCatalog foreach hikariConfig.setCatalog - hikariConfig.setConnectionTimeout(connectionTimeout) - hikariConfig.setIdleTimeout(idleTimeout) - hikariConfig.setLeakDetectionThreshold(leakDetectionThreshold) - hikariConfig.setMaximumPoolSize(maximumPoolSize) - hikariConfig.setMaxLifetime(maxLifetime) - hikariConfig.setMinimumIdle(minimumIdle) - hikariConfig.setValidationTimeout(validationTimeout) - hikariConfig.setAllowPoolSuspension(allowPoolSuspension) - hikariConfig.setAutoCommit(autoCommit) - hikariConfig.setInitializationFailTimeout(initializationFailTimeout) - hikariConfig.setIsolateInternalQueries(isolateInternalQueries) - hikariConfig.setReadOnly(readonly) - hikariConfig.setRegisterMbeans(registerMbeans) - - getPassword foreach hikariConfig.setPassword - getPoolName foreach hikariConfig.setPoolName - getUserName foreach hikariConfig.setUsername - getConnectionInitSql foreach hikariConfig.setConnectionInitSql - getConnectionTestQuery foreach hikariConfig.setConnectionTestQuery - getDataSourceClassname foreach hikariConfig.setDataSourceClassName - getDatasourceJndi foreach hikariConfig.setDataSourceJNDI - getDriverClassName foreach hikariConfig.setDriverClassName - getJdbcUrl foreach hikariConfig.setJdbcUrl - getSchema foreach hikariConfig.setSchema - getTransactionIsolation foreach hikariConfig.setTransactionIsolation - - jDataSource foreach hikariConfig.setDataSource - dataSourceProperties foreach hikariConfig.setDataSourceProperties - healthCheckProperties foreach hikariConfig.setHealthCheckProperties - healthCheckRegistry foreach hikariConfig.setHealthCheckRegistry - metricRegistry foreach hikariConfig.setMetricRegistry - metricsTrackerFactory foreach hikariConfig.setMetricsTrackerFactory - scheduledExecutor foreach hikariConfig.setScheduledExecutor - threadFactory foreach hikariConfig.setThreadFactory - - hikariConfig - -object HikariConfigBuilder: - - /** - * Methods for retrieving data from the ldbc default specified path. - * - * {{{ - * ldbc.hikari { - * jdbc_url = ... - * username = ... - * password = ... - * } - * }}} - */ - def default: HikariConfigBuilder = new HikariConfigBuilder {} - - /** - * Methods for retrieving data from a user-specified conf path. - * - * @param confPath - * Path of conf from which user-specified data is to be retrieved - * - * {{{ - * {user path} { - * jdbc_url = ... - * username = ... - * password = ... - * } - * }}} - */ - def from(confPath: String): HikariConfigBuilder = new HikariConfigBuilder: - override protected val path: String = confPath diff --git a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala b/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala deleted file mode 100644 index b4c94770f..000000000 --- a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import com.zaxxer.hikari.{ HikariConfig, HikariDataSource } - -import cats.effect.* -import cats.effect.implicits.* - -/** - * A model for building a database. HikariCP construction, thread pool generation for database connection, test - * connection, etc. are performed via the method. - * - * @tparam F - * the effect type. - */ -trait HikariDataSourceBuilder[F[_]: Sync] extends HikariConfigBuilder: - - /** - * Method for generating HikariDataSource with Resource. - * - * @param factory - * Process to generate HikariDataSource - */ - private def createDataSourceResource(factory: => HikariDataSource): Resource[F, HikariDataSource] = - Resource.fromAutoCloseable(Sync[F].delay(factory)) - - /** - * Method to generate Config for HikariCP. - */ - private def buildConfig(): Resource[F, HikariConfig] = - Sync[F].delay { - val hikariConfig = build() - hikariConfig.validate() - hikariConfig - }.toResource - - /** - * Method to generate DataSource from HikariCPConfig generation. - */ - def buildDataSource(): Resource[F, HikariDataSource] = - for - hikariConfig <- buildConfig() - hikariDataSource <- createDataSourceResource(new HikariDataSource(hikariConfig)) - yield hikariDataSource - - /** - * Methods for generating DataSource from user-generated HikariCPConfig. - * - * @param hikariConfig - * User-generated HikariCP Config file - */ - def buildFromConfig(hikariConfig: HikariConfig): Resource[F, HikariDataSource] = - createDataSourceResource(new HikariDataSource(hikariConfig)) - -object HikariDataSourceBuilder: - - def default[F[_]: Sync]: HikariDataSourceBuilder[F] = new HikariDataSourceBuilder[F] {} diff --git a/module/ldbc-hikari/src/test/resources/reference.conf b/module/ldbc-hikari/src/test/resources/reference.conf deleted file mode 100644 index dcda815d6..000000000 --- a/module/ldbc-hikari/src/test/resources/reference.conf +++ /dev/null @@ -1,53 +0,0 @@ -ldbc.hikari { - catalog = "ldbc" - connection_timeout = 30 s - idle_timeout = 10 m - leak_detection_threshold = 0 - maximum_pool_size = 32 - max_lifetime = 30 m - minimum_idle = 10 - pool_name = "ldbc-pool" - validation_timeout = 5 s - allow_pool_suspension = false - auto_commit = true - connection_init_sql = "select 1" - connection_test_query = "select 1" - data_source_classname = "com.mysql.cj.jdbc.Driver" - datasource_jndi = "" - initialization_fail_timeout = 1 ms - isolate_internal_queries = false - jdbc_url = "jdbc:mysql://127.0.0.1:3306/ldbc" - readonly = false - register_mbeans = false - schema = "ldbc" - username = "ldbc" - password = "mysql" - transaction_isolation = "TRANSACTION_NONE" -} - -ldbc.hikari.failure { - catalog = "ldbc" - connection_timeout = 30 s - idle_timeout = 10 m - leak_detection_threshold = 0 - maximum_pool_size = 32 - max_lifetime = 30 m - minimum_idle = 10 - pool_name = "ldbc-pool" - validation_timeout = 5 s - allow_pool_suspension = false - auto_commit = true - connection_init_sql = "select 1" - connection_test_query = "select 1" - data_source_classname = "com.mysql.cj.jdbc.Driver" - datasource_jndi = "" - initialization_fail_timeout = 1 ms - isolate_internal_queries = false - jdbc_url = "jdbc:mysql://127.0.0.1:3306/ldbc" - readonly = false - register_mbeans = false - schema = "ldbc" - username = "ldbc" - password = "mysql" - transaction_isolation = "test" -} diff --git a/module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala b/module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala deleted file mode 100644 index 0efebb3db..000000000 --- a/module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.specs2.mutable.Specification - -object HikariConfigBuilderTest extends Specification: - - private val hikariConfig = HikariConfigBuilder.default.build() - - "Testing the HikariConfigBuilder" should { - - "The value of the specified key can be retrieved from the conf file" in { - hikariConfig.getCatalog == "ldbc" - } - - "The value of the specified key can be retrieved from the conf file" in { - hikariConfig.getJdbcUrl == "jdbc:mysql://127.0.0.1:3306/ldbc" - } - - "The connection_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getConnectionTimeout == Duration(30, TimeUnit.SECONDS).toMillis - } - - "The idle_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getIdleTimeout == Duration(10, TimeUnit.MINUTES).toMillis - } - - "The leak_detection_threshold setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getLeakDetectionThreshold == Duration.Zero.toMillis - } - - "The maximum_pool_size setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getMaximumPoolSize == 32 - } - - "The max_lifetime setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getMaxLifetime == Duration(30, TimeUnit.MINUTES).toMillis - } - - "The minimum_idle setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getMinimumIdle == 10 - } - - "The pool_name setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getPoolName == "ldbc-pool" - } - - "The validation_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getValidationTimeout == Duration(5, TimeUnit.SECONDS).toMillis - } - - "The connection_init_sql setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getConnectionInitSql == "select 1" - } - - "The connection_test_query setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getConnectionTestQuery == "select 1" - } - - "The connection_test_query setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getDataSourceClassName == "com.mysql.cj.jdbc.Driver" - } - - "The datasource_jndi setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getDataSourceJNDI == "" - } - - "The initialization_fail_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getInitializationFailTimeout == Duration(1, TimeUnit.MILLISECONDS).toMillis - } - - "The jdbc_url setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getJdbcUrl == "jdbc:mysql://127.0.0.1:3306/ldbc" - } - - "The schema setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getSchema == "ldbc" - } - - "The username setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getUsername == "ldbc" - } - - "The password setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getPassword == "mysql" - } - - "The transaction_isolation setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getTransactionIsolation == "TRANSACTION_NONE" - } - } - -object HikariConfigBuilderFailureTest extends Specification, HikariConfigBuilder: - override protected val path: String = "ldbc.hikari.failure" - - "Testing the HikariConfigBuilderFailureTest" should { - "IllegalArgumentException exception is raised when transaction_isolation is set to a value other than expected" in { - getTransactionIsolation must throwAn[IllegalArgumentException] - } - } diff --git a/plugin/src/main/scala/ldbc/sbt/Dependencies.scala b/plugin/src/main/scala/ldbc/sbt/Dependencies.scala index 87548c3d6..f34fcd41e 100644 --- a/plugin/src/main/scala/ldbc/sbt/Dependencies.scala +++ b/plugin/src/main/scala/ldbc/sbt/Dependencies.scala @@ -17,7 +17,6 @@ trait Dependencies { val ldbcQueryBuilder = component("ldbc-query-builder") val ldbcSchema = component("ldbc-schema") val ldbcCodegen = component("ldbc-codegen") - val ldbcHikari = component("ldbc-hikari") val jdbcConnector = component("jdbc-connector") val ldbcConnector = component("ldbc-connector") } From 63c99a9bcb88747d7520bad4c0ca3ac91c476f52 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:26:06 +0900 Subject: [PATCH 006/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 443e43a20..74e310990 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -151,11 +151,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-hikari/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-hikari/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) From 843299dd21acb39d7185c53f6cd837a20fd0721b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:26:22 +0900 Subject: [PATCH 007/215] Action sbt scalafmtSbt --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 124d4e8fd..1a1e26193 100644 --- a/build.sbt +++ b/build.sbt @@ -204,7 +204,7 @@ lazy val benchmark = (project in file("benchmark")) mysql, doobie, slick, - hikariCP, + hikariCP ) ) .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm) From fc5c4069a92d9d4ea9f18b3c23db534c889d26aa Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:28:57 +0900 Subject: [PATCH 008/215] Delete Dependencies --- build.sbt | 16 ++++++++-------- project/Dependencies.scala | 33 --------------------------------- 2 files changed, 8 insertions(+), 41 deletions(-) delete mode 100644 project/Dependencies.scala diff --git a/build.sbt b/build.sbt index 1a1e26193..412deb5d3 100644 --- a/build.sbt +++ b/build.sbt @@ -175,7 +175,7 @@ lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) .defaultSettings .jvmSettings( Test / fork := true, - libraryDependencies += mysql % Test + libraryDependencies += "com.mysql" % "mysql-connector-j" % "8.4.0" % Test ) .jvmConfigure(_ dependsOn jdbcConnector.jvm) .jsSettings( @@ -200,11 +200,11 @@ lazy val benchmark = (project in file("benchmark")) .settings(Compile / javacOptions ++= Seq("--release", java21)) .settings( libraryDependencies ++= Seq( - scala3Compiler, - mysql, - doobie, - slick, - hikariCP + "org.scala-lang" %% "scala3-compiler" % scala3, + "com.mysql" % "mysql-connector-j" % "8.4.0", + "org.tpolecat" %% "doobie-core" % "1.0.0-RC10", + "com.typesafe.slick" %% "slick" % "3.6.1", + "com.zaxxer" % "HikariCP" % "7.0.2" ) ) .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm) @@ -231,8 +231,8 @@ lazy val hikariCPExample = crossProject(JVMPlatform) .example("hikariCP", "HikariCP example project") .settings( libraryDependencies ++= Seq( - hikariCP, - mysql + "com.zaxxer" % "HikariCP" % "7.0.2", + "com.mysql" % "mysql-connector-j" % "8.4.0" ) ) .dependsOn(jdbcConnector, dsl) diff --git a/project/Dependencies.scala b/project/Dependencies.scala deleted file mode 100644 index d203a2827..000000000 --- a/project/Dependencies.scala +++ /dev/null @@ -1,33 +0,0 @@ -/** This file is part of the ldbc. For the full copyright and license information, please view the LICENSE file that was - * distributed with this source code. - */ - -import sbt.* - -import ScalaVersions.* - -object Dependencies { - - val catsEffect = "org.typelevel" %% "cats-effect" % "3.6.3" - - val parserCombinators = "org.scala-lang.modules" %% "scala-parser-combinators" % "2.3.0" - - val mysqlVersion = "8.4.0" - val mysql = "com.mysql" % "mysql-connector-j" % mysqlVersion - - val typesafeConfig = "com.typesafe" % "config" % "1.4.5" - - val hikariCP = "com.zaxxer" % "HikariCP" % "7.0.2" - - val scala3Compiler = "org.scala-lang" %% "scala3-compiler" % scala3 - - val doobie = "org.tpolecat" %% "doobie-core" % "1.0.0-RC10" - - val slick = "com.typesafe.slick" %% "slick" % "3.6.1" - - val specs2Version = "5.6.4" - val specs2: Seq[ModuleID] = Seq( - "specs2-core", - "specs2-junit" - ).map("org.specs2" %% _ % specs2Version % Test) -} From 1d2eb49d1e991630bc9cd519ff0a240fc6def81e Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:31:31 +0900 Subject: [PATCH 009/215] Delete unused --- build.sbt | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/build.sbt b/build.sbt index 412deb5d3..42c1d9f1f 100644 --- a/build.sbt +++ b/build.sbt @@ -6,7 +6,6 @@ import com.typesafe.tools.mima.core.* import BuildSettings.* -import Dependencies.* import Implicits.* import JavaVersions.* import ProjectKeys.* @@ -174,7 +173,7 @@ lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) ) .defaultSettings .jvmSettings( - Test / fork := true, + Test / fork := true, libraryDependencies += "com.mysql" % "mysql-connector-j" % "8.4.0" % Test ) .jvmConfigure(_ dependsOn jdbcConnector.jvm) @@ -200,11 +199,11 @@ lazy val benchmark = (project in file("benchmark")) .settings(Compile / javacOptions ++= Seq("--release", java21)) .settings( libraryDependencies ++= Seq( - "org.scala-lang" %% "scala3-compiler" % scala3, - "com.mysql" % "mysql-connector-j" % "8.4.0", - "org.tpolecat" %% "doobie-core" % "1.0.0-RC10", - "com.typesafe.slick" %% "slick" % "3.6.1", - "com.zaxxer" % "HikariCP" % "7.0.2" + "org.scala-lang" %% "scala3-compiler" % scala3, + "com.mysql" % "mysql-connector-j" % "8.4.0", + "org.tpolecat" %% "doobie-core" % "1.0.0-RC10", + "com.typesafe.slick" %% "slick" % "3.6.1", + "com.zaxxer" % "HikariCP" % "7.0.2" ) ) .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm) @@ -231,8 +230,8 @@ lazy val hikariCPExample = crossProject(JVMPlatform) .example("hikariCP", "HikariCP example project") .settings( libraryDependencies ++= Seq( - "com.zaxxer" % "HikariCP" % "7.0.2", - "com.mysql" % "mysql-connector-j" % "8.4.0" + "com.zaxxer" % "HikariCP" % "7.0.2", + "com.mysql" % "mysql-connector-j" % "8.4.0" ) ) .dependsOn(jdbcConnector, dsl) @@ -265,7 +264,7 @@ lazy val docs = (project in file("docs")) mdocVariables ++= Map( "ORGANIZATION" -> organization.value, "SCALA_VERSION" -> scalaVersion.value, - "MYSQL_VERSION" -> mysqlVersion + "MYSQL_VERSION" -> "8.4.0" ), laikaTheme := LaikaSettings.helium.value, // Modify tlSite task to run the LLM docs script after the site is generated From 53f17bff5742ca7de8c7bd43d6977ec0948e9416 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:36:27 +0900 Subject: [PATCH 010/215] Fixed DatabaseMetaDataTest --- .../shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala b/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala index fe95c7cf0..7618db39f 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala @@ -2100,7 +2100,7 @@ trait DatabaseMetaDataTest extends CatsEffectSuite: for metaData <- conn.getMetaData() yield metaData.getJDBCMinorVersion() }, - if prefix == "jdbc" then 2 else 4 + if prefix == "jdbc" then 2 else 5 ) } From 08bcfbe52708d2d13b0e9ec4570b72b855dbfedb Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:42:03 +0900 Subject: [PATCH 011/215] Fixed ConnectionTest --- tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala b/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala index bbf0116f6..1c634b748 100644 --- a/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala @@ -126,7 +126,7 @@ trait ConnectionTest extends CatsEffectSuite: assertIO( datasource().getConnection.use(_.getMetaData().map(_.getDriverVersion())), if prefix == "jdbc" then "mysql-connector-j-8.4.0 (Revision: 1c3f5c149e0bfe31c7fbeb24e2d260cd890972c4)" - else "ldbc-connector-0.4.0" + else "ldbc-connector-0.5.0" ) } From fdfbe519063568b25e0d79415135afaf3806edcd Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:46:05 +0900 Subject: [PATCH 012/215] Delete ldbc-hikari in document --- README.md | 1 - codecov.yml | 2 -- docs/src/main/mdoc/README.md | 1 - docs/src/main/mdoc/en/migration-notes.md | 1 - docs/src/main/mdoc/index.md | 1 - docs/src/main/mdoc/ja/migration-notes.md | 1 - 6 files changed, 7 deletions(-) diff --git a/README.md b/README.md index 0774e5fdb..46a9566cf 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,6 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## Performance diff --git a/codecov.yml b/codecov.yml index 519985f37..f8a69ab4a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -16,12 +16,10 @@ coverage: - "module/ldbc-sql" - "module/ldbc-core" - "module/jdbc-connector" - - "module/ldbc-schemaspy" - "tests/shared/src/main/scala/ldbc/tests/model" - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/CallableStatementImpl.scala" - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/ServerPreparedStatement.scala" - "module/ldbc-connector/jvm/src/main/scala/ldbc/connector/SSLPlatform.scala" - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Initial.scala" - "module/ldbc-schema/src/main/scala/ldbc/schema/Character.scala" - - "module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala" - "module/ldbc-dsl/src/main/scala/ldbc/dsl/ResultSetConsumer.scala" diff --git a/docs/src/main/mdoc/README.md b/docs/src/main/mdoc/README.md index 516f97c50..128b61a4f 100644 --- a/docs/src/main/mdoc/README.md +++ b/docs/src/main/mdoc/README.md @@ -48,7 +48,6 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## Quick Start diff --git a/docs/src/main/mdoc/en/migration-notes.md b/docs/src/main/mdoc/en/migration-notes.md index 380e8dc25..1a0c96dae 100644 --- a/docs/src/main/mdoc/en/migration-notes.md +++ b/docs/src/main/mdoc/en/migration-notes.md @@ -32,7 +32,6 @@ | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## 🎯 Major Changes diff --git a/docs/src/main/mdoc/index.md b/docs/src/main/mdoc/index.md index 400e4a033..6fd3e67fd 100644 --- a/docs/src/main/mdoc/index.md +++ b/docs/src/main/mdoc/index.md @@ -47,7 +47,6 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## Quick Start diff --git a/docs/src/main/mdoc/ja/migration-notes.md b/docs/src/main/mdoc/ja/migration-notes.md index 8624b298d..7c14e0477 100644 --- a/docs/src/main/mdoc/ja/migration-notes.md +++ b/docs/src/main/mdoc/ja/migration-notes.md @@ -32,7 +32,6 @@ | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## 🎯 主要な変更点 From 2b4674f2ee8e13246ec1c41feb24d660c22376af Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:47:34 +0900 Subject: [PATCH 013/215] Delete unused --- .scalafmt.conf | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index 1764b7e8a..fa70e487b 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -104,7 +104,6 @@ rewrite { ["org\\.openjdk\\..*"], ["org\\.apache\\..*"], ["org\\.slf4j\\..*"], - ["org\\.schemaspy\\..*"], ["com\\.mysql\\..*"], ["com\\.zaxxer\\..*"], ["com\\.comcast\\..*"], @@ -117,7 +116,6 @@ rewrite { ["laika\\..*"], ["org\\.typelevel\\..*"], ["org\\.scalatest\\..*"], - ["org\\.specs2\\..*"], ["munit\\..*"], ["slick\\..*"], ["doobie\\..*"] @@ -130,8 +128,6 @@ rewrite { ["ldbc\\.codegen\\..*"], ["ldbc\\.connector\\..*"], ["jdbc\\.connector\\..*"], - ["ldbc\\.hikari\\..*"], - ["ldbc\\.schemaspy\\..*"], ["ldbc\\..*"], [".*"], ] From 9c9f895146eb40133cf5352a628dd2e6c4634ec0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 00:54:12 +0900 Subject: [PATCH 014/215] Fixed driver version --- .../shared/src/test/scala/ldbc/connector/ConnectionTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index 06a68155a..af84753d6 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -624,7 +624,7 @@ class ConnectionTest extends FTestPlatform: assertIO( connection.use(_.getMetaData().map(_.getDriverVersion())), - "ldbc-connector-0.4.0" + "ldbc-connector-0.5.0" ) } From c820c04142f01fffde8da8beb341a99f5a2b1f2c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 17:51:01 +0900 Subject: [PATCH 015/215] Delete ConnectionProvider for jdbc --- .../jdbc/connector/ConnectionProvider.scala | 177 ------------------ .../main/scala/jdbc/connector/package.scala | 1 - .../main/scala/ldbc/connector/package.scala | 1 - 3 files changed, 179 deletions(-) delete mode 100644 module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala diff --git a/module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala b/module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala deleted file mode 100644 index 29ddb8391..000000000 --- a/module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package jdbc.connector - -import java.sql.DriverManager - -import javax.sql.DataSource - -import scala.concurrent.ExecutionContext - -import cats.syntax.all.* - -import cats.effect.* -import cats.effect.std.Console - -import ldbc.sql.Connection - -import ldbc.{ Connector, Provider } -import ldbc.logging.LogHandler - -@deprecated( - "Connection creation using ConnectionProvider is now deprecated. Please use jdbc.connector.MySQLDataSource from now on. ConnectionProvider will be removed in version 0.5.x.", - "ldbc 0.4.0" -) -object ConnectionProvider: - - private case class DataSourceProvider[F[_]]( - dataSource: DataSource, - connectEC: ExecutionContext, - logHandler: Option[LogHandler[F]] - )(using ev: Async[F]) - extends Provider[F]: - override def createConnection(): Resource[F, Connection[F]] = - Resource - .fromAutoCloseable(ev.evalOn(ev.delay(dataSource.getConnection()), connectEC)) - .map(conn => ConnectionImpl(conn)) - - override def use[A](f: Connector[F] => F[A]): F[A] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - private case class JavaConnectionProvider[F[_]: Sync]( - connection: java.sql.Connection, - logHandler: Option[LogHandler[F]] - ) extends Provider[F]: - override def createConnection(): Resource[F, Connection[F]] = - Resource.pure(ConnectionImpl(connection)) - - override def use[A](f: Connector[F] => F[A]): F[A] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - class DriverProvider[F[_]: Console](using ev: Async[F]): - - private def create( - driver: String, - conn: () => java.sql.Connection, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - new Provider[F]: - override def createConnection(): Resource[F, Connection[F]] = - Resource - .fromAutoCloseable(ev.blocking { - Class.forName(driver) - conn() - }) - .map(conn => ConnectionImpl(conn)) - - override def use[A](f: Connector[F] => F[A]): F[A] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - /** Construct a new `Provider` that uses the JDBC `DriverManager` to allocate connections. - * - * @param driver - * the class name of the JDBC driver, like "com.mysql.cj.jdbc.MySQLDriver" - * @param url - * a connection URL, specific to your driver - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def apply( - driver: String, - url: String, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - create(driver, () => DriverManager.getConnection(url), logHandler) - - /** Construct a new `Provider` that uses the JDBC `DriverManager` to allocate connections. - * - * @param driver - * the class name of the JDBC driver, like "com.mysql.cj.jdbc.MySQLDriver" - * @param url - * a connection URL, specific to your driver - * @param user - * database username - * @param password - * database password - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def apply( - driver: String, - url: String, - user: String, - password: String, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - create(driver, () => DriverManager.getConnection(url, user, password), logHandler) - - /** Construct a new `Provider` that uses the JDBC `DriverManager` to allocate connections. - * - * @param driver - * the class name of the JDBC driver, like "com.mysql.cj.jdbc.MySQLDriver" - * @param url - * a connection URL, specific to your driver - * @param info - * a `Properties` containing connection information (see `DriverManager.getConnection`) - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def apply( - driver: String, - url: String, - info: java.util.Properties, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - create(driver, () => DriverManager.getConnection(url, info), logHandler) - - /** - * Construct a constructor of `Provider[F]` for some `D <: DataSource` When calling this constructor you - * should explicitly supply the effect type M e.g. ConnectionProvider.fromDataSource[IO](myDataSource, connectEC) - * - * @param dataSource - * A data source that manages connection information to MySQL. - * @param connectEC - * Execution context dedicated to database connection. - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def fromDataSource[F[_]: Console: Async]( - dataSource: DataSource, - connectEC: ExecutionContext, - logHandler: Option[LogHandler[F]] = None - ): Provider[F] = DataSourceProvider[F](dataSource, connectEC, logHandler) - - /** - * Construct a `Provider` that wraps an existing `Connection`. Closing the connection is the responsibility of - * the caller. - * - * @param connection - * a raw JDBC `Connection` to wrap - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def fromConnection[F[_]: Console: Sync]( - connection: java.sql.Connection, - logHandler: Option[LogHandler[F]] = None - ): Provider[F] = JavaConnectionProvider[F](connection, logHandler) - - /** Module of constructors for `Provider` that use the JDBC `DriverManager` to allocate connections. Note that - * `DriverManager` is unbounded and will happily allocate new connections until server resources are exhausted. It - * is usually preferable to use `DataSourceTransactor` with an underlying bounded connection pool. Blocking operations on `DriverProvider` are - * executed on an unbounded cached daemon thread pool by default, so you are also at risk of exhausting system - * threads. TL;DR this is fine for console apps but don't use it for a web application. - */ - def fromDriverManager[F[_]: Console: Async]: DriverProvider[F] = new DriverProvider[F] diff --git a/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala b/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala index 95e26409e..2ea26a1e5 100644 --- a/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala +++ b/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala @@ -9,4 +9,3 @@ package jdbc package object connector: export ldbc.Connector export ldbc.DataSource - export ldbc.Provider diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala index 081d78da4..ef58abcf9 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala @@ -9,4 +9,3 @@ package ldbc package object connector: export ldbc.Connector export ldbc.DataSource - export ldbc.Provider From 40e6fe72e48843ec4cfe1a20ae42a1bdd2ca9972 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 17:51:30 +0900 Subject: [PATCH 016/215] Delete ConnectionProvider for ldbc --- .../ldbc/connector/ConnectionProvider.scala | 522 ------------------ .../connector/ConnectionProviderTest.scala | 162 ------ 2 files changed, 684 deletions(-) delete mode 100644 module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala delete mode 100644 module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala deleted file mode 100644 index 8208b597d..000000000 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala +++ /dev/null @@ -1,522 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector - -import scala.concurrent.duration.Duration - -import cats.effect.* -import cats.effect.std.Console -import cats.effect.std.UUIDGen - -import fs2.hashing.Hashing -import fs2.io.net.* - -import org.typelevel.otel4s.trace.Tracer - -import ldbc.sql.DatabaseMetaData - -import ldbc.{ Connector, Provider } -import ldbc.logging.LogHandler - -@deprecated( - "Connection creation using ConnectionProvider is now deprecated. Please use ldbc.connector.MySQLDataSource from now on. ConnectionProvider will be removed in version 0.5.x.", - "ldbc 0.4.0" -) -trait ConnectionProvider[F[_], A] extends Provider[F]: - - /** - * Update the host information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setHost("127.0.0.1") - * }}} - * - * @param host - * Host information of the database to be connected - */ - def setHost(host: String): ConnectionProvider[F, A] - - /** - * Update the port information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setPort(3306) - * }}} - * - * @param port - * Port information of the database to be connected - */ - def setPort(port: Int): ConnectionProvider[F, A] - - /** - * Update the user information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setUser("root") - * }}} - * - * @param user - * User information of the database to be connected - */ - def setUser(user: String): ConnectionProvider[F, A] - - /** - * Update the password information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setPassword("password") - * }}} - * - * @param password - * Password information of the database to be connected - */ - def setPassword(password: String): ConnectionProvider[F, A] - - /** - * Update the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setDatabase("database") - * }}} - * - * @param database - * Database name to connect to - */ - def setDatabase(database: String): ConnectionProvider[F, A] - - /** - * Update the setting of whether or not to output the log of packet communications for connection processing. - * - * Default is false. - * - * {{{ - * ConnectionProvider - * ... - * .setDebug(true) - * }}} - * - * @param debug - * Whether packet communication logs are output or not - */ - def setDebug(debug: Boolean): ConnectionProvider[F, A] - - /** - * Update whether SSL communication is used. - * - * {{{ - * ConnectionProvider - * ... - * .setSSL(SSL.Trusted) - * }}} - * - * @param ssl - * SSL set value. Changes the way certificates are operated, etc. - */ - def setSSL(ssl: SSL): ConnectionProvider[F, A] - - /** - * Update socket options for TCP/UDP sockets. - * - * {{{ - * ConnectionProvider - * ... - * .addSocketOption(SocketOption.noDelay(true)) - * }}} - * - * @param socketOption - * Socket options for TCP/UDP sockets - */ - def addSocketOption(socketOption: SocketOption): ConnectionProvider[F, A] - - /** - * Update socket options for TCP/UDP sockets. - * - * {{{ - * ConnectionProvider - * ... - * .setSocketOptions(List(SocketOption.noDelay(true))) - * }}} - * - * @param socketOptions - * List of socket options for TCP/UDP sockets - */ - def setSocketOptions(socketOptions: List[SocketOption]): ConnectionProvider[F, A] - - /** - * Update the read timeout value. - * - * {{{ - * ConnectionProvider - * ... - * .setReadTimeout(Duration.Inf) - * }}} - * - * @param readTimeout - * Read timeout value - */ - def setReadTimeout(readTimeout: Duration): ConnectionProvider[F, A] - - /** - * Update the setting of whether or not to replace the public key. - * - * {{{ - * ConnectionProvider - * ... - * .setAllowPublicKeyRetrieval(true) - * }}} - * - * @param allowPublicKeyRetrieval - * Whether to replace the public key - */ - def setAllowPublicKeyRetrieval(allowPublicKeyRetrieval: Boolean): ConnectionProvider[F, A] - - /** - * Update whether the JDBC term “catalog” or “schema” is used to refer to the database in the application. - * - * {{{ - * ConnectionProvider - * ... - * .setDatabaseTerm(DatabaseMetaData.DatabaseTerm.SCHEMA) - * }}} - * - * @param databaseTerm - * The JDBC terms [[DatabaseMetaData.DatabaseTerm.CATALOG]] and [[DatabaseMetaData.DatabaseTerm.SCHEMA]] are used to refer to the database. - */ - def setDatabaseTerm(databaseTerm: DatabaseMetaData.DatabaseTerm): ConnectionProvider[F, A] - - /** - * Update handler to output execution log of processes using connections. - * - * {{{ - * ConnectionProvider - * ... - * .setLogHandler(consoleLogger) - * }}} - * - * @param handler - * Handler for outputting logs of process execution using connections. - */ - def setLogHandler(handler: LogHandler[F]): ConnectionProvider[F, A] - - /** - * Update tracers to output metrics. - * - * {{{ - * ConnectionProvider - * ... - * .setTracer(Tracer.noop[IO]) - * }}} - * - * @param tracer - * Tracer to output metrics - */ - def setTracer(tracer: Tracer[F]): ConnectionProvider[F, A] - - /** - * Update whether to use cursor fetch for large result sets. - * - * {{{ - * ConnectionProvider - * ... - * .setUseCursorFetch(true) - * }}} - * - * @param useCursorFetch - * Whether to use cursor fetch for large result sets. - */ - def setUseCursorFetch(useCursorFetch: Boolean): ConnectionProvider[F, A] - - /** - * Update whether to use server prepared statements. - * - * {{{ - * ConnectionProvider - * ... - * .setUseServerPrepStmts(true) - * }}} - * - * @param useServerPrepStmts - * Whether to use server prepared statements. - */ - def setUseServerPrepStmts(useServerPrepStmts: Boolean): ConnectionProvider[F, A] - - /** - * Add an optional process to be executed immediately after connection is established. - * - * {{{ - * val before = ??? - * - * ConnectionProvider - * ... - * .withBefore(before) - * }}} - * - * @param before - * Arbitrary processing to be performed immediately after connection is established - * @tparam B - * Value returned after the process is executed. This value can be passed to the After process. - */ - def withBefore[B](before: Connection[F] => F[B]): ConnectionProvider[F, B] - - /** - * Add any process to be performed before disconnecting the connection. - * - * {{{ - * val after = ??? - * - * ConnectionProvider - * ... - * .withAfter(after) - * }}} - * - * @param after - * Arbitrary processing to be performed before disconnecting - */ - def withAfter(after: (A, Connection[F]) => F[Unit]): ConnectionProvider[F, A] - - /** - * Add optional processing to be performed immediately after establishing a connection and before disconnecting. - * - * The order of processing is as follows. - * - * {{{ - * 1. connection establishment - * 2. before operation - * 3. Processing using connections. Any processing used primarily in the operation of an application. - * 4. after operation - * 5. Disconnection - * }}} - * - * {{{ - * val before = ??? - * val after = ??? - * - * ConnectionProvider - * ... - * .withBeforeAfter(before, after) - * }}} - * - * @param before - * Arbitrary processing to be performed immediately after connection is established - * @param after - * Arbitrary processing to be performed before disconnecting - * @tparam B - * Value returned after the process is executed. This value can be passed to the After process. - */ - def withBeforeAfter[B](before: Connection[F] => F[B], after: (B, Connection[F]) => F[Unit]): ConnectionProvider[F, B] - -object ConnectionProvider: - - val defaultSocketOptions: List[SocketOption] = - List(SocketOption.noDelay(true)) - - @annotation.nowarn - private case class Impl[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - logHandler: Option[LogHandler[F]] = None, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - tracer: Option[Tracer[F]] = None, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - before: Option[Connection[F] => F[A]] = None, - after: Option[(A, Connection[F]) => F[Unit]] = None - ) extends ConnectionProvider[F, A]: - given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) - - override def setHost(host: String): ConnectionProvider[F, A] = - this.copy(host = host) - - override def setPort(port: Int): ConnectionProvider[F, A] = - this.copy(port = port) - - override def setUser(user: String): ConnectionProvider[F, A] = - this.copy(user = user) - - override def setPassword(password: String): ConnectionProvider[F, A] = - this.copy(password = Some(password)) - - override def setDatabase(database: String): ConnectionProvider[F, A] = - this.copy(database = Some(database)) - - override def setDebug(debug: Boolean): ConnectionProvider[F, A] = - this.copy(debug = debug) - - override def setSSL(ssl: SSL): ConnectionProvider[F, A] = - this.copy(ssl = ssl) - - override def addSocketOption(socketOption: SocketOption): ConnectionProvider[F, A] = - this.copy(socketOptions = socketOptions.::(socketOption)) - - override def setSocketOptions(socketOptions: List[SocketOption]): ConnectionProvider[F, A] = - this.copy(socketOptions = socketOptions) - - override def setReadTimeout(readTimeout: Duration): ConnectionProvider[F, A] = - this.copy(readTimeout = readTimeout) - - override def setAllowPublicKeyRetrieval(allowPublicKeyRetrieval: Boolean): ConnectionProvider[F, A] = - this.copy(allowPublicKeyRetrieval = allowPublicKeyRetrieval) - - override def setDatabaseTerm(databaseTerm: DatabaseMetaData.DatabaseTerm): ConnectionProvider[F, A] = - this.copy(databaseTerm = Some(databaseTerm)) - - override def setLogHandler(handler: LogHandler[F]): ConnectionProvider[F, A] = - this.copy(logHandler = Some(handler)) - - override def setTracer(tracer: Tracer[F]): ConnectionProvider[F, A] = - this.copy(tracer = Some(tracer)) - - override def setUseCursorFetch(useCursorFetch: Boolean): ConnectionProvider[F, A] = - val useServerPrepStmts = if useCursorFetch then true else this.useServerPrepStmts - this.copy(useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts) - - override def setUseServerPrepStmts(useServerPrepStmts: Boolean): ConnectionProvider[F, A] = - this.copy(useServerPrepStmts = useServerPrepStmts) - - override def withBefore[B](before: Connection[F] => F[B]): ConnectionProvider[F, B] = - Impl( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - databaseTerm = databaseTerm, - logHandler = logHandler, - tracer = tracer, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - before = Some(before), - after = None - ) - override def withAfter(after: (A, Connection[F]) => F[Unit]): ConnectionProvider[F, A] = - this.copy(after = Some(after)) - - override def withBeforeAfter[B]( - before: Connection[F] => F[B], - after: (B, Connection[F]) => F[Unit] - ): ConnectionProvider[F, B] = - Impl( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - databaseTerm = databaseTerm, - tracer = tracer, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - before = Some(before), - after = Some(after) - ) - - override def createConnection(): Resource[F, Connection[F]] = - (before, after) match - case (Some(b), Some(a)) => - Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = a, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm - ) - case (Some(b), None) => - Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = (_, _) => Async[F].unit, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm - ) - case (None, _) => - Connection( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm - ) - - override def use[B](f: Connector[F] => F[B]): F[B] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - @annotation.nowarn - def default[F[_]: Async: Network: Console: Hashing: UUIDGen]( - host: String, - port: Int, - user: String - ): ConnectionProvider[F, Unit] = Impl[F, Unit](host, port, user) - - @annotation.nowarn - def default[F[_]: Async: Network: Console: Hashing: UUIDGen]( - host: String, - port: Int, - user: String, - password: String, - database: String - ): ConnectionProvider[F, Unit] = - default[F](host, port, user) - .setPassword(password) - .setDatabase(database) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala deleted file mode 100644 index 708276607..000000000 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector - -import scala.concurrent.duration.Duration - -import cats.effect.IO - -import fs2.io.net.SocketOption - -import org.typelevel.otel4s.trace.Tracer - -import ldbc.sql.DatabaseMetaData - -@annotation.nowarn() -class ConnectionProviderTest extends FTestPlatform: - - /** - * Default connection provider for testing - */ - def defaultProvider: ConnectionProvider[IO, Unit] = - ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") - - test("ConnectionProvider#setHost") { - val provider = defaultProvider.setHost("localhost") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setPort") { - val provider = defaultProvider.setPort(3306) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setUser") { - val provider = defaultProvider.setUser("root") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setPassword") { - val provider = defaultProvider.setPassword("newpassword") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setDatabase") { - val provider = defaultProvider.setDatabase("testdb") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setDebug") { - val provider = defaultProvider.setDebug(true) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setSSL") { - val provider = defaultProvider.setSSL(SSL.Trusted) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#addSocketOption") { - val option = SocketOption.noDelay(false) - val provider = defaultProvider.addSocketOption(option) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setSocketOptions") { - val options = List(SocketOption.noDelay(false)) - val provider = defaultProvider.setSocketOptions(options) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setReadTimeout") { - val provider = defaultProvider.setReadTimeout(Duration(5000, "ms")) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setAllowPublicKeyRetrieval") { - val provider = defaultProvider.setAllowPublicKeyRetrieval(true) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setDatabaseTerm") { - val provider = defaultProvider.setDatabaseTerm(DatabaseMetaData.DatabaseTerm.SCHEMA) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setTracer") { - val tracer = Tracer.noop[IO] - val provider = defaultProvider.setTracer(tracer) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#withBefore") { - val before = (conn: Connection[IO]) => IO.unit - val provider = defaultProvider.withBefore(before) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#withAfter") { - val after = (unit: Unit, conn: Connection[IO]) => IO.unit - val provider = defaultProvider.withAfter(after) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#withBeforeAfter") { - val before = (conn: Connection[IO]) => IO.unit - val after = (unit: Unit, conn: Connection[IO]) => IO.unit - val provider = defaultProvider.withBeforeAfter(before, after) - assertNotEquals( - provider, - defaultProvider - ) - } From cbef8631f04a9189e249397120a6eb236dc4133b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 19:56:10 +0900 Subject: [PATCH 017/215] Change use ConnectionProvider to Connector --- docs/src/main/scala/01-Program.scala | 11 +++++------ docs/src/main/scala/02-Program.scala | 11 +++++------ docs/src/main/scala/03-Program.scala | 11 +++++------ docs/src/main/scala/04-Program.scala | 11 +++++------ docs/src/main/scala/05-Program.scala | 11 +++++------ docs/src/main/scala/0X-Cleanup.scala | 10 ++++------ examples/http4s/src/main/scala/Main.scala | 19 ++++++++++--------- examples/otel/src/main/scala/Main.scala | 12 ++++++------ 8 files changed, 45 insertions(+), 51 deletions(-) diff --git a/docs/src/main/scala/01-Program.scala b/docs/src/main/scala/01-Program.scala index cf607f071..b36eed56a 100644 --- a/docs/src/main/scala/01-Program.scala +++ b/docs/src/main/scala/01-Program.scala @@ -18,17 +18,16 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.readOnly(conn).map(println(_)) - } + program.readOnly(connector).map(println(_)) .unsafeRunSync() // 1 // #run diff --git a/docs/src/main/scala/02-Program.scala b/docs/src/main/scala/02-Program.scala index 20f309116..5591cfde3 100644 --- a/docs/src/main/scala/02-Program.scala +++ b/docs/src/main/scala/02-Program.scala @@ -18,17 +18,16 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.readOnly(conn).map(println(_)) - } + program.readOnly(connector).map(println(_)) .unsafeRunSync() // Some(2) // #run diff --git a/docs/src/main/scala/03-Program.scala b/docs/src/main/scala/03-Program.scala index 45ac92f9c..61e264769 100644 --- a/docs/src/main/scala/03-Program.scala +++ b/docs/src/main/scala/03-Program.scala @@ -23,17 +23,16 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.readOnly(conn).map(println(_)) - } + program.readOnly(connector).map(println(_)) .unsafeRunSync() // (List(1), Some(2), 3) // #run diff --git a/docs/src/main/scala/04-Program.scala b/docs/src/main/scala/04-Program.scala index 7e9261fee..75e4eef85 100644 --- a/docs/src/main/scala/04-Program.scala +++ b/docs/src/main/scala/04-Program.scala @@ -19,17 +19,16 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.commit(conn).map(println(_)) - } + program.commit(connector).map(println(_)) .unsafeRunSync() // 1 // #run diff --git a/docs/src/main/scala/05-Program.scala b/docs/src/main/scala/05-Program.scala index 0e3494af8..9234daad5 100644 --- a/docs/src/main/scala/05-Program.scala +++ b/docs/src/main/scala/05-Program.scala @@ -35,13 +35,12 @@ import ldbc.connector.* val program2: DBIO[(String, String, Status)] = sql"SELECT name, email, status FROM user WHERE id = 1".query[(String, String, Status)].unsafe - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) - connection - .use { conn => - program1.commit(conn) *> program2.readOnly(conn).map(println(_)) - } + def connector = Connector.fromDataSource(dataSource) + + (program1.commit(connector) *> program2.readOnly(connector).map(println(_))) .unsafeRunSync() diff --git a/docs/src/main/scala/0X-Cleanup.scala b/docs/src/main/scala/0X-Cleanup.scala index a6f3be7aa..d455c648e 100644 --- a/docs/src/main/scala/0X-Cleanup.scala +++ b/docs/src/main/scala/0X-Cleanup.scala @@ -19,15 +19,13 @@ import ldbc.connector.* // #cleanupDatabase // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - dropDatabase.commit(conn).as(println("Database dropped")) - } + dropDatabase.commit(connector).as(println("Database dropped")) .unsafeRunSync() // #run diff --git a/examples/http4s/src/main/scala/Main.scala b/examples/http4s/src/main/scala/Main.scala index 40712ea4c..262d2491c 100644 --- a/examples/http4s/src/main/scala/Main.scala +++ b/examples/http4s/src/main/scala/Main.scala @@ -48,24 +48,25 @@ object Main extends ResourceApp.Forever: private val cityTable = TableQuery[CityTable] - private def provider = - ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc", "password", "world") - .setSSL(SSL.Trusted) - - private def routes(conn: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { + private val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") + .setPassword("password") + .setDatabase("world") + .setSSL(SSL.Trusted) + private val connector = Connector.fromDataSource(dataSource) + + private def routes: HttpRoutes[IO] = HttpRoutes.of[IO] { case GET -> Root / "cities" => for - cities <- cityTable.selectAll.query.to[List].readOnly(conn) + cities <- cityTable.selectAll.query.to[List].readOnly(connector) result <- Ok(cities.asJson) yield result } override def run(args: List[String]): Resource[IO, Unit] = for - conn <- provider.createConnector() _ <- EmberServerBuilder .default[IO] - .withHttpApp(routes(conn).orNotFound) + .withHttpApp(routes.orNotFound) .build yield () diff --git a/examples/otel/src/main/scala/Main.scala b/examples/otel/src/main/scala/Main.scala index 978f7922c..321fc3cce 100644 --- a/examples/otel/src/main/scala/Main.scala +++ b/examples/otel/src/main/scala/Main.scala @@ -18,18 +18,18 @@ object Main extends IOApp.Simple: private val serviceName = "ldbc-otel-example" + private val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13307, "ldbc") + .setPassword("password") + .setDatabase("world") + private def resource: Resource[IO, Connector[IO]] = for otel <- Resource .eval(IO.delay(GlobalOpenTelemetry.get)) .evalMap(OtelJava.fromJOpenTelemetry[IO]) tracer <- Resource.eval(otel.tracerProvider.get(serviceName)) - connection <- ConnectionProvider - .default[IO]("127.0.0.1", 13307, "ldbc", "password", "world") - .setSSL(SSL.Trusted) - .setTracer(tracer) - .createConnector() - yield connection + yield Connector.fromDataSource(dataSource.setTracer(tracer)) override def run: IO[Unit] = resource.use { conn => From 210e0cb121a6f15c5db5f3c12cc7d348b8508bb9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 30 Sep 2025 19:56:24 +0900 Subject: [PATCH 018/215] Action sbt scalafmtSbt --- docs/src/main/scala/01-Program.scala | 4 +++- docs/src/main/scala/02-Program.scala | 4 +++- docs/src/main/scala/03-Program.scala | 4 +++- docs/src/main/scala/04-Program.scala | 4 +++- docs/src/main/scala/0X-Cleanup.scala | 4 +++- examples/http4s/src/main/scala/Main.scala | 9 ++++----- examples/otel/src/main/scala/Main.scala | 2 +- 7 files changed, 20 insertions(+), 11 deletions(-) diff --git a/docs/src/main/scala/01-Program.scala b/docs/src/main/scala/01-Program.scala index b36eed56a..f26c71b02 100644 --- a/docs/src/main/scala/01-Program.scala +++ b/docs/src/main/scala/01-Program.scala @@ -27,7 +27,9 @@ import ldbc.connector.* // #connection // #run - program.readOnly(connector).map(println(_)) + program + .readOnly(connector) + .map(println(_)) .unsafeRunSync() // 1 // #run diff --git a/docs/src/main/scala/02-Program.scala b/docs/src/main/scala/02-Program.scala index 5591cfde3..75528f552 100644 --- a/docs/src/main/scala/02-Program.scala +++ b/docs/src/main/scala/02-Program.scala @@ -27,7 +27,9 @@ import ldbc.connector.* // #connection // #run - program.readOnly(connector).map(println(_)) + program + .readOnly(connector) + .map(println(_)) .unsafeRunSync() // Some(2) // #run diff --git a/docs/src/main/scala/03-Program.scala b/docs/src/main/scala/03-Program.scala index 61e264769..6584829c8 100644 --- a/docs/src/main/scala/03-Program.scala +++ b/docs/src/main/scala/03-Program.scala @@ -32,7 +32,9 @@ import ldbc.connector.* // #connection // #run - program.readOnly(connector).map(println(_)) + program + .readOnly(connector) + .map(println(_)) .unsafeRunSync() // (List(1), Some(2), 3) // #run diff --git a/docs/src/main/scala/04-Program.scala b/docs/src/main/scala/04-Program.scala index 75e4eef85..5c26ccbe4 100644 --- a/docs/src/main/scala/04-Program.scala +++ b/docs/src/main/scala/04-Program.scala @@ -28,7 +28,9 @@ import ldbc.connector.* // #connection // #run - program.commit(connector).map(println(_)) + program + .commit(connector) + .map(println(_)) .unsafeRunSync() // 1 // #run diff --git a/docs/src/main/scala/0X-Cleanup.scala b/docs/src/main/scala/0X-Cleanup.scala index d455c648e..09c879558 100644 --- a/docs/src/main/scala/0X-Cleanup.scala +++ b/docs/src/main/scala/0X-Cleanup.scala @@ -26,6 +26,8 @@ import ldbc.connector.* // #connection // #run - dropDatabase.commit(connector).as(println("Database dropped")) + dropDatabase + .commit(connector) + .as(println("Database dropped")) .unsafeRunSync() // #run diff --git a/examples/http4s/src/main/scala/Main.scala b/examples/http4s/src/main/scala/Main.scala index 262d2491c..2e9412298 100644 --- a/examples/http4s/src/main/scala/Main.scala +++ b/examples/http4s/src/main/scala/Main.scala @@ -64,9 +64,8 @@ object Main extends ResourceApp.Forever: } override def run(args: List[String]): Resource[IO, Unit] = - for - _ <- EmberServerBuilder - .default[IO] - .withHttpApp(routes.orNotFound) - .build + for _ <- EmberServerBuilder + .default[IO] + .withHttpApp(routes.orNotFound) + .build yield () diff --git a/examples/otel/src/main/scala/Main.scala b/examples/otel/src/main/scala/Main.scala index 321fc3cce..174724466 100644 --- a/examples/otel/src/main/scala/Main.scala +++ b/examples/otel/src/main/scala/Main.scala @@ -28,7 +28,7 @@ object Main extends IOApp.Simple: otel <- Resource .eval(IO.delay(GlobalOpenTelemetry.get)) .evalMap(OtelJava.fromJOpenTelemetry[IO]) - tracer <- Resource.eval(otel.tracerProvider.get(serviceName)) + tracer <- Resource.eval(otel.tracerProvider.get(serviceName)) yield Connector.fromDataSource(dataSource.setTracer(tracer)) override def run: IO[Unit] = From ee3c041fe39584322b1546ce710a090966996691 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 19:36:07 +0900 Subject: [PATCH 019/215] Create ldbc zio interop project --- build.sbt | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/build.sbt b/build.sbt index 2767a9f76..38565958d 100644 --- a/build.sbt +++ b/build.sbt @@ -155,6 +155,16 @@ lazy val plugin = LepusSbtPluginProject("ldbc-plugin", "plugin") ) }.taskValue) +lazy val zioInterop = crossProject(JVMPlatform, JSPlatform) + .crossType(CrossType.Pure) + .module("zio-interop", "Projects that provide a way to connect to the database for ZIO") + .settings( + libraryDependencies ++= Seq( + "dev.zio" %%% "zio" % "2.1.6", + "dev.zio" %%% "zio-interop-cats" % "23.1.0.5" + ) + ) + lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) .in(file("tests")) @@ -372,6 +382,7 @@ lazy val ldbc = tlCrossRootProject queryBuilder, schema, codegen, + zioInterop, plugin, tests, docs, From cdb17c89fd37c5fd3f268cfec301e167012ebce0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 19:36:19 +0900 Subject: [PATCH 020/215] Create package object --- .../main/scala/ldbc/zio/interop/package.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala diff --git a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala new file mode 100644 index 000000000..da2dafe53 --- /dev/null +++ b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala @@ -0,0 +1,24 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.zio + +import java.util.UUID + +import cats.effect.std.UUIDGen + +import fs2.hashing.Hashing +import fs2.io.net.Network + +import zio.* +import zio.interop.catz.* + +package object interop: + implicit def consoleToZIO: cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] + implicit def uuidGenToZIO: UUIDGen[Task] = new UUIDGen[Task]: + override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) + implicit def hashingToZIO: Hashing[Task] = Hashing.forSync[Task] + implicit def networkToZIO: Network[Task] = Network.forAsync[Task] From f4bef5abf5dc5e49ba906b5a035b963caf57c40f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 19:42:14 +0900 Subject: [PATCH 021/215] Update document --- README.md | 16 ++++----- docs/src/main/mdoc/README.md | 1 + docs/src/main/mdoc/en/index.md | 1 - .../main/mdoc/en/qa/How-to-use-with-ZIO.md | 28 +++++---------- docs/src/main/mdoc/index.md | 1 + docs/src/main/mdoc/ja/index.md | 1 - .../main/mdoc/ja/qa/How-to-use-with-ZIO.md | 35 +++++++------------ 7 files changed, 31 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 46a9566cf..7b3270b45 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-zio-interop` | ✅ | ❌ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-zio-interop_3) | ## Performance @@ -228,28 +229,25 @@ val result: IO[List[User]] = ## How to use with ZIO -Although ldbc was created to run on the Cats Effect, can also be used in conjunction with ZIO by using [ZIO Interop Cats](https://github.com/zio/interop-cats). +Although ldbc was created to run on the Cats Effect, can also be used in conjunction with ZIO by using `ldbc-zio-interop`. > [!CAUTION] > Although ldbc supports three platforms, Note that ZIO Interop Cats does not currently support Scala Native. ```scala -libraryDependencies += "dev.zio" %% "zio-interop-cats" % "" +libraryDependencies += "io.github.takapi327" %% "ldbc-zio-interop" % "latest" ``` The following is sample code for using ldbc with ZIO. ```scala 3 import zio.* -import zio.interop.catz.* -object Main extends ZIOAppDefault: +import ldbc.zio.interop.* +import ldbc.connector.* +import ldbc.dsl.* - given cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - given cats.effect.std.UUIDGen[Task] with - override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) - given fs2.hashing.Hashing[Task] = fs2.hashing.Hashing.forSync[Task] - given fs2.io.net.Network[Task] = fs2.io.net.Network.forAsync[Task] +object Main extends ZIOAppDefault: private val datasource = MySQLDataSource diff --git a/docs/src/main/mdoc/README.md b/docs/src/main/mdoc/README.md index 128b61a4f..59412cdf3 100644 --- a/docs/src/main/mdoc/README.md +++ b/docs/src/main/mdoc/README.md @@ -49,6 +49,7 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-zio-interop` | ✅ | ❌ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-zio-interop_3) | ## Quick Start diff --git a/docs/src/main/mdoc/en/index.md b/docs/src/main/mdoc/en/index.md index e03b833de..e8c7feb85 100644 --- a/docs/src/main/mdoc/en/index.md +++ b/docs/src/main/mdoc/en/index.md @@ -66,6 +66,5 @@ libraryDependencies ++= Seq( - Geometry data type support - CHECK constraint support - Support for databases other than MySQL -- ZIO module support - Test kit - etc... diff --git a/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md b/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md index 7bee3a856..646369374 100644 --- a/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md +++ b/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md @@ -5,36 +5,26 @@ laika.metadata.language = ja # Q: How to use with ZIO? -## A: For use with ZIO, use `zio-interop-cats`. +## A: For use with ZIO, use `ldbc-zio-interop`. ```scala -libraryDependencies += "dev.zio" %% "zio-interop-cats" % "" +libraryDependencies += "io.github.takapi327" %% "ldbc-zio-interop" % "latest" ``` The following is sample code for using ldbc with ZIO. ```scala 3 mdoc -import java.util.UUID - -import cats.effect.std.UUIDGen - -import fs2.hashing.Hashing -import fs2.io.net.Network - import zio.* -import zio.interop.catz.* -object Main extends ZIOAppDefault: +import ldbc.zio.interop.* +import ldbc.connector.* +import ldbc.dsl.* - given cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - given UUIDGen[Task] with - override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) - given Hashing[Task] = Hashing.forSync[Task] - given Network[Task] = Network.forAsync[Task] +object Main extends ZIOAppDefault: - private def datasource = + private val datasource = MySQLDataSource - .build[Task]("127.0.0.1", 13306, "ldbc") + .build[Task]("127.0.0.1", 3306, "ldbc") .setPassword("password") .setDatabase("world") .setSSL(SSL.Trusted) @@ -51,7 +41,7 @@ object Main extends ZIOAppDefault: } ``` -### パフォーマンス +### Performance Performance results from the Cats Effect to ZIO conversion are shown below. diff --git a/docs/src/main/mdoc/index.md b/docs/src/main/mdoc/index.md index 6fd3e67fd..7fc4566fd 100644 --- a/docs/src/main/mdoc/index.md +++ b/docs/src/main/mdoc/index.md @@ -48,6 +48,7 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-zio-interop` | ✅ | ❌ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-zio-interop_3) | ## Quick Start diff --git a/docs/src/main/mdoc/ja/index.md b/docs/src/main/mdoc/ja/index.md index eb14b520e..dc1496fd0 100644 --- a/docs/src/main/mdoc/ja/index.md +++ b/docs/src/main/mdoc/ja/index.md @@ -66,6 +66,5 @@ libraryDependencies ++= Seq( - Geometryデータタイプのサポート - CHECK制約のサポート - MySQL以外のデータベースサポート -- ZIOモジュールのサポート - テストキット - etc... diff --git a/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md b/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md index abafba1d1..d15366141 100644 --- a/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md +++ b/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md @@ -5,39 +5,30 @@ laika.metadata.language = ja # Q: ZIOで使用する方法は? -## A: ZIOで使用する場合、`zio-interop-cats`を使用します。 +## A: ZIOで使用する場合、`ldbc-zio-interop`を使用します。 ```scala -libraryDependencies += "dev.zio" %% "zio-interop-cats" % "" +libraryDependencies += "io.github.takapi327" %% "ldbc-zio-interop" % "latest" ``` 以下は、ZIOを使用してldbcを利用するためのサンプルコードです。 ```scala 3 mdoc -import java.util.UUID - -import cats.effect.std.UUIDGen - -import fs2.hashing.Hashing -import fs2.io.net.Network - import zio.* -import zio.interop.catz.* + +import ldbc.zio.interop.* +import ldbc.connector.* +import ldbc.dsl.* object Main extends ZIOAppDefault: - given cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - given UUIDGen[Task] with - override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) - given Hashing[Task] = Hashing.forSync[Task] - given Network[Task] = Network.forAsync[Task] - - private def datasource = MySQLDataSource - .build[Task]("127.0.0.1", 13306, "ldbc") - .setPassword("password") - .setDatabase("world") - .setSSL(SSL.Trusted) - + private val datasource = + MySQLDataSource + .build[Task]("127.0.0.1", 3306, "ldbc") + .setPassword("password") + .setDatabase("world") + .setSSL(SSL.Trusted) + private val connector = Connector.fromDataSource(datasource) override def run = From 514d10eb42f1b19a5ea936398d0de313b4b570b4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 19:42:33 +0900 Subject: [PATCH 022/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74e310990..ec7a02991 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -151,11 +151,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) From a5a17e97bb11ad04cb0294b6934fc17b59f1e754 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 19:42:50 +0900 Subject: [PATCH 023/215] Action sbt scalafmtAll --- .../src/main/scala/ldbc/zio/interop/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala index da2dafe53..a4d2ec663 100644 --- a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala +++ b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala @@ -18,7 +18,7 @@ import zio.interop.catz.* package object interop: implicit def consoleToZIO: cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - implicit def uuidGenToZIO: UUIDGen[Task] = new UUIDGen[Task]: + implicit def uuidGenToZIO: UUIDGen[Task] = new UUIDGen[Task]: override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) implicit def hashingToZIO: Hashing[Task] = Hashing.forSync[Task] implicit def networkToZIO: Network[Task] = Network.forAsync[Task] From 0e557233084325d615ab0f4d6d7c832832113b19 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 20:11:25 +0900 Subject: [PATCH 024/215] Create ZIO Query test --- build.sbt | 7 +- .../main/scala/ldbc/zio/interop/package.scala | 3 +- .../scala/ldbc/zio/interop/QueryTest.scala | 554 ++++++++++++++++++ 3 files changed, 561 insertions(+), 3 deletions(-) create mode 100644 module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala diff --git a/build.sbt b/build.sbt index 38565958d..6b0c1d7c3 100644 --- a/build.sbt +++ b/build.sbt @@ -160,10 +160,13 @@ lazy val zioInterop = crossProject(JVMPlatform, JSPlatform) .module("zio-interop", "Projects that provide a way to connect to the database for ZIO") .settings( libraryDependencies ++= Seq( - "dev.zio" %%% "zio" % "2.1.6", - "dev.zio" %%% "zio-interop-cats" % "23.1.0.5" + "dev.zio" %%% "zio" % "2.1.21", + "dev.zio" %%% "zio-interop-cats" % "23.1.0.5", + "dev.zio" %%% "zio-test" % "2.1.21" % Test, + "dev.zio" %%% "zio-test-sbt" % "2.1.21" % Test, ) ) + .dependsOn(connector % "test->compile") lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) diff --git a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala index a4d2ec663..6e9cf9a2c 100644 --- a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala +++ b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala @@ -8,15 +8,16 @@ package ldbc.zio import java.util.UUID +import cats.effect.Async import cats.effect.std.UUIDGen import fs2.hashing.Hashing import fs2.io.net.Network import zio.* -import zio.interop.catz.* package object interop: + implicit def asyncToZIO: Async[Task] = zio.interop.catz.asyncInstance implicit def consoleToZIO: cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] implicit def uuidGenToZIO: UUIDGen[Task] = new UUIDGen[Task]: override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) diff --git a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala new file mode 100644 index 000000000..60a092c5d --- /dev/null +++ b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala @@ -0,0 +1,554 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.zio.interop + +import java.time.* + +import cats.Monad + +import zio.* +import zio.test.* +import zio.test.Assertion.* + +import ldbc.connector.* +import ldbc.connector.data.* + +object QueryTest extends ZIOSpecDefault: + + private val datasource = + MySQLDataSource + .build[Task]("127.0.0.1", 13306, "ldbc") + .setPassword("password") + .setDatabase("connector_test") + .setSSL(SSL.Trusted) + + def spec = suite("QueryTest")( + test("The client's PreparedStatement may use NULL as a parameter.") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `bit`, `bit_null` FROM `all_types` WHERE `bit_null` is ?") + resultSet <- statement.setNull(1, MysqlType.BIT.jdbcType) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((1.toByte, 0.toByte)))) + } + }, + + test("Client PreparedStatement should be able to retrieve BIT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `bit`, `bit_null` FROM `all_types` WHERE `bit` = ?") + resultSet <- statement.setByte(1, 1.toByte) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((1.toByte, 0.toByte)))) + } + }, + + test("Client PreparedStatement should be able to retrieve TINYINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `tinyint`, `tinyint_null` FROM `all_types` WHERE `tinyint` = ?") + resultSet <- statement.setByte(1, 127.toByte) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((127.toByte, 0.toByte)))) + } + }, + + test("Client PreparedStatement should be able to retrieve unsigned TINYINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement( + "SELECT `tinyint_unsigned`, `tinyint_unsigned_null` FROM `all_types` WHERE `tinyint_unsigned` = ?" + ) + resultSet <- statement.setShort(1, 255.toShort) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { + for + v1 <- resultSet.getShort(1) + v2 <- resultSet.getShort(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((255.toShort, 0.toShort)))) + } + }, + + test("Client PreparedStatement should be able to retrieve SMALLINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `smallint`, `smallint_null` FROM `all_types` WHERE `smallint` = ?") + resultSet <- statement.setShort(1, 32767.toShort) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { + for + v1 <- resultSet.getShort(1) + v2 <- resultSet.getShort(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((32767.toShort, 0.toShort)))) + } + }, + + test("Client PreparedStatement should be able to retrieve unsigned SMALLINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement( + "SELECT `smallint_unsigned`, `smallint_unsigned_null` FROM `all_types` WHERE `smallint_unsigned` = ?" + ) + resultSet <- statement.setInt(1, 65535) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((65535, 0)))) + } + }, + + test("Client PreparedStatement should be able to retrieve MEDIUMINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `mediumint`, `mediumint_null` FROM `all_types` WHERE `mediumint` = ?") + resultSet <- statement.setInt(1, 8388607) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((8388607, 0)))) + } + }, + + test("Client PreparedStatement should be able to retrieve INT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `int`, `int_null` FROM `all_types` WHERE `int` = ?") + resultSet <- statement.setInt(1, 2147483647) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((2147483647, 0)))) + } + }, + + test("Client PreparedStatement should be able to retrieve unsigned INT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `int_unsigned`, `int_unsigned_null` FROM `all_types` WHERE `int_unsigned` = ?" + ) + resultSet <- statement.setLong(1, 4294967295L) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { + for + v1 <- resultSet.getLong(1) + v2 <- resultSet.getLong(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((4294967295L, 0L)))) + } + }, + + test("Client PreparedStatement should be able to retrieve BIGINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `bigint`, `bigint_null` FROM `all_types` WHERE `bigint` = ?") + resultSet <- statement.setLong(1, 9223372036854775807L) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { + for + v1 <- resultSet.getLong(1) + v2 <- resultSet.getLong(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((9223372036854775807L, 0L)))) + } + }, + + test("Client PreparedStatement should be able to retrieve unsigned BIGINT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `bigint_unsigned`, `bigint_unsigned_null` FROM `all_types` WHERE `bigint_unsigned` = ?" + ) + resultSet <- statement.setString(1, "18446744073709551615") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("18446744073709551615", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve FLOAT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `float`, `float_null` FROM `all_types` WHERE `float` > ?") + resultSet <- statement.setFloat(1, 3.4f) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Float, Float)](resultSet.next()) { + for + v1 <- resultSet.getFloat(1) + v2 <- resultSet.getFloat(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((3.40282e38f, 0f)))) + } + }, + + test("Client PreparedStatement should be able to retrieve DOUBLE type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `double`, `double_null` FROM `all_types` WHERE `double` = ?") + resultSet <- statement.setDouble(1, 1.7976931348623157e308) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Double, Double)](resultSet.next()) { + for + v1 <- resultSet.getDouble(1) + v2 <- resultSet.getDouble(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((1.7976931348623157e308, 0.toDouble)))) + } + }, + + test("Client PreparedStatement should be able to retrieve DECIMAL type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `decimal`, `decimal_null` FROM `all_types` WHERE `decimal` = ?") + resultSet <- statement.setBigDecimal(1, BigDecimal.decimal(9999999.99)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (BigDecimal, BigDecimal)](resultSet.next()) { + for + v1 <- resultSet.getBigDecimal(1) + v2 <- resultSet.getBigDecimal(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((BigDecimal.decimal(9999999.99), null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve DATE type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `date`, `date_null` FROM `all_types` WHERE `date` = ?") + resultSet <- statement.setDate(1, LocalDate.of(2020, 1, 1)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalDate, LocalDate)](resultSet.next()) { + for + v1 <- resultSet.getDate(1) + v2 <- resultSet.getDate(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalDate.of(2020, 1, 1), null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve TIME type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `time`, `time_null` FROM `all_types` WHERE `time` = ?") + resultSet <- statement.setTime(1, LocalTime.of(12, 34, 56)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalTime, LocalTime)](resultSet.next()) { + for + v1 <- resultSet.getTime(1) + v2 <- resultSet.getTime(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalTime.of(12, 34, 56), null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve DATETIME type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `datetime`, `datetime_null` FROM `all_types` WHERE `datetime` = ?") + resultSet <- statement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { + for + v1 <- resultSet.getTimestamp(1) + v2 <- resultSet.getTimestamp(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalDateTime.of(2020, 1, 1, 12, 34, 56), null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve TIMESTAMP type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `timestamp`, `timestamp_null` FROM `all_types` WHERE `timestamp` = ?") + resultSet <- statement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { + for + v1 <- resultSet.getTimestamp(1) + v2 <- resultSet.getTimestamp(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalDateTime.of(2020, 1, 1, 12, 34, 56), null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve YEAR type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `year`, `year_null` FROM `all_types` WHERE `year` = ?") + resultSet <- statement.setInt(1, 2020) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((2020, 0)))) + } + }, + + test("Client PreparedStatement should be able to retrieve CHAR type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `char`, `char_null` FROM `all_types` WHERE `char` = ?") + resultSet <- statement.setString(1, "char") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("char", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve VARCHAR type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `varchar`, `varchar_null` FROM `all_types` WHERE `varchar` = ?") + resultSet <- statement.setString(1, "varchar") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("varchar", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve BINARY type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `binary`, `binary_null` FROM `all_types` WHERE `binary` = ?") + resultSet <- + statement.setBytes(1, Array[Byte](98, 105, 110, 97, 114, 121, 0, 0, 0, 0)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, Array[Byte])](resultSet.next()) { + for + v1 <- resultSet.getBytes(1) + v2 <- resultSet.getBytes(2) + yield (v1.mkString(":"), v2) + } + yield assert(result)(equalTo(List((Array[Byte](98, 105, 110, 97, 114, 121, 0, 0, 0, 0).mkString(":"), null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve VARBINARY type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `varbinary`, `varbinary_null` FROM `all_types` WHERE `varbinary` = ?") + resultSet <- statement.setString(1, "varbinary") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("varbinary", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve TINYBLOB type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `tinyblob`, `tinyblob_null` FROM `all_types` WHERE `tinyblob` = ?") + resultSet <- statement.setString(1, "tinyblob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("tinyblob", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve BLOB type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `blob`, `blob_null` FROM `all_types` WHERE `blob` = ?") + resultSet <- statement.setString(1, "blob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("blob", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve MEDIUMBLOB type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `mediumblob`, `mediumblob_null` FROM `all_types` WHERE `mediumblob` = ?" + ) + resultSet <- statement.setString(1, "mediumblob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("mediumblob", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve LONGBLOB type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `longblob`, `longblob_null` FROM `all_types` WHERE `longblob` = ?") + resultSet <- statement.setString(1, "longblob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("longblob", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve TINYTEXT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `tinytext`, `tinytext_null` FROM `all_types` WHERE `tinytext` = ?") + resultSet <- statement.setString(1, "tinytext") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("tinytext", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve TEXT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `text`, `text_null` FROM `all_types` WHERE `text` = ?") + resultSet <- statement.setString(1, "text") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("text", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve MEDIUMTEXT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `mediumtext`, `mediumtext_null` FROM `all_types` WHERE `mediumtext` = ?" + ) + resultSet <- statement.setString(1, "mediumtext") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("mediumtext", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve LONGTEXT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `longtext`, `longtext_null` FROM `all_types` WHERE `longtext` = ?") + resultSet <- statement.setString(1, "longtext") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("longtext", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve ENUM type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `enum`, `enum_null` FROM `all_types` WHERE `enum` = ?") + resultSet <- statement.setString(1, "a") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("a", null)))) + } + }, + + test("Client PreparedStatement should be able to retrieve SET type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `set`, `set_null` FROM `all_types` WHERE `set` = ?") + resultSet <- statement.setString(1, "a,b") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("a,b", null)))) + } + } + ) From 99681d52af4a3c4f4b0cedeab0e0b63da0961982 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 20:11:51 +0900 Subject: [PATCH 025/215] Action sbt scalafmtAll --- .../main/scala/ldbc/zio/interop/package.scala | 4 +- .../scala/ldbc/zio/interop/QueryTest.scala | 432 ++++++++---------- 2 files changed, 202 insertions(+), 234 deletions(-) diff --git a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala index 6e9cf9a2c..5aec7da8a 100644 --- a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala +++ b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala @@ -8,8 +8,8 @@ package ldbc.zio import java.util.UUID -import cats.effect.Async import cats.effect.std.UUIDGen +import cats.effect.Async import fs2.hashing.Hashing import fs2.io.net.Network @@ -17,7 +17,7 @@ import fs2.io.net.Network import zio.* package object interop: - implicit def asyncToZIO: Async[Task] = zio.interop.catz.asyncInstance + implicit def asyncToZIO: Async[Task] = zio.interop.catz.asyncInstance implicit def consoleToZIO: cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] implicit def uuidGenToZIO: UUIDGen[Task] = new UUIDGen[Task]: override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) diff --git a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala index 60a092c5d..14a5f4c64 100644 --- a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala +++ b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala @@ -10,13 +10,13 @@ import java.time.* import cats.Monad +import ldbc.connector.* +import ldbc.connector.data.* + import zio.* import zio.test.* import zio.test.Assertion.* -import ldbc.connector.* -import ldbc.connector.data.* - object QueryTest extends ZIOSpecDefault: private val datasource = @@ -32,47 +32,44 @@ object QueryTest extends ZIOSpecDefault: for statement <- conn.prepareStatement("SELECT `bit`, `bit_null` FROM `all_types` WHERE `bit_null` is ?") resultSet <- statement.setNull(1, MysqlType.BIT.jdbcType) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { - for - v1 <- resultSet.getByte(1) - v2 <- resultSet.getByte(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((1.toByte, 0.toByte)))) } }, - test("Client PreparedStatement should be able to retrieve BIT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `bit`, `bit_null` FROM `all_types` WHERE `bit` = ?") resultSet <- statement.setByte(1, 1.toByte) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { - for - v1 <- resultSet.getByte(1) - v2 <- resultSet.getByte(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((1.toByte, 0.toByte)))) } }, - test("Client PreparedStatement should be able to retrieve TINYINT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `tinyint`, `tinyint_null` FROM `all_types` WHERE `tinyint` = ?") resultSet <- statement.setByte(1, 127.toByte) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { - for - v1 <- resultSet.getByte(1) - v2 <- resultSet.getByte(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((127.toByte, 0.toByte)))) } }, - test("Client PreparedStatement should be able to retrieve unsigned TINYINT type records") { datasource.getConnection.use { conn => for @@ -81,32 +78,30 @@ object QueryTest extends ZIOSpecDefault: "SELECT `tinyint_unsigned`, `tinyint_unsigned_null` FROM `all_types` WHERE `tinyint_unsigned` = ?" ) resultSet <- statement.setShort(1, 255.toShort) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { - for - v1 <- resultSet.getShort(1) - v2 <- resultSet.getShort(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { + for + v1 <- resultSet.getShort(1) + v2 <- resultSet.getShort(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((255.toShort, 0.toShort)))) } }, - test("Client PreparedStatement should be able to retrieve SMALLINT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `smallint`, `smallint_null` FROM `all_types` WHERE `smallint` = ?") resultSet <- statement.setShort(1, 32767.toShort) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { - for - v1 <- resultSet.getShort(1) - v2 <- resultSet.getShort(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { + for + v1 <- resultSet.getShort(1) + v2 <- resultSet.getShort(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((32767.toShort, 0.toShort)))) } }, - test("Client PreparedStatement should be able to retrieve unsigned SMALLINT type records") { datasource.getConnection.use { conn => for @@ -115,47 +110,44 @@ object QueryTest extends ZIOSpecDefault: "SELECT `smallint_unsigned`, `smallint_unsigned_null` FROM `all_types` WHERE `smallint_unsigned` = ?" ) resultSet <- statement.setInt(1, 65535) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { - for - v1 <- resultSet.getInt(1) - v2 <- resultSet.getInt(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((65535, 0)))) } }, - test("Client PreparedStatement should be able to retrieve MEDIUMINT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `mediumint`, `mediumint_null` FROM `all_types` WHERE `mediumint` = ?") resultSet <- statement.setInt(1, 8388607) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { - for - v1 <- resultSet.getInt(1) - v2 <- resultSet.getInt(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((8388607, 0)))) } }, - test("Client PreparedStatement should be able to retrieve INT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `int`, `int_null` FROM `all_types` WHERE `int` = ?") resultSet <- statement.setInt(1, 2147483647) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { - for - v1 <- resultSet.getInt(1) - v2 <- resultSet.getInt(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((2147483647, 0)))) } }, - test("Client PreparedStatement should be able to retrieve unsigned INT type records") { datasource.getConnection.use { conn => for @@ -163,32 +155,30 @@ object QueryTest extends ZIOSpecDefault: "SELECT `int_unsigned`, `int_unsigned_null` FROM `all_types` WHERE `int_unsigned` = ?" ) resultSet <- statement.setLong(1, 4294967295L) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { - for - v1 <- resultSet.getLong(1) - v2 <- resultSet.getLong(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { + for + v1 <- resultSet.getLong(1) + v2 <- resultSet.getLong(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((4294967295L, 0L)))) } }, - test("Client PreparedStatement should be able to retrieve BIGINT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `bigint`, `bigint_null` FROM `all_types` WHERE `bigint` = ?") resultSet <- statement.setLong(1, 9223372036854775807L) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { - for - v1 <- resultSet.getLong(1) - v2 <- resultSet.getLong(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { + for + v1 <- resultSet.getLong(1) + v2 <- resultSet.getLong(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((9223372036854775807L, 0L)))) } }, - test("Client PreparedStatement should be able to retrieve unsigned BIGINT type records") { datasource.getConnection.use { conn => for @@ -196,171 +186,160 @@ object QueryTest extends ZIOSpecDefault: "SELECT `bigint_unsigned`, `bigint_unsigned_null` FROM `all_types` WHERE `bigint_unsigned` = ?" ) resultSet <- statement.setString(1, "18446744073709551615") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("18446744073709551615", null)))) } }, - test("Client PreparedStatement should be able to retrieve FLOAT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `float`, `float_null` FROM `all_types` WHERE `float` > ?") resultSet <- statement.setFloat(1, 3.4f) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Float, Float)](resultSet.next()) { - for - v1 <- resultSet.getFloat(1) - v2 <- resultSet.getFloat(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Float, Float)](resultSet.next()) { + for + v1 <- resultSet.getFloat(1) + v2 <- resultSet.getFloat(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((3.40282e38f, 0f)))) } }, - test("Client PreparedStatement should be able to retrieve DOUBLE type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `double`, `double_null` FROM `all_types` WHERE `double` = ?") resultSet <- statement.setDouble(1, 1.7976931348623157e308) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Double, Double)](resultSet.next()) { - for - v1 <- resultSet.getDouble(1) - v2 <- resultSet.getDouble(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Double, Double)](resultSet.next()) { + for + v1 <- resultSet.getDouble(1) + v2 <- resultSet.getDouble(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((1.7976931348623157e308, 0.toDouble)))) } }, - test("Client PreparedStatement should be able to retrieve DECIMAL type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `decimal`, `decimal_null` FROM `all_types` WHERE `decimal` = ?") resultSet <- statement.setBigDecimal(1, BigDecimal.decimal(9999999.99)) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (BigDecimal, BigDecimal)](resultSet.next()) { - for - v1 <- resultSet.getBigDecimal(1) - v2 <- resultSet.getBigDecimal(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (BigDecimal, BigDecimal)](resultSet.next()) { + for + v1 <- resultSet.getBigDecimal(1) + v2 <- resultSet.getBigDecimal(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((BigDecimal.decimal(9999999.99), null)))) } }, - test("Client PreparedStatement should be able to retrieve DATE type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `date`, `date_null` FROM `all_types` WHERE `date` = ?") resultSet <- statement.setDate(1, LocalDate.of(2020, 1, 1)) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (LocalDate, LocalDate)](resultSet.next()) { - for - v1 <- resultSet.getDate(1) - v2 <- resultSet.getDate(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (LocalDate, LocalDate)](resultSet.next()) { + for + v1 <- resultSet.getDate(1) + v2 <- resultSet.getDate(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((LocalDate.of(2020, 1, 1), null)))) } }, - test("Client PreparedStatement should be able to retrieve TIME type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `time`, `time_null` FROM `all_types` WHERE `time` = ?") resultSet <- statement.setTime(1, LocalTime.of(12, 34, 56)) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (LocalTime, LocalTime)](resultSet.next()) { - for - v1 <- resultSet.getTime(1) - v2 <- resultSet.getTime(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (LocalTime, LocalTime)](resultSet.next()) { + for + v1 <- resultSet.getTime(1) + v2 <- resultSet.getTime(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((LocalTime.of(12, 34, 56), null)))) } }, - test("Client PreparedStatement should be able to retrieve DATETIME type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `datetime`, `datetime_null` FROM `all_types` WHERE `datetime` = ?") resultSet <- statement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { - for - v1 <- resultSet.getTimestamp(1) - v2 <- resultSet.getTimestamp(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { + for + v1 <- resultSet.getTimestamp(1) + v2 <- resultSet.getTimestamp(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((LocalDateTime.of(2020, 1, 1, 12, 34, 56), null)))) } }, - test("Client PreparedStatement should be able to retrieve TIMESTAMP type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `timestamp`, `timestamp_null` FROM `all_types` WHERE `timestamp` = ?") resultSet <- statement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { - for - v1 <- resultSet.getTimestamp(1) - v2 <- resultSet.getTimestamp(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { + for + v1 <- resultSet.getTimestamp(1) + v2 <- resultSet.getTimestamp(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((LocalDateTime.of(2020, 1, 1, 12, 34, 56), null)))) } }, - test("Client PreparedStatement should be able to retrieve YEAR type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `year`, `year_null` FROM `all_types` WHERE `year` = ?") resultSet <- statement.setInt(1, 2020) *> statement.executeQuery() - result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { - for - v1 <- resultSet.getInt(1) - v2 <- resultSet.getInt(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } yield assert(result)(equalTo(List((2020, 0)))) } }, - test("Client PreparedStatement should be able to retrieve CHAR type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `char`, `char_null` FROM `all_types` WHERE `char` = ?") resultSet <- statement.setString(1, "char") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("char", null)))) } }, - test("Client PreparedStatement should be able to retrieve VARCHAR type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `varchar`, `varchar_null` FROM `all_types` WHERE `varchar` = ?") resultSet <- statement.setString(1, "varchar") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("varchar", null)))) } }, - test("Client PreparedStatement should be able to retrieve BINARY type records") { datasource.getConnection.use { conn => for @@ -369,62 +348,58 @@ object QueryTest extends ZIOSpecDefault: resultSet <- statement.setBytes(1, Array[Byte](98, 105, 110, 97, 114, 121, 0, 0, 0, 0)) *> statement.executeQuery() result <- Monad[Task].whileM[List, (String, Array[Byte])](resultSet.next()) { - for - v1 <- resultSet.getBytes(1) - v2 <- resultSet.getBytes(2) - yield (v1.mkString(":"), v2) - } + for + v1 <- resultSet.getBytes(1) + v2 <- resultSet.getBytes(2) + yield (v1.mkString(":"), v2) + } yield assert(result)(equalTo(List((Array[Byte](98, 105, 110, 97, 114, 121, 0, 0, 0, 0).mkString(":"), null)))) } }, - test("Client PreparedStatement should be able to retrieve VARBINARY type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `varbinary`, `varbinary_null` FROM `all_types` WHERE `varbinary` = ?") resultSet <- statement.setString(1, "varbinary") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("varbinary", null)))) } }, - test("Client PreparedStatement should be able to retrieve TINYBLOB type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `tinyblob`, `tinyblob_null` FROM `all_types` WHERE `tinyblob` = ?") resultSet <- statement.setString(1, "tinyblob") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("tinyblob", null)))) } }, - test("Client PreparedStatement should be able to retrieve BLOB type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `blob`, `blob_null` FROM `all_types` WHERE `blob` = ?") resultSet <- statement.setString(1, "blob") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("blob", null)))) } }, - test("Client PreparedStatement should be able to retrieve MEDIUMBLOB type records") { datasource.getConnection.use { conn => for @@ -432,63 +407,59 @@ object QueryTest extends ZIOSpecDefault: "SELECT `mediumblob`, `mediumblob_null` FROM `all_types` WHERE `mediumblob` = ?" ) resultSet <- statement.setString(1, "mediumblob") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("mediumblob", null)))) } }, - test("Client PreparedStatement should be able to retrieve LONGBLOB type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `longblob`, `longblob_null` FROM `all_types` WHERE `longblob` = ?") resultSet <- statement.setString(1, "longblob") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("longblob", null)))) } }, - test("Client PreparedStatement should be able to retrieve TINYTEXT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `tinytext`, `tinytext_null` FROM `all_types` WHERE `tinytext` = ?") resultSet <- statement.setString(1, "tinytext") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("tinytext", null)))) } }, - test("Client PreparedStatement should be able to retrieve TEXT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `text`, `text_null` FROM `all_types` WHERE `text` = ?") resultSet <- statement.setString(1, "text") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("text", null)))) } }, - test("Client PreparedStatement should be able to retrieve MEDIUMTEXT type records") { datasource.getConnection.use { conn => for @@ -496,58 +467,55 @@ object QueryTest extends ZIOSpecDefault: "SELECT `mediumtext`, `mediumtext_null` FROM `all_types` WHERE `mediumtext` = ?" ) resultSet <- statement.setString(1, "mediumtext") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("mediumtext", null)))) } }, - test("Client PreparedStatement should be able to retrieve LONGTEXT type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `longtext`, `longtext_null` FROM `all_types` WHERE `longtext` = ?") resultSet <- statement.setString(1, "longtext") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("longtext", null)))) } }, - test("Client PreparedStatement should be able to retrieve ENUM type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `enum`, `enum_null` FROM `all_types` WHERE `enum` = ?") resultSet <- statement.setString(1, "a") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("a", null)))) } }, - test("Client PreparedStatement should be able to retrieve SET type records") { datasource.getConnection.use { conn => for statement <- conn.prepareStatement("SELECT `set`, `set_null` FROM `all_types` WHERE `set` = ?") resultSet <- statement.setString(1, "a,b") *> statement.executeQuery() - result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { - for - v1 <- resultSet.getString(1) - v2 <- resultSet.getString(2) - yield (v1, v2) - } + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } yield assert(result)(equalTo(List(("a,b", null)))) } } From 645ccf5e3522a2e40530aaa6624a0d24f2472826 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 20:12:04 +0900 Subject: [PATCH 026/215] Action sbt scalafmtSbt --- build.sbt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index 6b0c1d7c3..e48df3720 100644 --- a/build.sbt +++ b/build.sbt @@ -162,8 +162,8 @@ lazy val zioInterop = crossProject(JVMPlatform, JSPlatform) libraryDependencies ++= Seq( "dev.zio" %%% "zio" % "2.1.21", "dev.zio" %%% "zio-interop-cats" % "23.1.0.5", - "dev.zio" %%% "zio-test" % "2.1.21" % Test, - "dev.zio" %%% "zio-test-sbt" % "2.1.21" % Test, + "dev.zio" %%% "zio-test" % "2.1.21" % Test, + "dev.zio" %%% "zio-test-sbt" % "2.1.21" % Test ) ) .dependsOn(connector % "test->compile") From 53d31402976ea90d78858de2f6bd6e3e3922f15f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 20:20:58 +0900 Subject: [PATCH 027/215] Create ZIO Update test --- .../scala/ldbc/zio/interop/UpdateTest.scala | 282 ++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala diff --git a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala new file mode 100644 index 000000000..3df2f21b0 --- /dev/null +++ b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala @@ -0,0 +1,282 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.zio.interop + +import java.time.* + +import ldbc.connector.* +import ldbc.connector.data.* + +import zio.* +import zio.test.* +import zio.test.Assertion.* + +object UpdateTest extends ZIOSpecDefault: + + private val datasource = + MySQLDataSource + .build[Task]("127.0.0.1", 13306, "ldbc") + .setPassword("password") + .setDatabase("connector_test") + .setSSL(SSL.Trusted) + + def spec = suite("UpdateTest")( + test("Boolean values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_boolean_table`(`c1` BOOLEAN NOT NULL, `c2` BOOLEAN NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_boolean_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setBoolean(1, true) *> preparedStatement + .setNull(2, MysqlType.BOOLEAN.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_boolean_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Byte values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_byte_table`(`c1` BIT NOT NULL, `c2` BIT NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_byte_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setByte(1, 1.toByte) *> preparedStatement + .setNull(2, MysqlType.BIT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_byte_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Short values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_short_table`(`c1` TINYINT NOT NULL, `c2` TINYINT NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_short_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setShort(1, 1.toShort) *> preparedStatement + .setNull(2, MysqlType.TINYINT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_short_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Int values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_int_table`(`c1` SMALLINT NOT NULL, `c2` SMALLINT NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_int_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setInt(1, 1) *> preparedStatement.setNull( + 2, + MysqlType.SMALLINT.jdbcType + ) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_int_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Long values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_long_table`(`c1` BIGINT NOT NULL, `c2` BIGINT NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_long_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setLong(1, Long.MaxValue) *> preparedStatement + .setNull(2, MysqlType.BIGINT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_long_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("BigInt values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- + statement.executeUpdate( + "CREATE TABLE `client_statement_bigint_table`(`c1` BIGINT unsigned NOT NULL, `c2` BIGINT unsigned NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_bigint_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setString(1, "18446744073709551615") *> preparedStatement + .setNull(2, MysqlType.BIGINT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_bigint_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Float values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- + statement.executeUpdate("CREATE TABLE `client_statement_float_table`(`c1` FLOAT NOT NULL, `c2` FLOAT NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_float_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setFloat(1, 1.1f) *> preparedStatement + .setNull(2, MysqlType.FLOAT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_float_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Double values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_double_table`(`c1` DOUBLE NOT NULL, `c2` DOUBLE NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_double_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setDouble(1, 1.1) *> preparedStatement + .setNull(2, MysqlType.DOUBLE.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_double_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("BigDecimal values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_bigdecimal_table`(`c1` DECIMAL NOT NULL, `c2` DECIMAL NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_bigdecimal_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setBigDecimal(1, BigDecimal.decimal(1.1)) *> preparedStatement + .setNull(2, MysqlType.DECIMAL.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_bigdecimal_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("String values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_string_table`(`c1` VARCHAR(255) NOT NULL, `c2` VARCHAR(255) NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_string_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setString(1, "test") *> preparedStatement + .setNull(2, MysqlType.VARCHAR.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_string_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("Array[Byte] values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_bytes_table`(`c1` BINARY(10) NOT NULL, `c2` BINARY NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_bytes_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setBytes(1, Array[Byte](98, 105, 110, 97, 114, 121)) *> preparedStatement + .setNull(2, MysqlType.BINARY.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_bytes_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("java.time.LocalTime values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_time_table`(`c1` TIME NOT NULL, `c2` TIME NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_time_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setTime(1, LocalTime.of(12, 34, 56)) *> preparedStatement.setNull( + 2, + MysqlType.TIME.jdbcType + ) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_time_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("java.time.LocalDate values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_date_table`(`c1` DATE NOT NULL, `c2` DATE NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_date_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setDate(1, LocalDate.of(2020, 1, 1)) *> preparedStatement.setNull( + 2, + MysqlType.DATE.jdbcType + ) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_date_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("java.time.LocalDateTime values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_datetime_table`(`c1` TIMESTAMP NOT NULL, `c2` TIMESTAMP NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_datetime_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> preparedStatement + .setNull(2, MysqlType.TIMESTAMP.jdbcType) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_datetime_table`") + yield assert(count)(equalTo(1)) + } + }, + + test("java.time.Year values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_year_table`(`c1` YEAR NOT NULL, `c2` YEAR NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_year_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setInt(1, 2020) *> preparedStatement + .setNull(2, MysqlType.YEAR.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_year_table`") + yield assert(count)(equalTo(1)) + } + } + ) \ No newline at end of file From ed5fa1fb7999d63b22655abd81717111afad5e5d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 20:21:14 +0900 Subject: [PATCH 028/215] Action sbt scalafmtAll --- .../test/scala/ldbc/zio/interop/UpdateTest.scala | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala index 3df2f21b0..3b571ea32 100644 --- a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala +++ b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala @@ -41,7 +41,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Byte values can be set as parameters") { datasource.getConnection.use { conn => for @@ -56,7 +55,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Short values can be set as parameters") { datasource.getConnection.use { conn => for @@ -73,7 +71,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Int values can be set as parameters") { datasource.getConnection.use { conn => for @@ -92,7 +89,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Long values can be set as parameters") { datasource.getConnection.use { conn => for @@ -110,7 +106,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("BigInt values can be set as parameters") { datasource.getConnection.use { conn => for @@ -128,7 +123,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Float values can be set as parameters") { datasource.getConnection.use { conn => for @@ -144,7 +138,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Double values can be set as parameters") { datasource.getConnection.use { conn => for @@ -161,7 +154,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("BigDecimal values can be set as parameters") { datasource.getConnection.use { conn => for @@ -178,7 +170,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("String values can be set as parameters") { datasource.getConnection.use { conn => for @@ -195,7 +186,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("Array[Byte] values can be set as parameters") { datasource.getConnection.use { conn => for @@ -213,7 +203,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("java.time.LocalTime values can be set as parameters") { datasource.getConnection.use { conn => for @@ -230,7 +219,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("java.time.LocalDate values can be set as parameters") { datasource.getConnection.use { conn => for @@ -247,7 +235,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("java.time.LocalDateTime values can be set as parameters") { datasource.getConnection.use { conn => for @@ -263,7 +250,6 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } }, - test("java.time.Year values can be set as parameters") { datasource.getConnection.use { conn => for @@ -279,4 +265,4 @@ object UpdateTest extends ZIOSpecDefault: yield assert(count)(equalTo(1)) } } - ) \ No newline at end of file + ) From ea7c43155d41fc130dcb0bd8a3257a1279a5d56c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 5 Oct 2025 21:18:53 +0900 Subject: [PATCH 029/215] Added scalaJSLinkerConfig --- build.sbt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/build.sbt b/build.sbt index e48df3720..7b6f94f68 100644 --- a/build.sbt +++ b/build.sbt @@ -166,6 +166,9 @@ lazy val zioInterop = crossProject(JVMPlatform, JSPlatform) "dev.zio" %%% "zio-test-sbt" % "2.1.21" % Test ) ) + .jsSettings( + Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) + ) .dependsOn(connector % "test->compile") lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) From 4d9641e021a4ef91cf4415e9dd1dc34f048f1be6 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 6 Oct 2025 00:37:22 +0900 Subject: [PATCH 030/215] Create zio example project --- build.sbt | 15 +++++++- examples/zio/src/main/scala/Main.scala | 50 ++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 examples/zio/src/main/scala/Main.scala diff --git a/build.sbt b/build.sbt index 7b6f94f68..e08aea6ed 100644 --- a/build.sbt +++ b/build.sbt @@ -272,6 +272,18 @@ lazy val otelExample = crossProject(JVMPlatform) ) .dependsOn(connector, dsl) +lazy val zioExample = crossProject(JVMPlatform) + .crossType(CrossType.Pure) + .withoutSuffixFor(JVMPlatform) + .example("zio", "ZIO example project") + .settings( + libraryDependencies ++= Seq( + "dev.zio" %% "zio-http" % "3.5.1", + "dev.zio" %% "zio-json" % "0.7.44" + ) + ) + .dependsOn(connector, dsl, zioInterop) + lazy val docs = (project in file("docs")) .settings( description := "Documentation for ldbc", @@ -372,7 +384,8 @@ lazy val mcpDocumentServer = crossProject(JSPlatform) lazy val examples = Seq( http4sExample, hikariCPExample, - otelExample + otelExample, + zioExample ) lazy val ldbc = tlCrossRootProject diff --git a/examples/zio/src/main/scala/Main.scala b/examples/zio/src/main/scala/Main.scala new file mode 100644 index 000000000..b98988323 --- /dev/null +++ b/examples/zio/src/main/scala/Main.scala @@ -0,0 +1,50 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +import zio.* +import zio.http.* +import zio.json.* +import zio.interop.catz.* + +import ldbc.connector.* +import ldbc.zio.interop.* +import ldbc.dsl.* + +object Main extends ZIOAppDefault: + + private val poolConfig = MySQLConfig.default + .setHost("localhost") + .setPort(13306) + .setUser("ldbc") + .setPassword("password") + .setDatabase("world") + .setSSL(SSL.Trusted) + .setMinConnections(5) + .setMaxConnections(10) + + private val connectorLayer: ZLayer[Any, Throwable, Connector[Task]] = + ZLayer.scoped { + MySQLDataSource.pooling[Task](poolConfig).map(ds => Connector.fromDataSource[Task](ds)).toScopedZIO + } + + private val routes = Routes( + Method.GET / Root -> handler(Response.text("Hello, World!")), + Method.GET / Root / "countries" -> handler { + for + connector <- ZIO.service[Connector[Task]] + countries <- sql"SELECT Name FROM country".query[String].to[List].readOnly(connector) + yield Response.json(countries.toJson) + }.catchAll { error => + handler(Response.json(Map("error" -> error.getMessage).toJson)) + }, + ) + + override def run = + Server.serve(routes) + .provide( + Server.default, + connectorLayer + ) From be803c22f8b83c0ff988cfef8fa2f988262c2f8b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 6 Oct 2025 00:37:57 +0900 Subject: [PATCH 031/215] Action sbt scalafmtAll --- examples/zio/src/main/scala/Main.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/zio/src/main/scala/Main.scala b/examples/zio/src/main/scala/Main.scala index b98988323..b00613760 100644 --- a/examples/zio/src/main/scala/Main.scala +++ b/examples/zio/src/main/scala/Main.scala @@ -4,14 +4,16 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ -import zio.* -import zio.http.* -import zio.json.* -import zio.interop.catz.* +import ldbc.dsl.* import ldbc.connector.* + import ldbc.zio.interop.* -import ldbc.dsl.* + +import zio.* +import zio.http.* +import zio.interop.catz.* +import zio.json.* object Main extends ZIOAppDefault: @@ -31,7 +33,7 @@ object Main extends ZIOAppDefault: } private val routes = Routes( - Method.GET / Root -> handler(Response.text("Hello, World!")), + Method.GET / Root -> handler(Response.text("Hello, World!")), Method.GET / Root / "countries" -> handler { for connector <- ZIO.service[Connector[Task]] @@ -39,11 +41,12 @@ object Main extends ZIOAppDefault: yield Response.json(countries.toJson) }.catchAll { error => handler(Response.json(Map("error" -> error.getMessage).toJson)) - }, + } ) override def run = - Server.serve(routes) + Server + .serve(routes) .provide( Server.default, connectorLayer From be5e098f82468bf58ec8ee6ecfa689f566cab5b1 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 6 Oct 2025 00:38:09 +0900 Subject: [PATCH 032/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ec7a02991..34523d500 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -390,7 +390,7 @@ jobs: - name: Submit Dependencies uses: scalacenter/sbt-dependency-submission@v2 with: - modules-ignore: ldbcjs_3 ldbcjs_3 otel_3 mcp-ldbc-document-server_sjs1_3 docs_3 docs_3 ldbcnative_3 ldbcnative_3 ldbcjvm_3 ldbcjvm_3 hikaricp_3 tests_sjs1_3 tests_sjs1_3 http4s_3 tests_3 tests_3 benchmark_3 benchmark_3 tests_native0.4_3 tests_native0.4_3 + modules-ignore: ldbcjs_3 ldbcjs_3 otel_3 mcp-ldbc-document-server_sjs1_3 docs_3 docs_3 zio_3 ldbcnative_3 ldbcnative_3 ldbcjvm_3 ldbcjvm_3 hikaricp_3 tests_sjs1_3 tests_sjs1_3 http4s_3 tests_3 tests_3 benchmark_3 benchmark_3 tests_native0.4_3 tests_native0.4_3 configs-ignore: test scala-tool scala-doc-tool test-internal validate-steward: From ed9a80af20766c72c9a7e732be39db664755b09f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 6 Oct 2025 23:01:40 +0900 Subject: [PATCH 033/215] Added ignore codecov.yml --- codecov.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/codecov.yml b/codecov.yml index f8a69ab4a..cd19d211e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -23,3 +23,4 @@ coverage: - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Initial.scala" - "module/ldbc-schema/src/main/scala/ldbc/schema/Character.scala" - "module/ldbc-dsl/src/main/scala/ldbc/dsl/ResultSetConsumer.scala" + - "module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala" From 292c0c19fa459253d07db39d4f156209e4be5897 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 11 Oct 2025 20:08:17 +0900 Subject: [PATCH 034/215] Delete unused --- .../src/main/scala/ldbc/dsl/Mysql.scala | 2 +- .../src/main/scala/ldbc/dsl/Query.scala | 3 +- .../src/main/scala/ldbc/dsl/package.scala | 42 ------------------- 3 files changed, 2 insertions(+), 45 deletions(-) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala index c63ff3c53..87657e11e 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala @@ -21,7 +21,7 @@ import ldbc.dsl.codec.Decoder * @param params * statement has '?' that the statement has. */ -case class Mysql(statement: String, params: List[Parameter.Dynamic]) extends SQL, ParamBinder: +case class Mysql(statement: String, params: List[Parameter.Dynamic]) extends SQL: @targetName("combine") override def ++(sql: SQL): Mysql = diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala index 2da1b066d..1e9593e57 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala @@ -68,8 +68,7 @@ object Query: statement: String, params: List[Parameter.Dynamic], decoder: Decoder[T] - ) extends Query[T], - ParamBinder: + ) extends Query[T]: override def to[G[_]](using factory: FactoryCompat[T, G[T]]): DBIO[G[T]] = DBIO.queryTo(statement, params, decoder, factory) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala index 2c2ec27f0..23604917e 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala @@ -6,14 +6,6 @@ package ldbc -import java.time.* - -import cats.syntax.all.* -import cats.MonadThrow - -import ldbc.sql.PreparedStatement - -import ldbc.dsl.codec.Encoder import ldbc.dsl.syntax.* /** @@ -28,37 +20,3 @@ import ldbc.dsl.syntax.* package object dsl extends HelperFunctionsSyntax: export ldbc.DBIO - - private[ldbc] trait ParamBinder: - protected def paramBind[F[_]: MonadThrow]( - prepareStatement: PreparedStatement[F], - params: List[Parameter.Dynamic] - ): F[Unit] = - val encoded = params.foldLeft(MonadThrow[F].pure(List.empty[Encoder.Supported])) { - case (acc, param) => - for - acc$ <- acc - value <- param match - case Parameter.Dynamic.Success(value) => MonadThrow[F].pure(value) - case Parameter.Dynamic.Failure(errors) => - MonadThrow[F].raiseError(new IllegalArgumentException(errors.mkString(", "))) - yield acc$ :+ value - } - encoded.flatMap(_.zipWithIndex.foldLeft(MonadThrow[F].unit) { - case (acc, (value, index)) => - acc *> (value match - case value: Boolean => prepareStatement.setBoolean(index + 1, value) - case value: Byte => prepareStatement.setByte(index + 1, value) - case value: Short => prepareStatement.setShort(index + 1, value) - case value: Int => prepareStatement.setInt(index + 1, value) - case value: Long => prepareStatement.setLong(index + 1, value) - case value: Float => prepareStatement.setFloat(index + 1, value) - case value: Double => prepareStatement.setDouble(index + 1, value) - case value: BigDecimal => prepareStatement.setBigDecimal(index + 1, value) - case value: String => prepareStatement.setString(index + 1, value) - case value: Array[Byte] => prepareStatement.setBytes(index + 1, value) - case value: LocalDate => prepareStatement.setDate(index + 1, value) - case value: LocalTime => prepareStatement.setTime(index + 1, value) - case value: LocalDateTime => prepareStatement.setTimestamp(index + 1, value) - case None => prepareStatement.setNull(index + 1, ldbc.sql.Types.NULL)) - }) From d9a169bfa0aa517b39dcb5782add71b79b54b361 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 11 Oct 2025 20:11:20 +0900 Subject: [PATCH 035/215] Delete unused --- project/BuildSettings.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/project/BuildSettings.scala b/project/BuildSettings.scala index 0b498e202..9141e4955 100644 --- a/project/BuildSettings.scala +++ b/project/BuildSettings.scala @@ -65,17 +65,6 @@ object BuildSettings { ) ) - /** A project that runs in the sbt runtime. */ - object LepusSbtProject { - def apply(name: String, dir: String): Project = - Project(name, file(dir)) - .settings(scalaVersion := scala3) - .settings(scalacOptions ++= additionalSettings) - .settings(scalacOptions --= removeSettings) - .settings(commonSettings) - .enablePlugins(AutomateHeaderPlugin) - } - /** A project that is an sbt plugin. */ object LepusSbtPluginProject { def apply(name: String, dir: String): Project = From 0406c2abb64370f5d3fd4bfe17092e8b6d726dc4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 30 Oct 2025 00:00:57 +0900 Subject: [PATCH 036/215] Change active to mima binary check --- build.sbt | 1 - 1 file changed, 1 deletion(-) diff --git a/build.sbt b/build.sbt index 8bd792f37..6abe2870a 100644 --- a/build.sbt +++ b/build.sbt @@ -30,7 +30,6 @@ ThisBuild / githubWorkflowBuildPostamble += dockerStop ThisBuild / githubWorkflowTargetBranches := Seq("**") ThisBuild / githubWorkflowPublishTargetBranches := Seq(RefPredicate.StartsWith(Ref.Tag("v"))) ThisBuild / tlSitePublishBranch := None -ThisBuild / tlCiMimaBinaryIssueCheck := false lazy val sql = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Pure) From 2b6bdaa940e5de6bff1c82d0dd8e128f9c6d01dc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 30 Oct 2025 00:01:05 +0900 Subject: [PATCH 037/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 19aed2f2e..e643c17a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -142,6 +142,10 @@ jobs: - name: Test run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test + - name: Check binary compatibility + if: matrix.java == 'corretto@11' && matrix.os == 'ubuntu-22.04' + run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues + - name: Generate API documentation if: matrix.java == 'corretto@11' && matrix.os == 'ubuntu-22.04' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc From 68da1f575f93b0dc18a22f0b6ee2da19e5d60f8d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 19:12:34 +0900 Subject: [PATCH 038/215] Create execute{Query|Update} function --- .../src/main/scala/ldbc/free/KleisliInterpreter.scala | 2 ++ .../ldbc-core/src/main/scala/ldbc/free/StatementIO.scala | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala index d8ff38a01..a1d15bd16 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala @@ -140,6 +140,8 @@ class KleisliInterpreter[F[_]: Sync](logHandler: LogHandler[F]) extends Interpre override def onCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]): Kleisli[F, Statement[F], A] = outer.onCancel(this)(fa, fin) + override def executeQuery(sql: String): Kleisli[F, Statement[F], ResultSet[?]] = primitive[Statement[F], ResultSet[F]](_.executeQuery(sql)).asInstanceOf[Kleisli[F, Statement[F], ResultSet[?]]] + override def executeUpdate(sql: String): Kleisli[F, Statement[F], Int] = primitive(_.executeUpdate(sql)) override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) diff --git a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala index 026e1f820..d933cd928 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala @@ -42,6 +42,10 @@ object StatementOp: final case class OnCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]) extends StatementOp[A]: override def visit[F[_]](v: StatementOp.Visitor[F]): F[A] = v.onCancel(fa, fin) + final case class ExecuteQuery(sql: String) extends StatementOp[ResultSet[?]]: + override def visit[F[_]](v: StatementOp.Visitor[F]): F[ResultSet[?]] = v.executeQuery(sql) + final case class ExecuteUpdate(sql: String) extends StatementOp[Int]: + override def visit[F[_]](v: StatementOp.Visitor[F]): F[Int] = v.executeUpdate(sql) final case class AddBatch[A](sql: String) extends StatementOp[Unit]: override def visit[F[_]](v: StatementOp.Visitor[F]): F[Unit] = v.addBatch(sql) final case class ExecuteBatch[A]() extends StatementOp[Array[Int]]: @@ -67,6 +71,8 @@ object StatementOp: def canceled: F[Unit] def onCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]): F[A] + def executeQuery(sql: String): F[ResultSet[?]] + def executeUpdate(sql: String): F[Int] def addBatch(sql: String): F[Unit] def executeBatch(): F[Array[Int]] @@ -95,5 +101,7 @@ object StatementIO: def capturePoll[M[_]](mpoll: Poll[M]): Poll[StatementIO] = new Poll[StatementIO]: override def apply[A](fa: StatementIO[A]): StatementIO[A] = Free.liftF[StatementOp, A](StatementOp.Poll1(mpoll, fa)) + def executeQuery(sql: String): StatementIO[ResultSet[?]] = Free.liftF[StatementOp, ResultSet[?]](StatementOp.ExecuteQuery(sql)) + def executeUpdate(sql: String): StatementIO[Int] = Free.liftF[StatementOp, Int](StatementOp.ExecuteUpdate(sql)) def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) From 24e704d324217898b433fbd3af1a126b4b06e228 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 19:13:50 +0900 Subject: [PATCH 039/215] Create updateRaw & updateRaws method --- .../src/main/scala/ldbc/dsl/DBIO.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index e43ffe5a8..a531e0289 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -222,6 +222,26 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + def updateRaw(statement: String): DBIO[Int] = + (for + stmt <- ConnectionIO.createStatement() + result <- ConnectionIO.embed(stmt, StatementIO.executeUpdate(statement)) + yield result).onError { ex => + ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) + } <* + ConnectionIO.performLogging(LogEvent.Success(statement, List.empty)) + + def updateRaws(statement: String): DBIO[Array[Int]] = + (for + stmt <- ConnectionIO.createStatement() + statements = statement.trim.split(";").toList + _ <- ConnectionIO.embed(stmt, statements.map(statement => StatementIO.addBatch(statement)).sequence) + result <- ConnectionIO.embed(stmt, StatementIO.executeBatch()) + yield result).onError { ex => + ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) + } <* + ConnectionIO.performLogging(LogEvent.Success(statement, List.empty)) + def returning[A]( statement: String, params: List[Parameter.Dynamic], From 697233cc75a876f40abe076d1636c89ff1f75a36 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 19:43:13 +0900 Subject: [PATCH 040/215] Added scaladoc comment --- .../src/main/scala/ldbc/dsl/DBIO.scala | 455 +++++++++++++++++- 1 file changed, 454 insertions(+), 1 deletion(-) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index a531e0289..c9a2e0302 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -27,8 +27,47 @@ import ldbc.* import ldbc.free.* import ldbc.logging.LogEvent +/** + * DBIO (Database I/O) provides a set of combinators for building type-safe database operations + * using the Free monad pattern. It abstracts over JDBC operations and provides a functional + * interface for executing SQL queries with proper resource management and error handling. + * + * The DBIO type is a type alias for `Free[ConnectionOp, A]`, which represents a computation + * that produces a value of type `A` when run against a database connection. + * + * @example + * {{{ + * import ldbc.dsl.* + * + * val query: DBIO[List[User]] = + * sql"SELECT * FROM users WHERE age > \${18}".query[User].to[List] + * + * val result: F[List[User]] = query.run(connector) + * }}} + */ object DBIO: + /** + * Binds dynamic parameters to a PreparedStatement. + * + * This method is the core mechanism for safely binding values to SQL prepared statements. + * It processes a list of dynamic parameters, extracting their encoded values and binding + * them to the appropriate positions in the PreparedStatement. + * + * The binding process: + * 1. Validates all parameters, collecting any encoding failures + * 2. Maps each encoded value to the correct PreparedStatement setter method + * 3. Handles null values by calling setNull with the appropriate SQL type + * + * @param params List of dynamic parameters to bind. Each parameter contains either + * a successfully encoded value or encoding error information + * @return A PreparedStatementIO action that performs the binding operations + * @throws IllegalArgumentException if any parameter contains encoding failures + * + * @note JDBC uses 1-based indexing for parameters, so the index is incremented by 1 + * @note The method uses pattern matching to determine the appropriate setter method + * based on the runtime type of each encoded value + */ private def paramBind(params: List[Parameter.Dynamic]): PreparedStatementIO[Unit] = val encoded = params.foldLeft(PreparedStatementIO.pure(List.empty[Encoder.Supported])) { case (acc, param) => @@ -63,6 +102,21 @@ object DBIO: } yield () + /** + * Reads a single row from a ResultSet and decodes it using the provided decoder. + * + * This method expects exactly one row in the ResultSet. It will: + * - Advance to the first row using ResultSet.next() + * - Decode the row using the provided decoder + * - Return an error if no rows are found or if decoding fails + * + * @param statement The SQL statement being executed (used for error messages) + * @param decoder The decoder to convert the row data into type T + * @tparam T The type to decode the row into + * @return A ResultSetIO action that produces a value of type T + * @throws UnexpectedContinuation if the ResultSet is empty + * @throws DecodeFailureException if the decoder fails to decode the row + */ private def unique[T]( statement: String, decoder: Decoder[T] @@ -77,6 +131,21 @@ object DBIO: case false => ResultSetIO.raiseError(new UnexpectedContinuation("Expected ResultSet to have at least one row.")) } + /** + * Reads all rows from a ResultSet and collects them into a collection of type G[T]. + * + * This method iterates through all rows in the ResultSet, decoding each row + * and accumulating the results into a collection using the provided factory. + * It continues until ResultSet.next() returns false. + * + * @param statement The SQL statement being executed (used for error messages) + * @param decoder The decoder to convert each row into type T + * @param factoryCompat A factory for building the target collection type G[T] + * @tparam G The collection type constructor (e.g., List, Vector, Set) + * @tparam T The element type to decode each row into + * @return A ResultSetIO action that produces a collection of decoded values + * @throws DecodeFailureException if the decoder fails on any row + */ private def whileM[G[_], T]( statement: String, decoder: Decoder[T], @@ -102,6 +171,19 @@ object DBIO: loop(builder).map(_.result()) + /** + * Reads all rows from a ResultSet and returns them as a NonEmptyList. + * + * This method ensures that at least one row exists in the ResultSet. + * If the ResultSet is empty, it raises an UnexpectedEnd error. + * + * @param statement The SQL statement being executed (used for error messages) + * @param decoder The decoder to convert each row into type A + * @tparam A The type to decode each row into + * @return A ResultSetIO action that produces a NonEmptyList of decoded values + * @throws UnexpectedEnd if the ResultSet contains no rows + * @throws DecodeFailureException if the decoder fails on any row + */ private def nel[A]( statement: String, decoder: Decoder[A] @@ -111,6 +193,35 @@ object DBIO: else ResultSetIO.pure(NonEmptyList.fromListUnsafe(results)) } + /** + * Executes a SQL query that returns exactly one row and decodes it to type A. + * + * This method: + * 1. Prepares a statement with the given SQL + * 2. Binds the provided parameters + * 3. Executes the query + * 4. Reads exactly one row from the result + * 5. Closes the prepared statement + * 6. Logs the operation (success or failure) + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert the result row to type A + * @tparam A The result type + * @return A DBIO action that produces a single value of type A + * @throws UnexpectedContinuation if the ResultSet is empty + * @throws DecodeFailureException if decoding fails + * + * @example + * {{{ + * val userId = 42 + * val query = DBIO.queryA( + * "SELECT name FROM users WHERE id = ?", + * List(Parameter.Dynamic.Success(userId)), + * Decoder[String] + * ) + * }}} + */ def queryA[A]( statement: String, params: List[Parameter.Dynamic], @@ -129,6 +240,31 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL query and collects all results into a collection of type G[A]. + * + * This method allows you to specify the target collection type through the + * factory parameter. Common usage includes collecting to List, Vector, Set, etc. + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert each result row to type A + * @param factory Factory for building the target collection type + * @tparam G The collection type constructor (e.g., List, Vector, Set) + * @tparam A The element type + * @return A DBIO action that produces a collection of decoded values + * @throws DecodeFailureException if decoding fails for any row + * + * @example + * {{{ + * val query = DBIO.queryTo[List, String]( + * "SELECT name FROM users WHERE age > ?", + * List(Parameter.Dynamic.Success(18)), + * Decoder[String], + * summon[FactoryCompat[String, List[String]]] + * ) + * }}} + */ def queryTo[G[_], A]( statement: String, params: List[Parameter.Dynamic], @@ -148,6 +284,33 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL query that returns at most one row as an Option[A]. + * + * This method is useful when a query might return zero or one row. + * It returns: + * - Some(a) if exactly one row is found and successfully decoded + * - None if no rows are found + * - An error if more than one row is found + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert the result row to type A + * @tparam A The result type + * @return A DBIO action that produces an Option[A] + * @throws UnexpectedContinuation if more than one row is found + * @throws DecodeFailureException if decoding fails + * + * @example + * {{{ + * val query = DBIO.queryOption( + * "SELECT name FROM users WHERE email = ?", + * List(Parameter.Dynamic.Success("user@example.com")), + * Decoder[String] + * ) + * // Returns Some(name) if user exists, None otherwise + * }}} + */ def queryOption[A]( statement: String, params: List[Parameter.Dynamic], @@ -188,6 +351,31 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL query that returns at least one row as a NonEmptyList[A]. + * + * This method ensures that the query returns at least one result. + * It's useful when you need compile-time guarantees that your result + * set is not empty. + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert each result row to type A + * @tparam A The result type + * @return A DBIO action that produces a NonEmptyList[A] + * @throws UnexpectedEnd if no rows are found + * @throws DecodeFailureException if decoding fails for any row + * + * @example + * {{{ + * val query = DBIO.queryNel( + * "SELECT id FROM users WHERE role = ?", + * List(Parameter.Dynamic.Success("admin")), + * Decoder[Int] + * ) + * // Fails at runtime if no admin users exist + * }}} + */ def queryNel[A]( statement: String, params: List[Parameter.Dynamic], @@ -206,6 +394,28 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL update statement (INSERT, UPDATE, DELETE) with parameters. + * + * This method prepares the statement, binds parameters, executes the update, + * and returns the number of affected rows. + * + * @param statement The SQL update statement to execute + * @param params Dynamic parameters to bind to the statement + * @return A DBIO action that produces the number of affected rows + * + * @example + * {{{ + * val update = DBIO.update( + * "UPDATE users SET name = ? WHERE id = ?", + * List( + * Parameter.Dynamic.Success("New Name"), + * Parameter.Dynamic.Success(42) + * ) + * ) + * // Returns the number of updated rows + * }}} + */ def update( statement: String, params: List[Parameter.Dynamic] @@ -222,6 +432,25 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a raw SQL update statement without parameters. + * + * This method uses a Statement instead of PreparedStatement and is suitable + * for DDL operations or updates that don't require parameter binding. + * + * @param statement The SQL statement to execute + * @return A DBIO action that produces the number of affected rows + * + * @example + * {{{ + * val create = DBIO.updateRaw( + * "CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY, name VARCHAR(255))" + * ) + * }}} + * + * @note Use parameterized queries (update method) when dealing with user input + * to prevent SQL injection attacks + */ def updateRaw(statement: String): DBIO[Int] = (for stmt <- ConnectionIO.createStatement() @@ -231,6 +460,29 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, List.empty)) + /** + * Executes multiple SQL statements separated by semicolons as a batch. + * + * This method splits the input by semicolons and executes each statement + * as part of a batch operation. It's useful for executing multiple DDL + * or DML statements in a single round trip to the database. + * + * @param statement Multiple SQL statements separated by semicolons + * @return A DBIO action that produces an array of update counts + * + * @example + * {{{ + * val batch = DBIO.updateRaws(""" + * CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(255)); + * CREATE INDEX idx_name ON users(name); + * INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + * """) + * // Returns an array with the result of each statement + * }}} + * + * @note Each statement is executed independently; failure of one doesn't + * necessarily rollback others unless in a transaction + */ def updateRaws(statement: String): DBIO[Array[Int]] = (for stmt <- ConnectionIO.createStatement() @@ -242,6 +494,34 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, List.empty)) + /** + * Executes an INSERT statement and returns generated keys. + * + * This method is typically used for INSERT statements where you want to + * retrieve auto-generated keys (like auto-increment IDs). The statement + * is prepared with RETURN_GENERATED_KEYS flag. + * + * @param statement The INSERT statement to execute + * @param params Dynamic parameters to bind to the statement + * @param decoder Decoder to convert the generated key to type A + * @tparam A The type of the generated key + * @return A DBIO action that produces the generated key + * @throws UnexpectedContinuation if no keys are generated + * @throws DecodeFailureException if decoding the key fails + * + * @example + * {{{ + * val insert = DBIO.returning[Long]( + * "INSERT INTO users (name, email) VALUES (?, ?)", + * List( + * Parameter.Dynamic.Success("Alice"), + * Parameter.Dynamic.Success("alice@example.com") + * ), + * Decoder[Long] // For auto-increment ID + * ) + * // Returns the generated user ID + * }}} + */ def returning[A]( statement: String, params: List[Parameter.Dynamic], @@ -261,6 +541,27 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a batch of SQL statements. + * + * This method executes multiple statements as a batch operation, + * which is more efficient than executing them individually. + * Unlike updateRaws, this takes a list of complete statements. + * + * @param statements List of SQL statements to execute + * @return A DBIO action that produces an array of update counts + * + * @example + * {{{ + * val statements = List( + * "INSERT INTO users (name) VALUES ('Alice')", + * "INSERT INTO users (name) VALUES ('Bob')", + * "UPDATE users SET active = true WHERE name = 'Alice'" + * ) + * val batch = DBIO.sequence(statements) + * // Returns update counts for each statement + * }}} + */ def sequence( statements: List[String] ): DBIO[Array[Int]] = @@ -273,6 +574,42 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statements.mkString("\n"), List.empty)) + /** + * Creates a streaming query that processes results incrementally. + * + * This method is ideal for processing large result sets without loading + * all data into memory at once. It uses JDBC fetch size to control + * how many rows are fetched from the database in each round trip. + * + * The stream is resource-safe and will automatically close the + * PreparedStatement when the stream terminates (normally or via error). + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert each result row to type A + * @param fetchSize Number of rows to fetch in each batch from the database + * @tparam A The result type + * @return An fs2.Stream that emits decoded values + * + * @example + * {{{ + * val stream = DBIO.stream( + * "SELECT * FROM large_table WHERE category = ?", + * List(Parameter.Dynamic.Success("books")), + * Decoder[Product], + * fetchSize = 1000 + * ) + * + * // Process results as they arrive + * stream + * .evalMap(product => processProduct(product)) + * .compile + * .drain + * .run(connector) + * }}} + * + * @note The fetch size is a hint to the JDBC driver and may be ignored + */ def stream[A]( statement: String, params: List[Parameter.Dynamic], @@ -317,12 +654,56 @@ object DBIO: } <* fs2.Stream.eval(ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value)))) + /** + * Combines multiple DBIO actions into a single action that produces a list of results. + * + * This is equivalent to the traverse operation, executing each DBIO in sequence + * and collecting all results. + * + * @param dbios Variable number of DBIO actions to sequence + * @tparam A The result type of each DBIO + * @return A DBIO action that produces a List[A] containing all results + * + * @example + * {{{ + * val combined = DBIO.sequence( + * sql"SELECT COUNT(*) FROM users".query[Int].to[Option], + * sql"SELECT COUNT(*) FROM products".query[Int].to[Option] + * ) + * // Returns List(userCount, productCount) + * }}} + */ def sequence[A](dbios: DBIO[A]*): DBIO[List[A]] = dbios.toList.sequence - def pure[A](value: A): DBIO[A] = Free.pure(value) + /** + * Lifts a pure value into the DBIO context. + * + * @param value The value to lift + * @tparam A The type of the value + * @return A DBIO action that produces the given value + */ + def pure[A](value: A): DBIO[A] = Free.pure(value) + + /** + * Creates a failed DBIO action with the given error. + * + * @param e The error to raise + * @tparam A The expected result type (never produced) + * @return A DBIO action that fails with the given error + */ def raiseError[A](e: Throwable): DBIO[A] = ConnectionIO.raiseError(e) + /** + * Provides a Sync type class instance for DBIO. + * + * This enables DBIO to be used with any Cats Effect operations that + * require a Sync constraint, including error handling, resource management, + * and cancellation support. + * + * The implementation delegates to the underlying ConnectionIO Free monad + * operations, ensuring proper effect handling throughout the DBIO lifecycle. + */ implicit val syncDBIO: Sync[DBIO] = new Sync[DBIO]: val monad = Free.catsFreeMonadForFree[ConnectionOp] @@ -342,16 +723,67 @@ object DBIO: override def canceled: DBIO[Unit] = ConnectionIO.canceled override def onCancel[A](fa: DBIO[A], fin: DBIO[Unit]): DBIO[A] = ConnectionIO.onCancel(fa, fin) + /** + * Extension methods for DBIO values that provide various execution strategies. + * + * This class is made available through an implicit conversion and provides + * methods to run DBIO actions with different transaction semantics. + * + * @param dbio The DBIO action to execute + * @tparam A The result type of the DBIO action + */ private[ldbc] class Ops[A](dbio: DBIO[A]): + /** + * Runs the DBIO action using the provided connector. + * + * This is the basic execution method that runs the action + * with whatever transaction settings are currently active. + * + * @param connector The database connector to use + * @tparam F The effect type (e.g., IO, Task) + * @return The result wrapped in the effect type F + */ def run[F[_]](connector: Connector[F]): F[A] = connector.run(dbio) + /** + * Runs the DBIO action in read-only mode. + * + * This ensures that the connection is set to read-only before executing + * the action and restored afterwards. Useful for ensuring queries don't + * accidentally modify data. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + */ def readOnly[F[_]](connector: Connector[F]): F[A] = connector.run(ConnectionIO.setReadOnly(true) *> dbio <* ConnectionIO.setReadOnly(false)) + /** + * Runs the DBIO action with auto-commit enabled. + * + * Each statement is automatically committed after execution. + * This is typically used for single statement operations that + * should be immediately persisted. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + */ def commit[F[_]](connector: Connector[F]): F[A] = connector.run(ConnectionIO.setReadOnly(false) *> ConnectionIO.setAutoCommit(true) *> dbio) + /** + * Runs the DBIO action and rolls back all changes at the end. + * + * This is useful for testing or dry-run scenarios where you want + * to execute operations but not persist any changes. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + */ def rollback[F[_]](connector: Connector[F]): F[A] = connector.run( ConnectionIO.setReadOnly(false) *> @@ -361,6 +793,27 @@ object DBIO: ConnectionIO.setAutoCommit(true) ) + /** + * Runs the DBIO action within a transaction. + * + * The action is executed with auto-commit disabled. If the action + * completes successfully, changes are committed. If an error occurs, + * all changes are rolled back. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + * + * @example + * {{{ + * val transfer = for { + * _ <- sql"UPDATE accounts SET balance = balance - 100 WHERE id = 1".update + * _ <- sql"UPDATE accounts SET balance = balance + 100 WHERE id = 2".update + * } yield () + * + * transfer.transaction(connector) // Both updates succeed or both fail + * }}} + */ def transaction[F[_]](connector: Connector[F]): F[A] = connector.run( (ConnectionIO.setReadOnly(false) *> ConnectionIO.setAutoCommit(false) *> dbio) From b4370fe406f3b270c4ea5cac10e99beb18352580 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 19:44:55 +0900 Subject: [PATCH 041/215] Action sbt scalafmtAll --- .../main/scala/ldbc/free/KleisliInterpreter.scala | 9 +++++---- .../src/main/scala/ldbc/free/StatementIO.scala | 15 ++++++++------- .../ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala | 10 +++++----- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala index a1d15bd16..5265e6da6 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala @@ -140,10 +140,11 @@ class KleisliInterpreter[F[_]: Sync](logHandler: LogHandler[F]) extends Interpre override def onCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]): Kleisli[F, Statement[F], A] = outer.onCancel(this)(fa, fin) - override def executeQuery(sql: String): Kleisli[F, Statement[F], ResultSet[?]] = primitive[Statement[F], ResultSet[F]](_.executeQuery(sql)).asInstanceOf[Kleisli[F, Statement[F], ResultSet[?]]] - override def executeUpdate(sql: String): Kleisli[F, Statement[F], Int] = primitive(_.executeUpdate(sql)) - override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) - override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) + override def executeQuery(sql: String): Kleisli[F, Statement[F], ResultSet[?]] = + primitive[Statement[F], ResultSet[F]](_.executeQuery(sql)).asInstanceOf[Kleisli[F, Statement[F], ResultSet[?]]] + override def executeUpdate(sql: String): Kleisli[F, Statement[F], Int] = primitive(_.executeUpdate(sql)) + override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) + override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) trait PreparedStatementInterpreter extends PreparedStatementOp.Visitor[[A] =>> Kleisli[F, PreparedStatement[F], A]]: diff --git a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala index d933cd928..e7ba71cce 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala @@ -71,10 +71,10 @@ object StatementOp: def canceled: F[Unit] def onCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]): F[A] - def executeQuery(sql: String): F[ResultSet[?]] + def executeQuery(sql: String): F[ResultSet[?]] def executeUpdate(sql: String): F[Int] - def addBatch(sql: String): F[Unit] - def executeBatch(): F[Array[Int]] + def addBatch(sql: String): F[Unit] + def executeBatch(): F[Array[Int]] type StatementIO[A] = Free[StatementOp, A] @@ -101,7 +101,8 @@ object StatementIO: def capturePoll[M[_]](mpoll: Poll[M]): Poll[StatementIO] = new Poll[StatementIO]: override def apply[A](fa: StatementIO[A]): StatementIO[A] = Free.liftF[StatementOp, A](StatementOp.Poll1(mpoll, fa)) - def executeQuery(sql: String): StatementIO[ResultSet[?]] = Free.liftF[StatementOp, ResultSet[?]](StatementOp.ExecuteQuery(sql)) - def executeUpdate(sql: String): StatementIO[Int] = Free.liftF[StatementOp, Int](StatementOp.ExecuteUpdate(sql)) - def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) - def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) + def executeQuery(sql: String): StatementIO[ResultSet[?]] = + Free.liftF[StatementOp, ResultSet[?]](StatementOp.ExecuteQuery(sql)) + def executeUpdate(sql: String): StatementIO[Int] = Free.liftF[StatementOp, Int](StatementOp.ExecuteUpdate(sql)) + def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) + def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index c9a2e0302..951aa5892 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -453,8 +453,8 @@ object DBIO: */ def updateRaw(statement: String): DBIO[Int] = (for - stmt <- ConnectionIO.createStatement() - result <- ConnectionIO.embed(stmt, StatementIO.executeUpdate(statement)) + stmt <- ConnectionIO.createStatement() + result <- ConnectionIO.embed(stmt, StatementIO.executeUpdate(statement)) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) } <* @@ -487,8 +487,8 @@ object DBIO: (for stmt <- ConnectionIO.createStatement() statements = statement.trim.split(";").toList - _ <- ConnectionIO.embed(stmt, statements.map(statement => StatementIO.addBatch(statement)).sequence) - result <- ConnectionIO.embed(stmt, StatementIO.executeBatch()) + _ <- ConnectionIO.embed(stmt, statements.map(statement => StatementIO.addBatch(statement)).sequence) + result <- ConnectionIO.embed(stmt, StatementIO.executeBatch()) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) } <* @@ -684,7 +684,7 @@ object DBIO: * @return A DBIO action that produces the given value */ def pure[A](value: A): DBIO[A] = Free.pure(value) - + /** * Creates a failed DBIO action with the given error. * From 9494969dad3d131841bc88f72bea0f0cc2d10e14 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 20:07:08 +0900 Subject: [PATCH 042/215] Create updateRaw & updateRaws test --- .../src/test/scala/ldbc/tests/DBIOTest.scala | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala index 9d57a9ba6..ce6e90e15 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala @@ -102,3 +102,26 @@ class DBIOTest extends CatsEffectSuite: program.attempt.readOnly(connector).map(_.isLeft) ) } + + test("DBIO#updateRaw") { + assertIO( + (for + r1 <- DBIO.updateRaw("CREATE DATABASE `dbio`;").commit(connector) + r2 <- DBIO.updateRaw("DROP DATABASE `dbio`;").commit(connector) + yield List(r1, r2)), + List(1, 0) + ) + } + + test("DBIO#updateRaws") { + assertIO( + (for + results <- DBIO.updateRaws( + """ + |CREATE DATABASE `dbio`; + |DROP DATABASE `dbio`; + |""".stripMargin).commit(connector) + yield results.toList), + List(1, 0) + ) + } From 84ff12e30c6c1a25d49363000df866acfac8bd3c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 20:09:17 +0900 Subject: [PATCH 043/215] Action sbt scalafmtAll --- .../shared/src/test/scala/ldbc/tests/DBIOTest.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala index ce6e90e15..9b86a5b6f 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala @@ -114,14 +114,12 @@ class DBIOTest extends CatsEffectSuite: } test("DBIO#updateRaws") { + val sql = """ + |CREATE DATABASE `dbio`; + |DROP DATABASE `dbio`; + |""".stripMargin assertIO( - (for - results <- DBIO.updateRaws( - """ - |CREATE DATABASE `dbio`; - |DROP DATABASE `dbio`; - |""".stripMargin).commit(connector) - yield results.toList), + DBIO.updateRaws(sql).commit(connector).map(_.toList), List(1, 0) ) } From a57c6c5bba6c6bcef5c96086dfda8a93c80b98e6 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 3 Nov 2025 20:20:00 +0900 Subject: [PATCH 044/215] Added Exception error test --- .../src/test/scala/ldbc/tests/DBIOTest.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala index 9b86a5b6f..f327c244a 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala @@ -15,6 +15,7 @@ import munit.CatsEffectSuite import ldbc.dsl.* import ldbc.connector.* +import ldbc.connector.exception.* import ldbc.Connector @@ -113,6 +114,12 @@ class DBIOTest extends CatsEffectSuite: ) } + test("DBIO#updateRaw#Exception") { + interceptIO[SQLSyntaxErrorException]( + DBIO.updateRaw("CREATE `dbio`;").commit(connector) + ) + } + test("DBIO#updateRaws") { val sql = """ |CREATE DATABASE `dbio`; @@ -123,3 +130,13 @@ class DBIOTest extends CatsEffectSuite: List(1, 0) ) } + + test("DBIO#updateRaws#Exception") { + val sql = """ + |CREATE DATABASE `dbio` + |DROP DATABASE `dbio` + |""".stripMargin + interceptIO[BatchUpdateException]( + DBIO.updateRaws(sql).commit(connector).map(_.toList) + ) + } From c3e829df3174c3cdf38707a24a66984aac8eb281 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 4 Nov 2025 22:57:26 +0900 Subject: [PATCH 045/215] Added Close Vistor --- .../src/main/scala/ldbc/free/KleisliInterpreter.scala | 3 ++- module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala index 5265e6da6..a5a024c0f 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala @@ -144,7 +144,8 @@ class KleisliInterpreter[F[_]: Sync](logHandler: LogHandler[F]) extends Interpre primitive[Statement[F], ResultSet[F]](_.executeQuery(sql)).asInstanceOf[Kleisli[F, Statement[F], ResultSet[?]]] override def executeUpdate(sql: String): Kleisli[F, Statement[F], Int] = primitive(_.executeUpdate(sql)) override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) - override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) + override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) + override def close(): Kleisli[F, Statement[F], Unit] = primitive(_.close()) trait PreparedStatementInterpreter extends PreparedStatementOp.Visitor[[A] =>> Kleisli[F, PreparedStatement[F], A]]: diff --git a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala index e7ba71cce..e080aa5de 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala @@ -50,6 +50,8 @@ object StatementOp: override def visit[F[_]](v: StatementOp.Visitor[F]): F[Unit] = v.addBatch(sql) final case class ExecuteBatch[A]() extends StatementOp[Array[Int]]: override def visit[F[_]](v: StatementOp.Visitor[F]): F[Array[Int]] = v.executeBatch() + final case class Close() extends StatementOp[Unit]: + override def visit[F[_]](v: StatementOp.Visitor[F]): F[Unit] = v.close() given Embeddable[StatementOp, Statement[?]] = new Embeddable[StatementOp, Statement[?]]: @@ -75,6 +77,7 @@ object StatementOp: def executeUpdate(sql: String): F[Int] def addBatch(sql: String): F[Unit] def executeBatch(): F[Array[Int]] + def close(): F[Unit] type StatementIO[A] = Free[StatementOp, A] @@ -106,3 +109,4 @@ object StatementIO: def executeUpdate(sql: String): StatementIO[Int] = Free.liftF[StatementOp, Int](StatementOp.ExecuteUpdate(sql)) def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) + def close(): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.Close()) From 675aad2d10bf9ec288bbefa96d7fb02284073601 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 4 Nov 2025 22:57:39 +0900 Subject: [PATCH 046/215] Use Statement CLose --- module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index 951aa5892..7dcd197a5 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -455,6 +455,7 @@ object DBIO: (for stmt <- ConnectionIO.createStatement() result <- ConnectionIO.embed(stmt, StatementIO.executeUpdate(statement)) + _ <- ConnectionIO.embed(stmt, StatementIO.close()) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) } <* @@ -489,6 +490,7 @@ object DBIO: statements = statement.trim.split(";").toList _ <- ConnectionIO.embed(stmt, statements.map(statement => StatementIO.addBatch(statement)).sequence) result <- ConnectionIO.embed(stmt, StatementIO.executeBatch()) + _ <- ConnectionIO.embed(stmt, StatementIO.close()) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) } <* @@ -569,6 +571,7 @@ object DBIO: statement <- ConnectionIO.createStatement() _ <- ConnectionIO.embed(statement, statements.map(statement => StatementIO.addBatch(statement)).sequence) result <- ConnectionIO.embed(statement, StatementIO.executeBatch()) + _ <- ConnectionIO.embed(statement, StatementIO.close()) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statements.mkString("\n"), List.empty, ex)) } <* From 1ed4af86f54f3f7376476eecd12e1ad6002d155d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 4 Nov 2025 22:58:06 +0900 Subject: [PATCH 047/215] Action sbt scalafmtAll --- .../src/main/scala/ldbc/free/KleisliInterpreter.scala | 4 ++-- module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala | 4 ++-- module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala index a5a024c0f..7633c85b8 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala @@ -144,8 +144,8 @@ class KleisliInterpreter[F[_]: Sync](logHandler: LogHandler[F]) extends Interpre primitive[Statement[F], ResultSet[F]](_.executeQuery(sql)).asInstanceOf[Kleisli[F, Statement[F], ResultSet[?]]] override def executeUpdate(sql: String): Kleisli[F, Statement[F], Int] = primitive(_.executeUpdate(sql)) override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) - override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) - override def close(): Kleisli[F, Statement[F], Unit] = primitive(_.close()) + override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) + override def close(): Kleisli[F, Statement[F], Unit] = primitive(_.close()) trait PreparedStatementInterpreter extends PreparedStatementOp.Visitor[[A] =>> Kleisli[F, PreparedStatement[F], A]]: diff --git a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala index e080aa5de..8a94b340b 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala @@ -77,7 +77,7 @@ object StatementOp: def executeUpdate(sql: String): F[Int] def addBatch(sql: String): F[Unit] def executeBatch(): F[Array[Int]] - def close(): F[Unit] + def close(): F[Unit] type StatementIO[A] = Free[StatementOp, A] @@ -109,4 +109,4 @@ object StatementIO: def executeUpdate(sql: String): StatementIO[Int] = Free.liftF[StatementOp, Int](StatementOp.ExecuteUpdate(sql)) def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) - def close(): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.Close()) + def close(): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.Close()) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index 7dcd197a5..69b6db5ee 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -571,7 +571,7 @@ object DBIO: statement <- ConnectionIO.createStatement() _ <- ConnectionIO.embed(statement, statements.map(statement => StatementIO.addBatch(statement)).sequence) result <- ConnectionIO.embed(statement, StatementIO.executeBatch()) - _ <- ConnectionIO.embed(statement, StatementIO.close()) + _ <- ConnectionIO.embed(statement, StatementIO.close()) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statements.mkString("\n"), List.empty, ex)) } <* From e48d0e9b4e096a66e3289d29b705b2f928ccd2e3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Tue, 4 Nov 2025 23:33:07 +0900 Subject: [PATCH 048/215] Fixed scaladoc --- module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index 69b6db5ee..74c92baae 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -40,7 +40,7 @@ import ldbc.logging.LogEvent * import ldbc.dsl.* * * val query: DBIO[List[User]] = - * sql"SELECT * FROM users WHERE age > \${18}".query[User].to[List] + * sql"SELECT * FROM users WHERE age > ${ 18 }".query[User].to[List] * * val result: F[List[User]] = query.run(connector) * }}} From ade6da13b2d2574d01b85ce3bbdbb3423217ea27 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 15 Nov 2025 18:27:21 +0900 Subject: [PATCH 049/215] Create awsAuthenticationPlugin sbt project --- build.sbt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/build.sbt b/build.sbt index ab915ec2a..4898cf6fb 100644 --- a/build.sbt +++ b/build.sbt @@ -143,6 +143,15 @@ lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) .nativeSettings(Test / nativeBrewFormulas += "s2n") .dependsOn(core) +lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) + .crossType(CrossType.Full) + .module("aws-authentication-plugin", "") + .jsSettings( + Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) + ) + .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) + .nativeSettings(Test / nativeBrewFormulas += "s2n") + lazy val plugin = LepusSbtPluginProject("ldbc-plugin", "plugin") .settings(description := "Projects that provide sbt plug-ins") .settings((Compile / sourceGenerators) += Def.task { @@ -401,6 +410,7 @@ lazy val ldbc = tlCrossRootProject schema, codegen, zioInterop, + awsAuthenticationPlugin, plugin, tests, docs, From 629767834ec2f6f6b94c8db2db53a8b071b6830d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 15 Nov 2025 18:27:37 +0900 Subject: [PATCH 050/215] Create Identity --- .../scala/ldbc/amazon/identity/Identity.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala new file mode 100644 index 000000000..1997ac110 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala @@ -0,0 +1,24 @@ +package ldbc.amazon.identity + +import java.time.Instant + +/** + * Interface to represent who is using the SDK, i.e., the identity of the caller, used for authentication. + * + * Examples include [[AwsCredentialsIdentity]] and [[TokenIdentity]]. + */ +trait Identity: + + /** + * The time after which this identity will no longer be valid. If this is empty, + * an expiration time is not known (but the identity may still expire at some + * time in the future). + */ + def expirationTime(): Option[Instant] + + /** + * The source that resolved this identity, normally an identity provider. Note that + * this string value would be set by an identity provider implementation and is + * intended to be used for for tracking purposes. Avoid building logic on its value. + */ + def providerName(): Option[String] From 802addcedd5501c40ba4d19bd024b529c8c96441 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 15 Nov 2025 19:14:38 +0900 Subject: [PATCH 051/215] Create AwsCredentialsIdentity --- .../identity/AwsCredentialsIdentity.scala | 36 ++++++++++++++++ .../scala/ldbc/amazon/identity/Identity.scala | 4 +- .../DefaultAwsCredentialsIdentity.scala | 42 +++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala new file mode 100644 index 000000000..15f7733e2 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala @@ -0,0 +1,36 @@ +package ldbc.amazon.identity + +import ldbc.amazon.identity.internal.DefaultAwsCredentialsIdentity + +trait AwsCredentialsIdentity extends Identity: + + /** + * Retrieve the AWS access key, used to identify the user interacting with services. + */ + def accessKeyId: String + + /** + * Retrieve the AWS secret access key, used to authenticate the user interacting with services. + */ + def secretAccessKey: String + + /** + * Retrieve the AWS account id associated with this credentials identity, if found. + */ + def accountId: Option[String] + +object AwsCredentialsIdentity: + + /** + * Constructs a new credentials object, with the specified AWS access key and AWS secret key. + * + * @param accessKeyId The AWS access key, used to identify the user interacting with services. + * @param secretAccessKey The AWS secret access key, used to authenticate the user interacting with services. + */ + def create(accessKeyId: String, secretAccessKey: String): AwsCredentialsIdentity = DefaultAwsCredentialsIdentity( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala index 1997ac110..c5e8d415a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala @@ -14,11 +14,11 @@ trait Identity: * an expiration time is not known (but the identity may still expire at some * time in the future). */ - def expirationTime(): Option[Instant] + def expirationTime: Option[Instant] /** * The source that resolved this identity, normally an identity provider. Note that * this string value would be set by an identity provider implementation and is * intended to be used for for tracking purposes. Avoid building logic on its value. */ - def providerName(): Option[String] + def providerName: Option[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala new file mode 100644 index 000000000..8bef8cd7b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -0,0 +1,42 @@ +package ldbc.amazon.identity.internal + +import java.time.Instant +import java.util.Objects + +import ldbc.amazon.identity.AwsCredentialsIdentity + +final case class DefaultAwsCredentialsIdentity( + accessKeyId: String, + secretAccessKey: String, + accountId: Option[String], + expirationTime: Option[Instant], + providerName: Option[String] +) extends AwsCredentialsIdentity: + + override def toString: String = + val builder = new StringBuilder() + builder.append("AwsCredentialsIdentity(") + builder.append(s"accessKeyId=$accessKeyId") + providerName.foreach(v => builder.append(s"providerName=$v")) + accountId.foreach(v => builder.append(s"accountId=$v")) + + builder.result() + + override def equals(obj: Any): Boolean = + if this == obj then + true + else if obj == null || getClass != obj.getClass then + false + else obj match + case that: AwsCredentialsIdentity => + Objects.equals(accessKeyId, that.accessKeyId) && + Objects.equals(secretAccessKey, that.secretAccessKey) && + Objects.equals(accountId, that.accountId) + case _ => false + + override def hashCode(): Int = + var hashCode = 1 + hashCode = 31 * hashCode + Objects.hashCode(accessKeyId) + hashCode = 31 * hashCode + Objects.hashCode(secretAccessKey) + hashCode = 31 * hashCode + Objects.hashCode(accountId) + hashCode From 70d0c960681ef5cfc3abd3d6743b4f74e8a79b40 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 15 Nov 2025 19:21:14 +0900 Subject: [PATCH 052/215] Create AwsCredentials --- .../scala/ldbc/amazon/identity/AwsCredentials.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala new file mode 100644 index 000000000..823af5998 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala @@ -0,0 +1,11 @@ +package ldbc.amazon.identity + +/** + * Provides access to the AWS credentials used for accessing services: AWS access key ID and secret access key. These + * credentials are used to securely sign requests to services (e.g., AWS services) that use them for authentication. + * + *

For more details on AWS access keys, see: + * + * https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html#access-keys-and-secret-access-keys

+ */ +trait AwsCredentials extends AwsCredentialsIdentity From d619e7c0d4f32c81e4ebcf921be069867d9aed67 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 15 Nov 2025 19:21:27 +0900 Subject: [PATCH 053/215] Create AwsCredentialsProvider --- .../amazon/identity/AwsCredentialsIdentity.scala | 8 ++++++++ .../amazon/identity/AwsCredentialsProvider.scala | 15 +++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala index 15f7733e2..b1b72d290 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala @@ -2,6 +2,14 @@ package ldbc.amazon.identity import ldbc.amazon.identity.internal.DefaultAwsCredentialsIdentity +/** + * Provides access to the AWS credentials used for accessing services: AWS access key ID and secret access key. These + * credentials are used to securely sign requests to services (e.g., AWS services) that use them for authentication. + * + *

For more details on AWS access keys, see: + * + * https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html#access-keys-and-secret-access-keys

+ */ trait AwsCredentialsIdentity extends Identity: /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala new file mode 100644 index 000000000..c65029c03 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala @@ -0,0 +1,15 @@ +package ldbc.amazon.identity + +trait AwsCredentialsProvider: + + /** + * Returns [[AwsCredentials]] that can be used to authorize an AWS request. Each implementation of AWSCredentialsProvider + * can choose its own strategy for loading credentials. For example, an implementation might load credentials from an existing + * key management system, or load new credentials when credentials are rotated. + * + *

If an error occurs during the loading of credentials or credentials could not be found, a runtime exception will be + * raised.

+ * + * @return AwsCredentials which the caller can use to authorize an AWS request. + */ + def resolveCredentials(): AwsCredentials From d6b4549967704f08b88b1b70ded07ff4802326c0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:18:45 +0900 Subject: [PATCH 054/215] Create SystemSetting --- .../ldbc/amazon/util/SystemSetting.scala | 228 ++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala new file mode 100644 index 000000000..eb0c1566d --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala @@ -0,0 +1,228 @@ +package ldbc.amazon.util + +trait SystemSetting + +enum SdkSystemSetting(val systemProperty: String, val defaultValue: Option[String]): + /** + * Configure the AWS access key ID. + * + * This value will not be ignored if the [[AWS_SECRET_ACCESS_KEY]] is not specified. + */ + case AWS_ACCESS_KEY_ID extends SdkSystemSetting("aws.accessKeyId", None) + + /** + * Configure the AWS secret access key. + * + * This value will not be ignored if the [[AWS_ACCESS_KEY_ID]] is not specified. + */ + case AWS_SECRET_ACCESS_KEY extends SdkSystemSetting("aws.secretAccessKey", None) + + /** + * Configure the AWS session token. + */ + case AWS_SESSION_TOKEN extends SdkSystemSetting("aws.sessionToken", None) + + /** + * Configure the AWS account id associated with credentials supplied through system properties. + */ + case AWS_ACCOUNT_ID extends SdkSystemSetting("aws.accountId", None) + + /** + * Configure the AWS web identity token file path. + */ + case AWS_WEB_IDENTITY_TOKEN_FILE extends SdkSystemSetting("aws.webIdentityTokenFile", None) + + /** + * Configure the AWS role arn. + */ + case AWS_ROLE_ARN extends SdkSystemSetting("aws.roleArn", None) + + /** + * Configure the session name for a role. + */ + case AWS_ROLE_SESSION_NAME extends SdkSystemSetting("aws.roleSessionName", None) + + /** + * Configure the default region. + */ + case AWS_REGION extends SdkSystemSetting("aws.region", None) + + /** + * Whether to load information such as credentials, regions from EC2 Metadata instance service. + */ + case AWS_EC2_METADATA_DISABLED extends SdkSystemSetting("aws.disableEc2Metadata", Some("false")) + + /** + * Whether to disable fallback to insecure EC2 Metadata instance service v1 on errors or timeouts. + */ + case AWS_EC2_METADATA_V1_DISABLED extends SdkSystemSetting("aws.disableEc2MetadataV1", None) + + /** + * The EC2 instance metadata service endpoint. + * + * This allows a service running in EC2 to automatically load its credentials and region without needing to configure them + * in the SdkClientBuilder. + */ + case AWS_EC2_METADATA_SERVICE_ENDPOINT extends SdkSystemSetting("aws.ec2MetadataServiceEndpoint", Some("http://169.254.169.254")) + + case AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE extends SdkSystemSetting("aws.ec2MetadataServiceEndpointMode", Some("IPv4")) + + /** + * The number of seconds (either as an integer or double) before a connection to the instance + * metadata service should time out. This value is applied to both the socket connect and read timeouts. + * + * The timeout can be configured using the system property "aws.ec2MetadataServiceTimeout". If not set, + * a default timeout is used. This setting is crucial for ensuring timely responses from the instance + * metadata service in environments with varying network conditions. + */ + case AWS_METADATA_SERVICE_TIMEOUT extends SdkSystemSetting("aws.ec2MetadataServiceTimeout", Some("1")) + + /** + * The elastic container metadata service endpoint that should be called by the ContainerCredentialsProvider + * when loading data from the container metadata service. + * + * This allows a service running in an elastic container to automatically load its credentials without needing to configure + * them in the SdkClientBuilder. + * + * This is not used if the [[AWS_CONTAINER_CREDENTIALS_RELATIVE_URI]] is not specified. + */ + case AWS_CONTAINER_SERVICE_ENDPOINT extends SdkSystemSetting("aws.containerServiceEndpoint", Some("http://169.254.170.2")) + + /** + * The elastic container metadata service path that should be called by the ContainerCredentialsProvider when + * loading credentials form the container metadata service. If this is not specified, credentials will not be automatically + * loaded from the container metadata service. + * + * @see #AWS_CONTAINER_SERVICE_ENDPOINT + */ + case AWS_CONTAINER_CREDENTIALS_RELATIVE_URI extends SdkSystemSetting("aws.containerCredentialsPath", None) + + /** + * The full URI path to a localhost metadata service to be used. + */ + case AWS_CONTAINER_CREDENTIALS_FULL_URI extends SdkSystemSetting("aws.containerCredentialsFullUri", None) + + /** + * An authorization token to pass to a container metadata service, only used when [[AWS_CONTAINER_CREDENTIALS_FULL_URI]] + * is specified. + * + * @see #AWS_CONTAINER_CREDENTIALS_FULL_URI + */ + case AWS_CONTAINER_AUTHORIZATION_TOKEN extends SdkSystemSetting("aws.containerAuthorizationToken", None) + + /** + * The absolute file path containing the authorization token in plain text to pass to a container metadata + * service, only used when [[AWS_CONTAINER_CREDENTIALS_FULL_URI]] is specified. + * + * @see #AWS_CONTAINER_CREDENTIALS_FULL_URI + */ + case AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE extends SdkSystemSetting("aws.containerAuthorizationTokenFile", None) + + /** + * Explicitly identify the default synchronous HTTP implementation the SDK will use. Useful + * when there are multiple implementations on the classpath or as a performance optimization + * since implementation discovery requires classpath scanning. + */ + case SYNC_HTTP_SERVICE_IMPL extends SdkSystemSetting("software.amazon.awssdk.http.service.impl", None) + + /** + * Explicitly identify the default Async HTTP implementation the SDK will use. Useful + * when there are multiple implementations on the classpath or as a performance optimization + * since implementation discovery requires classpath scanning. + */ + case ASYNC_HTTP_SERVICE_IMPL extends SdkSystemSetting("software.amazon.awssdk.http.async.service.impl", None) + + /** + * Whether CBOR optimization should automatically be used if its support is found on the classpath and the service supports + * CBOR-formatted JSON. + */ + case CBOR_ENABLED extends SdkSystemSetting("aws.cborEnabled", Some("true")) + + /** + * Whether binary ION representation optimization should automatically be used if the service supports ION. + */ + case BINARY_ION_ENABLED extends SdkSystemSetting("aws.binaryIonEnabled", Some("true")) + + /** + * The execution environment of the SDK user. This is automatically set in certain environments by the underlying AWS service. + * For example, AWS Lambda will automatically specify a runtime indicating that the SDK is being used within Lambda. + */ + case AWS_EXECUTION_ENV extends SdkSystemSetting("aws.executionEnvironment", None) + + /** + * Whether endpoint discovery should be enabled. + */ + case AWS_ENDPOINT_DISCOVERY_ENABLED extends SdkSystemSetting("aws.endpointDiscoveryEnabled", None) + + /** + * Which [[RetryMode]] to use for the default [[RetryPolicy]], when one is not specified at the client level. + */ + case AWS_RETRY_MODE extends SdkSystemSetting("aws.retryMode", None) + + /** + * Which DefaultsMode to use, case insensitive + */ + case AWS_DEFAULTS_MODE extends SdkSystemSetting("aws.defaultsMode", None) + + /** + * Which AccountIdEndpointMode to use, case insensitive + */ + case AWS_ACCOUNT_ID_ENDPOINT_MODE extends SdkSystemSetting("aws.accountIdEndpointMode", None) + + /** + * Defines whether dualstack endpoints should be resolved during default endpoint resolution instead of non-dualstack + * endpoints. + */ + case AWS_USE_DUALSTACK_ENDPOINT extends SdkSystemSetting("aws.useDualstackEndpoint", None) + + /** + * Defines whether fips endpoints should be resolved during default endpoint resolution instead of non-fips endpoints. + */ + case AWS_USE_FIPS_ENDPOINT extends SdkSystemSetting("aws.useFipsEndpoint", None) + + /** + * Whether request compression is disabled for operations marked with the RequestCompression trait. The default value is + * false, i.e., request compression is enabled. + */ + case AWS_DISABLE_REQUEST_COMPRESSION extends SdkSystemSetting("aws.disableRequestCompression", None) + + /** + * Defines the minimum compression size in bytes, inclusive, for a request to be compressed. The default value is 10_240. + * The value must be non-negative and no greater than 10_485_760. + */ + case AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES extends SdkSystemSetting("aws.requestMinCompressionSizeBytes", None) + + /** + * Defines a file path from which partition metadata should be loaded. If this isn't specified, the partition + * metadata deployed with the SDK client will be used instead. + */ + case AWS_PARTITIONS_FILE extends SdkSystemSetting("aws.partitionsFile", None) + + /** + * The request checksum calculation setting. The default value is WHEN_SUPPORTED. + */ + case AWS_REQUEST_CHECKSUM_CALCULATION extends SdkSystemSetting("aws.requestChecksumCalculation", None) + + /** + * The response checksum validation setting. The default value is WHEN_SUPPORTED. + */ + case AWS_RESPONSE_CHECKSUM_VALIDATION extends SdkSystemSetting("aws.responseChecksumValidation", None) + + /** + * Configure an optional identification value to be appended to the user agent header. + * The value should be less than 50 characters in length and is None by default. + */ + case AWS_SDK_UA_APP_ID extends SdkSystemSetting("sdk.ua.appId", None) + + /** + * Configure the SIGV4A signing region set. + * This is a non-empty, comma-delimited list of AWS region names used during signing. + */ + case AWS_SIGV4A_SIGNING_REGION_SET extends SdkSystemSetting("aws.sigv4a.signing.region.set", None) + + + /** + * Configure the preferred auth scheme to use. + * This is a comma-delimited list of AWS auth scheme names used during signing. + */ + case AWS_AUTH_SCHEME_PREFERENCE extends SdkSystemSetting("aws.authSchemePreference", None) From 547ad79059f725d2c77c212e50f65d21efb21ece Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:18:59 +0900 Subject: [PATCH 055/215] Create BusinessMetricFeatureId --- .../useragent/BusinessMetricFeatureId.scala | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala new file mode 100644 index 000000000..ead0cd5db --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -0,0 +1,57 @@ +package ldbc.amazon.useragent + +/** + * An enum class representing a short form of identity providers to record in the UA string. + * + * Unimplemented metrics: I,K + * Unsupported metrics (these will never be added): A,H + */ +enum BusinessMetricFeatureId(val code: String): + case WAITER extends BusinessMetricFeatureId("B") + case PAGINATOR extends BusinessMetricFeatureId("C") + case RETRY_MODE_LEGACY extends BusinessMetricFeatureId("D") + case RETRY_MODE_STANDARD extends BusinessMetricFeatureId("E") + case RETRY_MODE_ADAPTIVE extends BusinessMetricFeatureId("F") + case S3_TRANSFER extends BusinessMetricFeatureId("G") + case GZIP_REQUEST_COMPRESSION extends BusinessMetricFeatureId("L") + case PROTOCOL_RPC_V2_CBOR extends BusinessMetricFeatureId("M") + case ENDPOINT_OVERRIDE extends BusinessMetricFeatureId("N") + case S3_EXPRESS_BUCKET extends BusinessMetricFeatureId("J") + case ACCOUNT_ID_MODE_PREFERRED extends BusinessMetricFeatureId("P") + case ACCOUNT_ID_MODE_DISABLED extends BusinessMetricFeatureId("Q") + case ACCOUNT_ID_MODE_REQUIRED extends BusinessMetricFeatureId("R") + case SIGV4A_SIGNING extends BusinessMetricFeatureId("S") + case RESOLVED_ACCOUNT_ID extends BusinessMetricFeatureId("T") + case FLEXIBLE_CHECKSUMS_REQ_CRC32 extends BusinessMetricFeatureId("U") + case FLEXIBLE_CHECKSUMS_REQ_CRC32C extends BusinessMetricFeatureId("V") + case FLEXIBLE_CHECKSUMS_REQ_CRC64 extends BusinessMetricFeatureId("W") + case FLEXIBLE_CHECKSUMS_REQ_SHA1 extends BusinessMetricFeatureId("X") + case FLEXIBLE_CHECKSUMS_REQ_SHA256 extends BusinessMetricFeatureId("Y") + case FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED extends BusinessMetricFeatureId("Z") + case FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED extends BusinessMetricFeatureId("a") + case FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED extends BusinessMetricFeatureId("b") + case FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED extends BusinessMetricFeatureId("c") + case DDB_MAPPER extends BusinessMetricFeatureId("d") + case BEARER_SERVICE_ENV_VARS extends BusinessMetricFeatureId("3") + case CREDENTIALS_CODE extends BusinessMetricFeatureId("e") + case CREDENTIALS_JVM_SYSTEM_PROPERTIES extends BusinessMetricFeatureId("f") + case CREDENTIALS_ENV_VARS extends BusinessMetricFeatureId("g") + case CREDENTIALS_ENV_VARS_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("h") + case CREDENTIALS_STS_ASSUME_ROLE extends BusinessMetricFeatureId("i") + case CREDENTIALS_STS_ASSUME_ROLE_SAML extends BusinessMetricFeatureId("j") + case CREDENTIALS_STS_ASSUME_ROLE_WEB_ID extends BusinessMetricFeatureId("k") + case CREDENTIALS_STS_FEDERATION_TOKEN extends BusinessMetricFeatureId("l") + case CREDENTIALS_STS_SESSION_TOKEN extends BusinessMetricFeatureId("m") + case CREDENTIALS_PROFILE extends BusinessMetricFeatureId("n") + case CREDENTIALS_PROFILE_SOURCE_PROFILE extends BusinessMetricFeatureId("o") + case CREDENTIALS_PROFILE_NAMED_PROVIDER extends BusinessMetricFeatureId("p") + case CREDENTIALS_PROFILE_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("q") + case CREDENTIALS_PROFILE_SSO extends BusinessMetricFeatureId("r") + case CREDENTIALS_SSO extends BusinessMetricFeatureId("s") + case CREDENTIALS_PROFILE_SSO_LEGACY extends BusinessMetricFeatureId("t") + case CREDENTIALS_SSO_LEGACY extends BusinessMetricFeatureId("u") + case CREDENTIALS_PROFILE_PROCESS extends BusinessMetricFeatureId("v") + case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") + case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") + case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") + case UNKNOWN extends BusinessMetricFeatureId("Unknown") From 5233d3ae292f1292c025456a8c91690af12d76ce Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:19:13 +0900 Subject: [PATCH 056/215] Create SdkClientException --- .../scala/ldbc/amazon/exception/SdkClientException.scala | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala new file mode 100644 index 000000000..580420741 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala @@ -0,0 +1,5 @@ +package ldbc.amazon.exception + +class SdkClientException(message: String) extends RuntimeException: + + final override def getMessage: String = message From ae6f0473f9a566dc99a6c3e3e19ca99ef0b6f2b3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:19:27 +0900 Subject: [PATCH 057/215] Create AwsBasicCredentials --- .../credentials/AwsBasicCredentials.scala | 14 ++++++ .../credentials/AwsSessionCredentials.scala | 15 ++++++ .../SystemPropertyCredentialsProvider.scala | 17 +++++++ .../SystemSettingsCredentialsProvider.scala | 47 +++++++++++++++++++ 4 files changed, 93 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala new file mode 100644 index 000000000..52bd6363b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala @@ -0,0 +1,14 @@ +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import ldbc.amazon.identity.AwsCredentials + +final case class AwsBasicCredentials( + accessKeyId: String, + secretAccessKey: String, + validateCredentials: Boolean, + providerName: Option[String], + accountId: Option[String], + expirationTime: Option[Instant], + ) extends AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala new file mode 100644 index 000000000..1ad3ef064 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala @@ -0,0 +1,15 @@ +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import ldbc.amazon.identity.AwsCredentials + +final case class AwsSessionCredentials( + accessKeyId: String, + secretAccessKey: String, + sessionToken: String, + validateCredentials: Boolean, + providerName: Option[String], + accountId: Option[String], + expirationTime: Option[Instant], + ) extends AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala new file mode 100644 index 000000000..c4b03ce7a --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -0,0 +1,17 @@ +package ldbc.amazon.auth.credentials + +import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider +import ldbc.amazon.util.SdkSystemSetting +import ldbc.amazon.useragent.BusinessMetricFeatureId + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from the aws.accessKeyId, aws.secretAccessKey and + * aws.sessionToken system properties. + */ +final class SystemPropertyCredentialsProvider extends SystemSettingsCredentialsProvider: + + // Customers should be able to specify a credentials provider that only looks at the system properties, + // but not the environment variables. For that reason, we're only checking the system properties here. + override def loadSetting(setting: SdkSystemSetting): Option[String] = Option(System.getProperty(setting.systemProperty)) + + override def provider: String = BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala new file mode 100644 index 000000000..3c6d15c8a --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -0,0 +1,47 @@ +package ldbc.amazon.auth.credentials.internal + +import ldbc.amazon.identity.* +import ldbc.amazon.util.SdkSystemSetting +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.auth.credentials.* + +trait SystemSettingsCredentialsProvider extends AwsCredentialsProvider: + + override def resolveCredentials(): Either[SdkClientException, AwsCredentials] = + val accessKeyOpt = loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.trim) + val secretKeyOpt = loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.trim) + val sessionTokenOpt = loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN).map(_.trim) + val accountId = loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.trim) + + for + accessKey <- accessKeyOpt match { + case Some(value) if value.nonEmpty => Right(value) + case _ => Left(new SdkClientException(s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${SdkSystemSetting.AWS_ACCESS_KEY_ID}) or system property (${SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty}).")) + } + secretKey <- secretKeyOpt match { + case Some(value) if value.isEmpty => Right(value) + case _ => Left(new SdkClientException(s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${SdkSystemSetting.AWS_SECRET_ACCESS_KEY}) or system property (${SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty}).")) + } + yield sessionTokenOpt match { + case None => AwsBasicCredentials( + accessKeyId = accessKey, + secretAccessKey = secretKey, + validateCredentials = false, + providerName = Some(provider), + accountId = accountId, + expirationTime = None, + ) + case Some(sessionToken) => AwsSessionCredentials( + accessKeyId = accessKey, + secretAccessKey = secretKey, + sessionToken = sessionToken, + validateCredentials = false, + providerName = Some(provider), + accountId = accountId, + expirationTime = None, + ) + } + + def loadSetting(setting: SdkSystemSetting): Option[String] + + def provider: String From c2f73e0049f0d21397f08291d466a3cb697f0e6a Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:19:50 +0900 Subject: [PATCH 058/215] Change return type --- .../scala/ldbc/amazon/identity/AwsCredentialsProvider.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala index c65029c03..3cd5bf4ce 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala @@ -1,5 +1,7 @@ package ldbc.amazon.identity +import ldbc.amazon.exception.SdkClientException + trait AwsCredentialsProvider: /** @@ -12,4 +14,4 @@ trait AwsCredentialsProvider: * * @return AwsCredentials which the caller can use to authorize an AWS request. */ - def resolveCredentials(): AwsCredentials + def resolveCredentials(): Either[SdkClientException, AwsCredentials] From 13e942d2aa9a35f3a784d6738c62347ec8286485 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:20:30 +0900 Subject: [PATCH 059/215] Added licence header --- .../ldbc/amazon/auth/credentials/AwsBasicCredentials.scala | 6 ++++++ .../amazon/auth/credentials/AwsSessionCredentials.scala | 6 ++++++ .../credentials/SystemPropertyCredentialsProvider.scala | 6 ++++++ .../internal/SystemSettingsCredentialsProvider.scala | 6 ++++++ .../scala/ldbc/amazon/exception/SdkClientException.scala | 6 ++++++ .../main/scala/ldbc/amazon/identity/AwsCredentials.scala | 6 ++++++ .../scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala | 6 ++++++ .../scala/ldbc/amazon/identity/AwsCredentialsProvider.scala | 6 ++++++ .../src/main/scala/ldbc/amazon/identity/Identity.scala | 6 ++++++ .../identity/internal/DefaultAwsCredentialsIdentity.scala | 6 ++++++ .../ldbc/amazon/useragent/BusinessMetricFeatureId.scala | 6 ++++++ .../src/main/scala/ldbc/amazon/util/SystemSetting.scala | 6 ++++++ 12 files changed, 72 insertions(+) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala index 52bd6363b..843b16a83 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.auth.credentials import java.time.Instant diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala index 1ad3ef064..ae1f46478 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.auth.credentials import java.time.Instant diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala index c4b03ce7a..cc22ac989 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.auth.credentials import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala index 3c6d15c8a..fcc7f12fd 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.auth.credentials.internal import ldbc.amazon.identity.* diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala index 580420741..a38ea56dc 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.exception class SdkClientException(message: String) extends RuntimeException: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala index 823af5998..3b55f0187 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.identity /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala index b1b72d290..0819fe7df 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.identity import ldbc.amazon.identity.internal.DefaultAwsCredentialsIdentity diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala index 3cd5bf4ce..51c278c73 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.identity import ldbc.amazon.exception.SdkClientException diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala index c5e8d415a..efc4b7e47 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.identity import java.time.Instant diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index 8bef8cd7b..bcdbc3404 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.identity.internal import java.time.Instant diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala index ead0cd5db..60ca8e68a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.useragent /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala index eb0c1566d..0d86f554c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.util trait SystemSetting From 0242b4afa446741f68c4d13fd8e7a6a7a6bf2461 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 20 Nov 2025 00:23:53 +0900 Subject: [PATCH 060/215] Action sbt scalafmtAll --- .../credentials/AwsBasicCredentials.scala | 14 +-- .../credentials/AwsSessionCredentials.scala | 16 ++-- .../SystemPropertyCredentialsProvider.scala | 6 +- .../SystemSettingsCredentialsProvider.scala | 68 ++++++++------ .../identity/AwsCredentialsIdentity.scala | 8 +- .../DefaultAwsCredentialsIdentity.scala | 23 +++-- .../useragent/BusinessMetricFeatureId.scala | 90 +++++++++---------- .../ldbc/amazon/util/SystemSetting.scala | 10 ++- 8 files changed, 125 insertions(+), 110 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala index 843b16a83..8e0566e37 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala @@ -11,10 +11,10 @@ import java.time.Instant import ldbc.amazon.identity.AwsCredentials final case class AwsBasicCredentials( - accessKeyId: String, - secretAccessKey: String, - validateCredentials: Boolean, - providerName: Option[String], - accountId: Option[String], - expirationTime: Option[Instant], - ) extends AwsCredentials + accessKeyId: String, + secretAccessKey: String, + validateCredentials: Boolean, + providerName: Option[String], + accountId: Option[String], + expirationTime: Option[Instant] +) extends AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala index ae1f46478..5e9e587cd 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala @@ -11,11 +11,11 @@ import java.time.Instant import ldbc.amazon.identity.AwsCredentials final case class AwsSessionCredentials( - accessKeyId: String, - secretAccessKey: String, - sessionToken: String, - validateCredentials: Boolean, - providerName: Option[String], - accountId: Option[String], - expirationTime: Option[Instant], - ) extends AwsCredentials + accessKeyId: String, + secretAccessKey: String, + sessionToken: String, + validateCredentials: Boolean, + providerName: Option[String], + accountId: Option[String], + expirationTime: Option[Instant] +) extends AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala index cc22ac989..6e7b67921 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -7,8 +7,8 @@ package ldbc.amazon.auth.credentials import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider -import ldbc.amazon.util.SdkSystemSetting import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SdkSystemSetting /** * [[AwsCredentialsProvider]] implementation that loads credentials from the aws.accessKeyId, aws.secretAccessKey and @@ -18,6 +18,8 @@ final class SystemPropertyCredentialsProvider extends SystemSettingsCredentialsP // Customers should be able to specify a credentials provider that only looks at the system properties, // but not the environment variables. For that reason, we're only checking the system properties here. - override def loadSetting(setting: SdkSystemSetting): Option[String] = Option(System.getProperty(setting.systemProperty)) + override def loadSetting(setting: SdkSystemSetting): Option[String] = Option( + System.getProperty(setting.systemProperty) + ) override def provider: String = BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala index fcc7f12fd..8a84d46fe 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -6,46 +6,58 @@ package ldbc.amazon.auth.credentials.internal +import ldbc.amazon.auth.credentials.* +import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.util.SdkSystemSetting -import ldbc.amazon.exception.SdkClientException -import ldbc.amazon.auth.credentials.* trait SystemSettingsCredentialsProvider extends AwsCredentialsProvider: override def resolveCredentials(): Either[SdkClientException, AwsCredentials] = - val accessKeyOpt = loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.trim) - val secretKeyOpt = loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.trim) + val accessKeyOpt = loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.trim) + val secretKeyOpt = loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.trim) val sessionTokenOpt = loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN).map(_.trim) - val accountId = loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.trim) + val accountId = loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.trim) for accessKey <- accessKeyOpt match { - case Some(value) if value.nonEmpty => Right(value) - case _ => Left(new SdkClientException(s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${SdkSystemSetting.AWS_ACCESS_KEY_ID}) or system property (${SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty}).")) - } + case Some(value) if value.nonEmpty => Right(value) + case _ => + Left( + new SdkClientException( + s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${ SdkSystemSetting.AWS_ACCESS_KEY_ID }) or system property (${ SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty })." + ) + ) + } secretKey <- secretKeyOpt match { - case Some(value) if value.isEmpty => Right(value) - case _ => Left(new SdkClientException(s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${SdkSystemSetting.AWS_SECRET_ACCESS_KEY}) or system property (${SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty}).")) - } + case Some(value) if value.isEmpty => Right(value) + case _ => + Left( + new SdkClientException( + s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY }) or system property (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty })." + ) + ) + } yield sessionTokenOpt match { - case None => AwsBasicCredentials( - accessKeyId = accessKey, - secretAccessKey = secretKey, - validateCredentials = false, - providerName = Some(provider), - accountId = accountId, - expirationTime = None, - ) - case Some(sessionToken) => AwsSessionCredentials( - accessKeyId = accessKey, - secretAccessKey = secretKey, - sessionToken = sessionToken, - validateCredentials = false, - providerName = Some(provider), - accountId = accountId, - expirationTime = None, - ) + case None => + AwsBasicCredentials( + accessKeyId = accessKey, + secretAccessKey = secretKey, + validateCredentials = false, + providerName = Some(provider), + accountId = accountId, + expirationTime = None + ) + case Some(sessionToken) => + AwsSessionCredentials( + accessKeyId = accessKey, + secretAccessKey = secretKey, + sessionToken = sessionToken, + validateCredentials = false, + providerName = Some(provider), + accountId = accountId, + expirationTime = None + ) } def loadSetting(setting: SdkSystemSetting): Option[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala index 0819fe7df..ecef0fca3 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala @@ -42,9 +42,9 @@ object AwsCredentialsIdentity: * @param secretAccessKey The AWS secret access key, used to authenticate the user interacting with services. */ def create(accessKeyId: String, secretAccessKey: String): AwsCredentialsIdentity = DefaultAwsCredentialsIdentity( - accessKeyId = accessKeyId, + accessKeyId = accessKeyId, secretAccessKey = secretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index bcdbc3404..b7f981083 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -12,11 +12,11 @@ import java.util.Objects import ldbc.amazon.identity.AwsCredentialsIdentity final case class DefaultAwsCredentialsIdentity( - accessKeyId: String, + accessKeyId: String, secretAccessKey: String, - accountId: Option[String], - expirationTime: Option[Instant], - providerName: Option[String] + accountId: Option[String], + expirationTime: Option[Instant], + providerName: Option[String] ) extends AwsCredentialsIdentity: override def toString: String = @@ -29,16 +29,15 @@ final case class DefaultAwsCredentialsIdentity( builder.result() override def equals(obj: Any): Boolean = - if this == obj then - true - else if obj == null || getClass != obj.getClass then - false - else obj match - case that: AwsCredentialsIdentity => - Objects.equals(accessKeyId, that.accessKeyId) && + if this == obj then true + else if obj == null || getClass != obj.getClass then false + else + obj match + case that: AwsCredentialsIdentity => + Objects.equals(accessKeyId, that.accessKeyId) && Objects.equals(secretAccessKey, that.secretAccessKey) && Objects.equals(accountId, that.accountId) - case _ => false + case _ => false override def hashCode(): Int = var hashCode = 1 diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala index 60ca8e68a..b245509a8 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -13,51 +13,51 @@ package ldbc.amazon.useragent * Unsupported metrics (these will never be added): A,H */ enum BusinessMetricFeatureId(val code: String): - case WAITER extends BusinessMetricFeatureId("B") - case PAGINATOR extends BusinessMetricFeatureId("C") - case RETRY_MODE_LEGACY extends BusinessMetricFeatureId("D") - case RETRY_MODE_STANDARD extends BusinessMetricFeatureId("E") - case RETRY_MODE_ADAPTIVE extends BusinessMetricFeatureId("F") - case S3_TRANSFER extends BusinessMetricFeatureId("G") - case GZIP_REQUEST_COMPRESSION extends BusinessMetricFeatureId("L") - case PROTOCOL_RPC_V2_CBOR extends BusinessMetricFeatureId("M") - case ENDPOINT_OVERRIDE extends BusinessMetricFeatureId("N") - case S3_EXPRESS_BUCKET extends BusinessMetricFeatureId("J") - case ACCOUNT_ID_MODE_PREFERRED extends BusinessMetricFeatureId("P") - case ACCOUNT_ID_MODE_DISABLED extends BusinessMetricFeatureId("Q") - case ACCOUNT_ID_MODE_REQUIRED extends BusinessMetricFeatureId("R") - case SIGV4A_SIGNING extends BusinessMetricFeatureId("S") - case RESOLVED_ACCOUNT_ID extends BusinessMetricFeatureId("T") - case FLEXIBLE_CHECKSUMS_REQ_CRC32 extends BusinessMetricFeatureId("U") - case FLEXIBLE_CHECKSUMS_REQ_CRC32C extends BusinessMetricFeatureId("V") - case FLEXIBLE_CHECKSUMS_REQ_CRC64 extends BusinessMetricFeatureId("W") - case FLEXIBLE_CHECKSUMS_REQ_SHA1 extends BusinessMetricFeatureId("X") - case FLEXIBLE_CHECKSUMS_REQ_SHA256 extends BusinessMetricFeatureId("Y") + case WAITER extends BusinessMetricFeatureId("B") + case PAGINATOR extends BusinessMetricFeatureId("C") + case RETRY_MODE_LEGACY extends BusinessMetricFeatureId("D") + case RETRY_MODE_STANDARD extends BusinessMetricFeatureId("E") + case RETRY_MODE_ADAPTIVE extends BusinessMetricFeatureId("F") + case S3_TRANSFER extends BusinessMetricFeatureId("G") + case GZIP_REQUEST_COMPRESSION extends BusinessMetricFeatureId("L") + case PROTOCOL_RPC_V2_CBOR extends BusinessMetricFeatureId("M") + case ENDPOINT_OVERRIDE extends BusinessMetricFeatureId("N") + case S3_EXPRESS_BUCKET extends BusinessMetricFeatureId("J") + case ACCOUNT_ID_MODE_PREFERRED extends BusinessMetricFeatureId("P") + case ACCOUNT_ID_MODE_DISABLED extends BusinessMetricFeatureId("Q") + case ACCOUNT_ID_MODE_REQUIRED extends BusinessMetricFeatureId("R") + case SIGV4A_SIGNING extends BusinessMetricFeatureId("S") + case RESOLVED_ACCOUNT_ID extends BusinessMetricFeatureId("T") + case FLEXIBLE_CHECKSUMS_REQ_CRC32 extends BusinessMetricFeatureId("U") + case FLEXIBLE_CHECKSUMS_REQ_CRC32C extends BusinessMetricFeatureId("V") + case FLEXIBLE_CHECKSUMS_REQ_CRC64 extends BusinessMetricFeatureId("W") + case FLEXIBLE_CHECKSUMS_REQ_SHA1 extends BusinessMetricFeatureId("X") + case FLEXIBLE_CHECKSUMS_REQ_SHA256 extends BusinessMetricFeatureId("Y") case FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED extends BusinessMetricFeatureId("Z") - case FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED extends BusinessMetricFeatureId("a") + case FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED extends BusinessMetricFeatureId("a") case FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED extends BusinessMetricFeatureId("b") - case FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED extends BusinessMetricFeatureId("c") - case DDB_MAPPER extends BusinessMetricFeatureId("d") - case BEARER_SERVICE_ENV_VARS extends BusinessMetricFeatureId("3") - case CREDENTIALS_CODE extends BusinessMetricFeatureId("e") - case CREDENTIALS_JVM_SYSTEM_PROPERTIES extends BusinessMetricFeatureId("f") - case CREDENTIALS_ENV_VARS extends BusinessMetricFeatureId("g") + case FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED extends BusinessMetricFeatureId("c") + case DDB_MAPPER extends BusinessMetricFeatureId("d") + case BEARER_SERVICE_ENV_VARS extends BusinessMetricFeatureId("3") + case CREDENTIALS_CODE extends BusinessMetricFeatureId("e") + case CREDENTIALS_JVM_SYSTEM_PROPERTIES extends BusinessMetricFeatureId("f") + case CREDENTIALS_ENV_VARS extends BusinessMetricFeatureId("g") case CREDENTIALS_ENV_VARS_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("h") - case CREDENTIALS_STS_ASSUME_ROLE extends BusinessMetricFeatureId("i") - case CREDENTIALS_STS_ASSUME_ROLE_SAML extends BusinessMetricFeatureId("j") - case CREDENTIALS_STS_ASSUME_ROLE_WEB_ID extends BusinessMetricFeatureId("k") - case CREDENTIALS_STS_FEDERATION_TOKEN extends BusinessMetricFeatureId("l") - case CREDENTIALS_STS_SESSION_TOKEN extends BusinessMetricFeatureId("m") - case CREDENTIALS_PROFILE extends BusinessMetricFeatureId("n") - case CREDENTIALS_PROFILE_SOURCE_PROFILE extends BusinessMetricFeatureId("o") - case CREDENTIALS_PROFILE_NAMED_PROVIDER extends BusinessMetricFeatureId("p") - case CREDENTIALS_PROFILE_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("q") - case CREDENTIALS_PROFILE_SSO extends BusinessMetricFeatureId("r") - case CREDENTIALS_SSO extends BusinessMetricFeatureId("s") - case CREDENTIALS_PROFILE_SSO_LEGACY extends BusinessMetricFeatureId("t") - case CREDENTIALS_SSO_LEGACY extends BusinessMetricFeatureId("u") - case CREDENTIALS_PROFILE_PROCESS extends BusinessMetricFeatureId("v") - case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") - case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") - case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") - case UNKNOWN extends BusinessMetricFeatureId("Unknown") + case CREDENTIALS_STS_ASSUME_ROLE extends BusinessMetricFeatureId("i") + case CREDENTIALS_STS_ASSUME_ROLE_SAML extends BusinessMetricFeatureId("j") + case CREDENTIALS_STS_ASSUME_ROLE_WEB_ID extends BusinessMetricFeatureId("k") + case CREDENTIALS_STS_FEDERATION_TOKEN extends BusinessMetricFeatureId("l") + case CREDENTIALS_STS_SESSION_TOKEN extends BusinessMetricFeatureId("m") + case CREDENTIALS_PROFILE extends BusinessMetricFeatureId("n") + case CREDENTIALS_PROFILE_SOURCE_PROFILE extends BusinessMetricFeatureId("o") + case CREDENTIALS_PROFILE_NAMED_PROVIDER extends BusinessMetricFeatureId("p") + case CREDENTIALS_PROFILE_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("q") + case CREDENTIALS_PROFILE_SSO extends BusinessMetricFeatureId("r") + case CREDENTIALS_SSO extends BusinessMetricFeatureId("s") + case CREDENTIALS_PROFILE_SSO_LEGACY extends BusinessMetricFeatureId("t") + case CREDENTIALS_SSO_LEGACY extends BusinessMetricFeatureId("u") + case CREDENTIALS_PROFILE_PROCESS extends BusinessMetricFeatureId("v") + case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") + case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") + case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") + case UNKNOWN extends BusinessMetricFeatureId("Unknown") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala index 0d86f554c..58e35d88d 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala @@ -69,9 +69,11 @@ enum SdkSystemSetting(val systemProperty: String, val defaultValue: Option[Strin * This allows a service running in EC2 to automatically load its credentials and region without needing to configure them * in the SdkClientBuilder. */ - case AWS_EC2_METADATA_SERVICE_ENDPOINT extends SdkSystemSetting("aws.ec2MetadataServiceEndpoint", Some("http://169.254.169.254")) + case AWS_EC2_METADATA_SERVICE_ENDPOINT + extends SdkSystemSetting("aws.ec2MetadataServiceEndpoint", Some("http://169.254.169.254")) - case AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE extends SdkSystemSetting("aws.ec2MetadataServiceEndpointMode", Some("IPv4")) + case AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE + extends SdkSystemSetting("aws.ec2MetadataServiceEndpointMode", Some("IPv4")) /** * The number of seconds (either as an integer or double) before a connection to the instance @@ -92,7 +94,8 @@ enum SdkSystemSetting(val systemProperty: String, val defaultValue: Option[Strin * * This is not used if the [[AWS_CONTAINER_CREDENTIALS_RELATIVE_URI]] is not specified. */ - case AWS_CONTAINER_SERVICE_ENDPOINT extends SdkSystemSetting("aws.containerServiceEndpoint", Some("http://169.254.170.2")) + case AWS_CONTAINER_SERVICE_ENDPOINT + extends SdkSystemSetting("aws.containerServiceEndpoint", Some("http://169.254.170.2")) /** * The elastic container metadata service path that should be called by the ContainerCredentialsProvider when @@ -226,7 +229,6 @@ enum SdkSystemSetting(val systemProperty: String, val defaultValue: Option[Strin */ case AWS_SIGV4A_SIGNING_REGION_SET extends SdkSystemSetting("aws.sigv4a.signing.region.set", None) - /** * Configure the preferred auth scheme to use. * This is a comma-delimited list of AWS auth scheme names used during signing. From f18c5124fa290124b63b8d16aa9041d791049831 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 16:59:36 +0900 Subject: [PATCH 061/215] Create MysqlClearPasswordPlugin --- .../MysqlClearPasswordPlugin.scala | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala new file mode 100644 index 000000000..ebdfe5caf --- /dev/null +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala @@ -0,0 +1,23 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.connector.authenticator + +import java.nio.charset.StandardCharsets + +import scodec.bits.ByteVector + +import cats.effect.kernel.Sync + +class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: + + override def name: String = "mysql_clear_password" + override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = + if password.isEmpty then Sync[F].pure(ByteVector.empty) + else Sync[F].delay(ByteVector(password.getBytes(StandardCharsets.UTF_8))) + +object MysqlClearPasswordPlugin: + def apply[F[_]: Sync](): MysqlClearPasswordPlugin[F] = new MysqlClearPasswordPlugin[F]() From d8f3296135baec5f248aa9a979c1c8e23de186bc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:01:41 +0900 Subject: [PATCH 062/215] Added defaultAuthenticationPlugin property --- .../scala/ldbc/connector/Connection.scala | 16 +++- .../scala/ldbc/connector/net/Protocol.scala | 23 +++-- .../scala/ldbc/connector/ConnectionTest.scala | 88 +++++++++++++++++++ 3 files changed, 116 insertions(+), 11 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index f3eb6089f..492ccfcab 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -28,6 +28,7 @@ import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.* import ldbc.connector.net.protocol.* +import ldbc.connector.authenticator.AuthenticationPlugin type Connection[F[_]] = ldbc.sql.Connection[F] object Connection: @@ -75,7 +76,8 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG) + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default[F, Unit]( host, port, @@ -90,6 +92,7 @@ object Connection: useCursorFetch, useServerPrepStmts, databaseTerm, + defaultAuthenticationPlugin, unitBefore, unitAfter ) @@ -109,7 +112,8 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG) + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default( host, port, @@ -124,6 +128,7 @@ object Connection: useCursorFetch, useServerPrepStmts, databaseTerm, + defaultAuthenticationPlugin, before, after ) @@ -142,6 +147,7 @@ object Connection: useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, before: Connection[F] => F[A], after: (A, Connection[F]) => F[Unit] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = @@ -165,6 +171,7 @@ object Connection: useCursorFetch, useServerPrepStmts, databaseTerm, + defaultAuthenticationPlugin, before, after ) @@ -184,6 +191,7 @@ object Connection: useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], acquire: Connection[F] => F[A], release: (A, Connection[F]) => F[Unit] ): Resource[F, LdbcConnection[F]] = @@ -194,7 +202,7 @@ object Connection: for given Exchange[F] <- Resource.eval(Exchange[F]) protocol <- - Protocol[F](sockets, hostInfo, debug, sslOptions, allowPublicKeyRetrieval, readTimeout, capabilityFlags) + Protocol[F](sockets, hostInfo, debug, sslOptions, allowPublicKeyRetrieval, readTimeout, capabilityFlags, defaultAuthenticationPlugin) _ <- Resource.eval(protocol.startAuthentication(user, password.getOrElse(""))) serverVariables <- Resource.eval(protocol.serverVariables()) readOnly <- Resource.eval(Ref[F].of[Boolean](false)) @@ -234,6 +242,7 @@ object Connection: useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], acquire: Connection[F] => F[A], release: (A, Connection[F]) => F[Unit] )(using ev: Async[F]): Resource[F, LdbcConnection[F]] = @@ -262,6 +271,7 @@ object Connection: useCursorFetch, useServerPrepStmts, databaseTerm, + defaultAuthenticationPlugin, acquire, release ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index 71141a5eb..a2e4c346b 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -145,7 +145,8 @@ object Protocol: useSSL: Boolean = false, allowPublicKeyRetrieval: Boolean = false, capabilityFlags: Set[CapabilitiesFlags], - sequenceIdRef: Ref[F, Byte] + sequenceIdRef: Ref[F, Byte], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], )(using ev: MonadError[F, Throwable], ex: Exchange[F]) extends Protocol[F]: @@ -503,9 +504,11 @@ object Protocol: Attribute("username", username) ))* ) *> ( - determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match - case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) - case Right(plugin) => handshake(plugin, username, password) *> readUntilOk(plugin, password) + defaultAuthenticationPlugin match + case Some(plugin) => handshake(plugin, username, password) *> readUntilOk(plugin, password) + case None => determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match + case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) + case Right(plugin) => handshake(plugin, username, password) *> readUntilOk(plugin, password) ) } @@ -539,7 +542,8 @@ object Protocol: sslOptions: Option[SSLNegotiation.Options[F]], allowPublicKeyRetrieval: Boolean = false, readTimeout: Duration, - capabilitiesFlags: Set[CapabilitiesFlags] + capabilitiesFlags: Set[CapabilitiesFlags], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], ): Resource[F, Protocol[F]] = for sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) @@ -554,7 +558,8 @@ object Protocol: allowPublicKeyRetrieval, capabilitiesFlags, sequenceIdRef, - initialPacketRef + initialPacketRef, + defaultAuthenticationPlugin ) ) yield protocol @@ -566,7 +571,8 @@ object Protocol: allowPublicKeyRetrieval: Boolean = false, capabilitiesFlags: Set[CapabilitiesFlags], sequenceIdRef: Ref[F, Byte], - initialPacketRef: Ref[F, Option[InitialPacket]] + initialPacketRef: Ref[F, Option[InitialPacket]], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], )(using ev: Async[F]): F[Protocol[F]] = initialPacketRef.get.flatMap { case Some(initialPacket) => @@ -578,7 +584,8 @@ object Protocol: sslOptions.isDefined, allowPublicKeyRetrieval, capabilitiesFlags, - sequenceIdRef + sequenceIdRef, + defaultAuthenticationPlugin ) ) case None => diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index 29ef0837b..517c175b7 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -17,6 +17,21 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData import ldbc.connector.exception.* +import ldbc.connector.authenticator.MysqlClearPasswordPlugin + +class ClearTest extends FTestPlatform: + given Tracer[IO] = Tracer.noop[IO] + + test("A user using mysql_clear_password can establish a connection with the MySQL server.") { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), + defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) + ) + assertIOBoolean(connection.use(_ => IO(true))) + } class ConnectionTest extends FTestPlatform: @@ -267,6 +282,79 @@ class ConnectionTest extends FTestPlatform: assertIOBoolean(connection.use(_ => IO(true))) } + test("A user using mysql_clear_password can establish a connection with the MySQL server.") { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_clear_user", + password = Some("ldbc_mysql_clear_password") + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + + test( + "Connections to MySQL servers using users with mysql_clear_password will succeed if allowPublicKeyRetrieval is enabled for non-SSL connections." + ) { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_clear_user", + password = Some("ldbc_mysql_clear_password"), + allowPublicKeyRetrieval = true + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + + test("Connections to MySQL servers using users with mysql_clear_password will succeed for SSL connections.") { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_clear_user", + password = Some("ldbc_mysql_clear_password"), + ssl = SSL.Trusted + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + + test("Users using mysql_clear_password can establish a connection with the MySQL server by specifying database.") { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_clear_user", + password = Some("ldbc_mysql_clear_password"), + database = Some("connector_test") + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + + test( + "If allowPublicKeyRetrieval is enabled for non-SSL connections, a connection to a MySQL server specifying a database using a user with mysql_clear_password will succeed." + ) { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_clear_user", + password = Some("ldbc_mysql_clear_password"), + database = Some("connector_test"), + allowPublicKeyRetrieval = true + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + + test( + "A connection to a MySQL server with a database specified using a user with mysql_clear_password will succeed with an SSL connection." + ) { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_clear_user", + password = Some("ldbc_mysql_clear_password"), + database = Some("connector_test"), + ssl = SSL.Trusted + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + test("Catalog change will change the currently connected Catalog.") { val connection = Connection[IO]( host = "127.0.0.1", From 158e8ade3481256c667a97d1345fe3012094e656 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:11:16 +0900 Subject: [PATCH 063/215] Added defaultAuthenticationPlugin property and set function --- .../ldbc/connector/MySQLDataSource.scala | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index c7d41c83f..328628e09 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -20,6 +20,7 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData import ldbc.connector.pool.* +import ldbc.connector.authenticator.AuthenticationPlugin import ldbc.DataSource @@ -47,6 +48,7 @@ import ldbc.DataSource * @param tracer optional OpenTelemetry tracer for distributed tracing * @param useCursorFetch whether to use cursor-based fetching for result sets * @param useServerPrepStmts whether to use server-side prepared statements + * @param defaultAuthenticationPlugin The authentication plugin used first for communication with the server * @param before optional hook to execute before a connection is acquired * @param after optional hook to execute after a connection is used * @@ -81,6 +83,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen tracer: Option[Tracer[F]] = None, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ) extends DataSource[F]: @@ -113,7 +116,8 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + defaultAuthenticationPlugin = defaultAuthenticationPlugin ) case (Some(b), None) => Connection.withBeforeAfter( @@ -131,7 +135,8 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + defaultAuthenticationPlugin = defaultAuthenticationPlugin ) case (None, _) => Connection( @@ -147,7 +152,8 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + defaultAuthenticationPlugin = defaultAuthenticationPlugin ) /** Sets the hostname or IP address of the MySQL server. @@ -245,6 +251,14 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen def setUseServerPrepStmts(newUseServerPrepStmts: Boolean): MySQLDataSource[F, A] = copy(useServerPrepStmts = newUseServerPrepStmts) + /** Sets whether to authentication plugin to be used first for communication with the server. + * @param defaultAuthenticationPlugin + * The authentication plugin used first for communication with the server + * @return a new MySQLDataSource with the updated setting + */ + def setDefaultAuthenticationPlugin(defaultAuthenticationPlugin: AuthenticationPlugin[F]): MySQLDataSource[F, A] = + copy(defaultAuthenticationPlugin = Some(defaultAuthenticationPlugin)) + /** * Adds a before hook that will be executed when a connection is acquired. * From 3ed7cdab48959d68e4969acb26a715ba8acb1a45 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:13:21 +0900 Subject: [PATCH 064/215] Added defaultAuthenticationPlugin test --- .../scala/ldbc/connector/ConnectionTest.scala | 84 +------------------ 1 file changed, 4 insertions(+), 80 deletions(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index 517c175b7..9478da516 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -19,20 +19,6 @@ import ldbc.sql.DatabaseMetaData import ldbc.connector.exception.* import ldbc.connector.authenticator.MysqlClearPasswordPlugin -class ClearTest extends FTestPlatform: - given Tracer[IO] = Tracer.noop[IO] - - test("A user using mysql_clear_password can establish a connection with the MySQL server.") { - val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_native_user", - password = Some("ldbc_mysql_native_password"), - defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) - ) - assertIOBoolean(connection.use(_ => IO(true))) - } - class ConnectionTest extends FTestPlatform: given Tracer[IO] = Tracer.noop[IO] @@ -282,75 +268,13 @@ class ConnectionTest extends FTestPlatform: assertIOBoolean(connection.use(_ => IO(true))) } - test("A user using mysql_clear_password can establish a connection with the MySQL server.") { - val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_clear_user", - password = Some("ldbc_mysql_clear_password") - ) - assertIOBoolean(connection.use(_ => IO(true))) - } - - test( - "Connections to MySQL servers using users with mysql_clear_password will succeed if allowPublicKeyRetrieval is enabled for non-SSL connections." - ) { - val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_clear_user", - password = Some("ldbc_mysql_clear_password"), - allowPublicKeyRetrieval = true - ) - assertIOBoolean(connection.use(_ => IO(true))) - } - - test("Connections to MySQL servers using users with mysql_clear_password will succeed for SSL connections.") { - val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_clear_user", - password = Some("ldbc_mysql_clear_password"), - ssl = SSL.Trusted - ) - assertIOBoolean(connection.use(_ => IO(true))) - } - - test("Users using mysql_clear_password can establish a connection with the MySQL server by specifying database.") { - val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_clear_user", - password = Some("ldbc_mysql_clear_password"), - database = Some("connector_test") - ) - assertIOBoolean(connection.use(_ => IO(true))) - } - - test( - "If allowPublicKeyRetrieval is enabled for non-SSL connections, a connection to a MySQL server specifying a database using a user with mysql_clear_password will succeed." - ) { - val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_clear_user", - password = Some("ldbc_mysql_clear_password"), - database = Some("connector_test"), - allowPublicKeyRetrieval = true - ) - assertIOBoolean(connection.use(_ => IO(true))) - } - - test( - "A connection to a MySQL server with a database specified using a user with mysql_clear_password will succeed with an SSL connection." - ) { + test("You can connect to the database by specifying the default authentication plugin.") { val connection = Connection[IO]( host = "127.0.0.1", port = 13306, - user = "ldbc_mysql_clear_user", - password = Some("ldbc_mysql_clear_password"), - database = Some("connector_test"), - ssl = SSL.Trusted + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), + defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) ) assertIOBoolean(connection.use(_ => IO(true))) } From 5d3f5a1da382f5a9df8c0abb6debb89273ca01d0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:44:16 +0900 Subject: [PATCH 065/215] Added scaladoc --- .../authenticator/AuthenticationPlugin.scala | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala index 8a90704e3..459b7b830 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala @@ -8,8 +8,57 @@ package ldbc.connector.authenticator import scodec.bits.ByteVector +/** + * A trait representing a MySQL authentication plugin for database connections. + * + * This trait defines the contract for various authentication mechanisms supported by MySQL, + * including traditional password-based authentication (mysql_native_password) and + * modern authentication methods like mysql_clear_password for IAM authentication. + * + * Authentication plugins are used during the MySQL handshake process to validate + * client credentials and establish secure database connections. + * + * @tparam F The effect type that wraps the authentication operations + */ trait AuthenticationPlugin[F[_]]: + /** + * The name of the authentication plugin as recognized by the MySQL server. + * + * Common plugin names include: + * - "mysql_native_password" for traditional SHA1-based password authentication + * - "mysql_clear_password" for plaintext password transmission over SSL + * - "caching_sha2_password" for SHA256-based password authentication + * - "mysql_old_password" for legacy MySQL authentication (deprecated) + * + * @return The plugin name string that identifies this authentication method + */ def name: String + /** + * Indicates whether this authentication plugin requires a secure (encrypted) connection. + * + * Some authentication plugins, particularly those that transmit passwords in cleartext + * (like mysql_clear_password), require SSL/TLS encryption to ensure data security. + * Traditional hashing-based plugins may optionally use encryption but don't strictly require it. + * + * @return true if SSL/TLS connection is mandatory for this plugin, false otherwise + */ + def requiresConfidentiality: Boolean + + /** + * Processes the password according to the authentication plugin's requirements. + * + * Different authentication plugins handle passwords differently: + * - mysql_native_password: Performs SHA1-based hashing with the server's scramble + * - mysql_clear_password: Returns the password as plaintext bytes (requires SSL) + * - caching_sha2_password: Performs SHA256-based hashing with salt + * + * @param password The user's password in plaintext + * @param scramble The random challenge bytes sent by the MySQL server during handshake. + * Used as salt/seed for cryptographic hashing in most authentication methods. + * May be ignored by plugins that don't use server-side challenges. + * @return The processed password data wrapped in the effect type F, ready for transmission + * to the MySQL server during authentication + */ def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] From 741c41a53352512727a1715aba7517052f289f01 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:44:34 +0900 Subject: [PATCH 066/215] Added checkRequiresConfidentiality --- .../MysqlClearPasswordPlugin.scala | 1 + .../MysqlNativePasswordPlugin.scala | 1 + .../authenticator/Sha256PasswordPlugin.scala | 1 + .../scala/ldbc/connector/net/Protocol.scala | 17 +++++++++++++++-- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala index ebdfe5caf..251a0582b 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala @@ -15,6 +15,7 @@ import cats.effect.kernel.Sync class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: override def name: String = "mysql_clear_password" + override def requiresConfidentiality: Boolean = true override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else Sync[F].delay(ByteVector(password.getBytes(StandardCharsets.UTF_8))) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala index e695ba6ce..5ba643362 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala @@ -22,6 +22,7 @@ import fs2.Chunk class MysqlNativePasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F]: override def name: String = "mysql_native_password" + override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala index 784a1a493..f138725e7 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala @@ -24,6 +24,7 @@ trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F] (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray override def name: String = "sha256_password" + override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index a2e4c346b..5f899df93 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -505,10 +505,10 @@ object Protocol: ))* ) *> ( defaultAuthenticationPlugin match - case Some(plugin) => handshake(plugin, username, password) *> readUntilOk(plugin, password) + case Some(plugin) => checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk(plugin, password) case None => determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) - case Right(plugin) => handshake(plugin, username, password) *> readUntilOk(plugin, password) + case Right(plugin) => checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk(plugin, password) ) } @@ -535,6 +535,19 @@ object Protocol: ) } + private def checkRequiresConfidentiality(plugin: AuthenticationPlugin[F], span: Span[F]): F[Unit] = + if plugin.requiresConfidentiality && !useSSL then + val error = new SQLInvalidAuthorizationSpecException( + s"SSL connection required for plugin “${plugin.name}”. Check if ‘ssl’ is enabled.", + hint = Some( + """// You can enable SSL. + | MySQLDataSource.build[IO](....).setSSL(SSL.Trusted) + |""".stripMargin + ) + ) + span.recordException(error) *> ev.raiseError(error) + else ev.unit + def apply[F[_]: Async: Console: Tracer: Exchange: Hashing]( sockets: Resource[F, Socket[F]], hostInfo: HostInfo, From 5660641b5b53b4f972589dbce3397c623b21168c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:44:47 +0900 Subject: [PATCH 067/215] Added SQLInvalidAuthorizationSpecException check test --- .../scala/ldbc/connector/ConnectionTest.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index 9478da516..5eeb510b6 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -274,11 +274,24 @@ class ConnectionTest extends FTestPlatform: port = 13306, user = "ldbc_mysql_native_user", password = Some("ldbc_mysql_native_password"), - defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) + defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()), + ssl = SSL.Trusted ) assertIOBoolean(connection.use(_ => IO(true))) } + test("Using the MySQL Clear Password Plugin when SSL is not enabled causes an SQLInvalidAuthorizationSpecException to occur.") { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), + defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) + ) + interceptIO[SQLInvalidAuthorizationSpecException](connection.use(_ => IO(true))) + } + + test("Catalog change will change the currently connected Catalog.") { val connection = Connection[IO]( host = "127.0.0.1", From 2f720bcdfdf650793835b0065e8cac22fc73f061 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 17:45:56 +0900 Subject: [PATCH 068/215] Action sbt scalafmtAll --- .../scala/ldbc/connector/Connection.scala | 167 +++++++++--------- .../ldbc/connector/MySQLDataSource.scala | 122 ++++++------- .../MysqlClearPasswordPlugin.scala | 2 +- .../MysqlNativePasswordPlugin.scala | 2 +- .../authenticator/Sha256PasswordPlugin.scala | 2 +- .../scala/ldbc/connector/net/Protocol.scala | 67 ++++--- .../scala/ldbc/connector/ConnectionTest.scala | 25 +-- 7 files changed, 203 insertions(+), 184 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index 492ccfcab..440420e2c 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -24,11 +24,11 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData +import ldbc.connector.authenticator.AuthenticationPlugin import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.* import ldbc.connector.net.protocol.* -import ldbc.connector.authenticator.AuthenticationPlugin type Connection[F[_]] = ldbc.sql.Connection[F] object Connection: @@ -64,20 +64,20 @@ object Connection: this.default[F, Unit](host, port, user, before = unitBefore, after = unitAfter) def apply[F[_]: Async: Network: Console: Hashing: UUIDGen]( - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default[F, Unit]( host, port, @@ -98,22 +98,22 @@ object Connection: ) def withBeforeAfter[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - before: Connection[F] => F[A], - after: (A, Connection[F]) => F[Unit], - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + host: String, + port: Int, + user: String, + before: Connection[F] => F[A], + after: (A, Connection[F]) => F[Unit], + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default( host, port, @@ -134,22 +134,22 @@ object Connection: ) def default[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, - before: Connection[F] => F[A], - after: (A, Connection[F]) => F[Unit] + before: Connection[F] => F[A], + after: (A, Connection[F]) => F[Unit] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = val logger: String => F[Unit] = s => Console[F].println(s"TLS: $s") @@ -171,29 +171,29 @@ object Connection: useCursorFetch, useServerPrepStmts, databaseTerm, - defaultAuthenticationPlugin, + defaultAuthenticationPlugin, before, after ) yield connection def fromSockets[F[_]: Async: Tracer: Console: Hashing: UUIDGen, A]( - sockets: Resource[F, Socket[F]], - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - sslOptions: Option[SSLNegotiation.Options[F]], - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, + sockets: Resource[F, Socket[F]], + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + sslOptions: Option[SSLNegotiation.Options[F]], + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - acquire: Connection[F] => F[A], - release: (A, Connection[F]) => F[Unit] + acquire: Connection[F] => F[A], + release: (A, Connection[F]) => F[Unit] ): Resource[F, LdbcConnection[F]] = val capabilityFlags = defaultCapabilityFlags ++ (if database.isDefined then Set(CapabilitiesFlags.CLIENT_CONNECT_WITH_DB) else Set.empty) ++ @@ -202,7 +202,16 @@ object Connection: for given Exchange[F] <- Resource.eval(Exchange[F]) protocol <- - Protocol[F](sockets, hostInfo, debug, sslOptions, allowPublicKeyRetrieval, readTimeout, capabilityFlags, defaultAuthenticationPlugin) + Protocol[F]( + sockets, + hostInfo, + debug, + sslOptions, + allowPublicKeyRetrieval, + readTimeout, + capabilityFlags, + defaultAuthenticationPlugin + ) _ <- Resource.eval(protocol.startAuthentication(user, password.getOrElse(""))) serverVariables <- Resource.eval(protocol.serverVariables()) readOnly <- Resource.eval(Ref[F].of[Boolean](false)) @@ -228,23 +237,23 @@ object Connection: yield connection def fromSocketGroup[F[_]: Tracer: Console: Hashing: UUIDGen, A]( - socketGroup: SocketGroup[F], - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - socketOptions: List[SocketOption], - sslOptions: Option[SSLNegotiation.Options[F]], - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, + socketGroup: SocketGroup[F], + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + socketOptions: List[SocketOption], + sslOptions: Option[SSLNegotiation.Options[F]], + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - acquire: Connection[F] => F[A], - release: (A, Connection[F]) => F[Unit] + acquire: Connection[F] => F[A], + release: (A, Connection[F]) => F[Unit] )(using ev: Async[F]): Resource[F, LdbcConnection[F]] = def fail[B](msg: String): Resource[F, B] = diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 328628e09..ec3f5cfea 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -19,8 +19,8 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.connector.pool.* import ldbc.connector.authenticator.AuthenticationPlugin +import ldbc.connector.pool.* import ldbc.DataSource @@ -69,23 +69,23 @@ import ldbc.DataSource * }}} */ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = MySQLConfig.defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - tracer: Option[Tracer[F]] = None, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, - before: Option[Connection[F] => F[A]] = None, - after: Option[(A, Connection[F]) => F[Unit]] = None + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = MySQLConfig.defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + tracer: Option[Tracer[F]] = None, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + before: Option[Connection[F] => F[A]] = None, + after: Option[(A, Connection[F]) => F[Unit]] = None ) extends DataSource[F]: given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) @@ -102,57 +102,57 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen (before, after) match case (Some(b), Some(a)) => Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = a, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm, + host = host, + port = port, + user = user, + before = b, + after = a, + password = password, + database = database, + debug = debug, + ssl = ssl, + socketOptions = socketOptions, + readTimeout = readTimeout, + allowPublicKeyRetrieval = allowPublicKeyRetrieval, + useCursorFetch = useCursorFetch, + useServerPrepStmts = useServerPrepStmts, + databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin ) case (Some(b), None) => Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = (_, _) => Async[F].unit, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm, + host = host, + port = port, + user = user, + before = b, + after = (_, _) => Async[F].unit, + password = password, + database = database, + debug = debug, + ssl = ssl, + socketOptions = socketOptions, + readTimeout = readTimeout, + allowPublicKeyRetrieval = allowPublicKeyRetrieval, + useCursorFetch = useCursorFetch, + useServerPrepStmts = useServerPrepStmts, + databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin ) case (None, _) => Connection( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm, + host = host, + port = port, + user = user, + password = password, + database = database, + debug = debug, + ssl = ssl, + socketOptions = socketOptions, + readTimeout = readTimeout, + allowPublicKeyRetrieval = allowPublicKeyRetrieval, + useCursorFetch = useCursorFetch, + useServerPrepStmts = useServerPrepStmts, + databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala index 251a0582b..8735ab59a 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala @@ -15,7 +15,7 @@ import cats.effect.kernel.Sync class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: override def name: String = "mysql_clear_password" - override def requiresConfidentiality: Boolean = true + override def requiresConfidentiality: Boolean = true override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else Sync[F].delay(ByteVector(password.getBytes(StandardCharsets.UTF_8))) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala index 5ba643362..d9c7dfb92 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala @@ -22,7 +22,7 @@ import fs2.Chunk class MysqlNativePasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F]: override def name: String = "mysql_native_password" - override def requiresConfidentiality: Boolean = false + override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala index f138725e7..7c77bf4bb 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala @@ -24,7 +24,7 @@ trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F] (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray override def name: String = "sha256_password" - override def requiresConfidentiality: Boolean = false + override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index 5f899df93..fe5a2103b 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -139,14 +139,14 @@ object Protocol: "SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout" private[ldbc] case class Impl[F[_]: Async: Tracer: Hashing]( - initialPacket: InitialPacket, - hostInfo: HostInfo, - socket: PacketSocket[F], - useSSL: Boolean = false, - allowPublicKeyRetrieval: Boolean = false, - capabilityFlags: Set[CapabilitiesFlags], - sequenceIdRef: Ref[F, Byte], - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + initialPacket: InitialPacket, + hostInfo: HostInfo, + socket: PacketSocket[F], + useSSL: Boolean = false, + allowPublicKeyRetrieval: Boolean = false, + capabilityFlags: Set[CapabilitiesFlags], + sequenceIdRef: Ref[F, Byte], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] )(using ev: MonadError[F, Throwable], ex: Exchange[F]) extends Protocol[F]: @@ -505,10 +505,19 @@ object Protocol: ))* ) *> ( defaultAuthenticationPlugin match - case Some(plugin) => checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk(plugin, password) - case None => determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match - case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) - case Right(plugin) => checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk(plugin, password) + case Some(plugin) => + checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk( + plugin, + password + ) + case None => + determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match + case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) + case Right(plugin) => + checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk( + plugin, + password + ) ) } @@ -538,7 +547,7 @@ object Protocol: private def checkRequiresConfidentiality(plugin: AuthenticationPlugin[F], span: Span[F]): F[Unit] = if plugin.requiresConfidentiality && !useSSL then val error = new SQLInvalidAuthorizationSpecException( - s"SSL connection required for plugin “${plugin.name}”. Check if ‘ssl’ is enabled.", + s"SSL connection required for plugin “${ plugin.name }”. Check if ‘ssl’ is enabled.", hint = Some( """// You can enable SSL. | MySQLDataSource.build[IO](....).setSSL(SSL.Trusted) @@ -549,14 +558,14 @@ object Protocol: else ev.unit def apply[F[_]: Async: Console: Tracer: Exchange: Hashing]( - sockets: Resource[F, Socket[F]], - hostInfo: HostInfo, - debug: Boolean, - sslOptions: Option[SSLNegotiation.Options[F]], - allowPublicKeyRetrieval: Boolean = false, - readTimeout: Duration, - capabilitiesFlags: Set[CapabilitiesFlags], - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + sockets: Resource[F, Socket[F]], + hostInfo: HostInfo, + debug: Boolean, + sslOptions: Option[SSLNegotiation.Options[F]], + allowPublicKeyRetrieval: Boolean = false, + readTimeout: Duration, + capabilitiesFlags: Set[CapabilitiesFlags], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] ): Resource[F, Protocol[F]] = for sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) @@ -578,14 +587,14 @@ object Protocol: yield protocol def fromPacketSocket[F[_]: Tracer: Exchange: Hashing]( - packetSocket: PacketSocket[F], - hostInfo: HostInfo, - sslOptions: Option[SSLNegotiation.Options[F]], - allowPublicKeyRetrieval: Boolean = false, - capabilitiesFlags: Set[CapabilitiesFlags], - sequenceIdRef: Ref[F, Byte], - initialPacketRef: Ref[F, Option[InitialPacket]], - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + packetSocket: PacketSocket[F], + hostInfo: HostInfo, + sslOptions: Option[SSLNegotiation.Options[F]], + allowPublicKeyRetrieval: Boolean = false, + capabilitiesFlags: Set[CapabilitiesFlags], + sequenceIdRef: Ref[F, Byte], + initialPacketRef: Ref[F, Option[InitialPacket]], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] )(using ev: Async[F]): F[Protocol[F]] = initialPacketRef.get.flatMap { case Some(initialPacket) => diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index 5eeb510b6..f32668f5c 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -16,8 +16,8 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.connector.exception.* import ldbc.connector.authenticator.MysqlClearPasswordPlugin +import ldbc.connector.exception.* class ConnectionTest extends FTestPlatform: @@ -270,28 +270,29 @@ class ConnectionTest extends FTestPlatform: test("You can connect to the database by specifying the default authentication plugin.") { val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_native_user", - password = Some("ldbc_mysql_native_password"), + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()), - ssl = SSL.Trusted + ssl = SSL.Trusted ) assertIOBoolean(connection.use(_ => IO(true))) } - test("Using the MySQL Clear Password Plugin when SSL is not enabled causes an SQLInvalidAuthorizationSpecException to occur.") { + test( + "Using the MySQL Clear Password Plugin when SSL is not enabled causes an SQLInvalidAuthorizationSpecException to occur." + ) { val connection = Connection[IO]( - host = "127.0.0.1", - port = 13306, - user = "ldbc_mysql_native_user", - password = Some("ldbc_mysql_native_password"), + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) ) interceptIO[SQLInvalidAuthorizationSpecException](connection.use(_ => IO(true))) } - test("Catalog change will change the currently connected Catalog.") { val connection = Connection[IO]( host = "127.0.0.1", From 3a6716a6ec7443161944fd971cb72765be751ea4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 20:15:49 +0900 Subject: [PATCH 069/215] Added mysql_clear_password --- .../main/scala/ldbc/connector/net/protocol/Authentication.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala index 85cdcb723..022dfc7ab 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala @@ -40,6 +40,7 @@ trait Authentication[F[_]: Hashing: Sync]: */ protected def determinatePlugin(pluginName: String, version: Version): Either[SQLException, AuthenticationPlugin[F]] = pluginName match + case "mysql_clear_password" => Right(MysqlClearPasswordPlugin[F]()) case "mysql_native_password" => Right(MysqlNativePasswordPlugin[F]()) case "sha256_password" => Right(Sha256PasswordPlugin[F]()) case "caching_sha2_password" => Right(CachingSha2PasswordPlugin[F](version)) From 5490a27a14a4550fbbd7a84ff08611471532e7b0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 20:16:47 +0900 Subject: [PATCH 070/215] Action sbt scalafmtAll --- .../main/scala/ldbc/connector/net/protocol/Authentication.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala index 022dfc7ab..8df586eb7 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala @@ -40,7 +40,7 @@ trait Authentication[F[_]: Hashing: Sync]: */ protected def determinatePlugin(pluginName: String, version: Version): Either[SQLException, AuthenticationPlugin[F]] = pluginName match - case "mysql_clear_password" => Right(MysqlClearPasswordPlugin[F]()) + case "mysql_clear_password" => Right(MysqlClearPasswordPlugin[F]()) case "mysql_native_password" => Right(MysqlNativePasswordPlugin[F]()) case "sha256_password" => Right(Sha256PasswordPlugin[F]()) case "caching_sha2_password" => Right(CachingSha2PasswordPlugin[F](version)) From 25528018b80ec0257294007500c22e79f7799534 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 20:52:31 +0900 Subject: [PATCH 071/215] Added clear paswword document --- docs/src/main/mdoc/en/reference/Connector.md | 8 ++++++++ docs/src/main/mdoc/en/tutorial/Connection.md | 2 ++ docs/src/main/mdoc/ja/reference/Connector.md | 8 ++++++++ docs/src/main/mdoc/ja/tutorial/Connection.md | 2 ++ 4 files changed, 20 insertions(+) diff --git a/docs/src/main/mdoc/en/reference/Connector.md b/docs/src/main/mdoc/en/reference/Connector.md index e5f721618..9d4d958a9 100644 --- a/docs/src/main/mdoc/en/reference/Connector.md +++ b/docs/src/main/mdoc/en/reference/Connector.md @@ -115,9 +115,17 @@ ldbc currently supports the following authentication plugins: - Native Pluggable Authentication - SHA-256 Pluggable Authentication - SHA-2 Pluggable Authentication Cache +- Cleartext Pluggable Authentication *Note: Native Pluggable Authentication and SHA-256 Pluggable Authentication are deprecated plugins from MySQL 8.x. It is recommended to use the SHA-2 Pluggable Authentication Cache unless there is a specific reason.* +@:callout(warning) +**Important Security Precautions:** +MySQL cleartext pluggable authentication is an authentication method that sends passwords in plain text to the server. When using this authentication plugin, you must enable SSL/TLS connections. Using it without SSL/TLS connections is extremely dangerous from a security standpoint, as passwords are transmitted over the network in plain text. + +This authentication plugin is primarily used for integration with AWS IAM authentication and other external authentication systems. For instructions on configuring SSL/TLS connections, refer to the [SSL configuration section of the tutorial](/en/tutorial/Connection.md#connection-with-ssl-configuration). +@:@ + You do not need to be aware of authentication plugins in the ldbc application code. Users should create users with the desired authentication plugin on the MySQL database and use those users to attempt connections to MySQL in the ldbc application code. ldbc internally determines the authentication plugin and connects to MySQL using the appropriate authentication plugin. diff --git a/docs/src/main/mdoc/en/tutorial/Connection.md b/docs/src/main/mdoc/en/tutorial/Connection.md index 5796aec32..0d3321429 100644 --- a/docs/src/main/mdoc/en/tutorial/Connection.md +++ b/docs/src/main/mdoc/en/tutorial/Connection.md @@ -148,6 +148,8 @@ You can add SSL configuration to establish a secure connection: ※ Note that Trusted accepts all certificates. This is a setting for development environments. +※ For security reasons, SSL/TLS connections are required for certain authentication plugins, such as MySQL cleartext pluggable authentication. For details, see the [authentication section of the reference](/en/reference/Connector.md#authentication). + ```scala import cats.effect.IO import ldbc.connector.* diff --git a/docs/src/main/mdoc/ja/reference/Connector.md b/docs/src/main/mdoc/ja/reference/Connector.md index 510463307..c15a01464 100644 --- a/docs/src/main/mdoc/ja/reference/Connector.md +++ b/docs/src/main/mdoc/ja/reference/Connector.md @@ -115,9 +115,17 @@ ldbcは現時点で以下の認証プラグインをサポートしています - ネイティブプラガブル認証 - SHA-256 プラガブル認証 - SHA-2 プラガブル認証のキャッシュ +- クリアテキストプラガブル認証 ※ ネイティブプラガブル認証とSHA-256 プラガブル認証はMySQL 8.xから非推奨となったプラグインです。特段理由がない場合はSHA-2 プラガブル認証のキャッシュを使用することを推奨します。 +@:callout(warning) +**重要なセキュリティ注意事項:** +MySQL クリアテキストプラガブル認証は、パスワードを平文でサーバーに送信する認証方式です。この認証プラグインを使用する場合は、必ずSSL/TLS接続を有効にしてください。SSL/TLS接続なしでの使用は、パスワードがネットワーク上で平文として送信されるため、セキュリティ上非常に危険です。 + +この認証プラグインは主にAWS IAM認証やその他の外部認証システムとの統合で使用されます。SSL/TLS接続の設定方法については、[チュートリアルのSSL設定セクション](/ja/tutorial/Connection.md#ssl設定を使用したコネクション)を参照してください。 +@:@ + ldbcのアプリケーションコード上で認証プラグインを意識する必要はありません。ユーザーはMySQLのデータベース上で使用したい認証プラグインで作成されたユーザーを作成し、ldbcのアプリケーションコード上ではそのユーザーを使用してMySQLへの接続を試みるだけで問題ありません。 ldbcが内部で認証プラグインを判断し、適切な認証プラグインを使用してMySQLへの接続を行います。 diff --git a/docs/src/main/mdoc/ja/tutorial/Connection.md b/docs/src/main/mdoc/ja/tutorial/Connection.md index ab15892ae..9c3903b9a 100644 --- a/docs/src/main/mdoc/ja/tutorial/Connection.md +++ b/docs/src/main/mdoc/ja/tutorial/Connection.md @@ -146,6 +146,8 @@ val program = datasource.getConnection.use { connection => ※ Trustedは全ての証明書を受け入れることに注意してください。これは開発環境向けの設定です。 +※ MySQL クリアテキストプラガブル認証などの一部の認証プラグインでは、セキュリティ上の理由からSSL/TLS接続が必須です。詳細は[リファレンスの認証セクション](/ja/reference/Connector.md#認証)を参照してください。 + ```scala import cats.effect.IO import ldbc.connector.* From 86f04bfba0f1f3838e14b285662ad6432c598ddc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 21:59:42 +0900 Subject: [PATCH 072/215] Added fs2 dependencies --- build.sbt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/build.sbt b/build.sbt index 4898cf6fb..6204b6daa 100644 --- a/build.sbt +++ b/build.sbt @@ -146,6 +146,13 @@ lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) .module("aws-authentication-plugin", "") + .settings( + libraryDependencies ++= Seq( + "co.fs2" %%% "fs2-core" % "3.12.2", + "co.fs2" %%% "fs2-io" % "3.12.2", + "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test + ) + ) .jsSettings( Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) ) From def09dbd148b6a4658e7ed8e1b06cab6cc77d008 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 22:00:05 +0900 Subject: [PATCH 073/215] Create AuthTokenGenerator --- .../auth/token/AuthTokenGenerator.scala | 6 + .../auth/token/RdsIamAuthTokenGenerator.scala | 126 ++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala new file mode 100644 index 000000000..4c966b88b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala @@ -0,0 +1,6 @@ +package ldbc.amazon.auth.token + +import ldbc.amazon.identity.AwsCredentials + +trait AuthTokenGenerator[F[_]]: + def generateToken(credentials: AwsCredentials): F[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala new file mode 100644 index 000000000..4bd0d54c2 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -0,0 +1,126 @@ +package ldbc.amazon.auth.token + +import java.net.URLEncoder +import java.nio.charset.StandardCharsets +import java.time.format.DateTimeFormatter +import java.time.{Instant, ZoneOffset} +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec + +import cats.syntax.all.* + +import cats.effect.kernel.{Clock, Sync} + +import fs2.hashing.{HashAlgorithm, Hashing} +import fs2.{Chunk, Stream} + +import ldbc.amazon.identity.AwsCredentials + +class RdsIamAuthTokenGenerator[F[_]: Hashing]( + hostname: String, + port: Int, + username: String, + region: String +)(using clock: Clock[F]) extends AuthTokenGenerator[F]: + + private val ALGORITHM = "AWS4-HMAC-SHA256" + private val SERVICE = "rds-db" + private val EXPIRES_SECONDS = 900 + private val TERMINATOR = "aws4_request" + + override def generateToken(credentials: AwsCredentials): F[String] = + for + now <- clock.realTimeInstant + dateTime = formatDateTime(now) + date = dateTime.substring(0, 8) + credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" + credential = s"${credentials.accessKeyId}/$credentialScope" + queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) + canonicalRequest = buildCanonicalRequest(s"$hostname:$port", queryParams) + canonicalRequestHash <- sha256Hex(canonicalRequest) + stringToSign = buildStringToSign(dateTime, credentialScope, canonicalRequestHash) + signature <- calculateSignature(credentials.secretAccessKey, date, region, stringToSign) + yield s"${config.hostname}:${config.port}/?$queryParams&X-Amz-Signature=$signature" + + private def formatDateTime(instant: Instant): String = + DateTimeFormatter + .ofPattern("yyyyMMdd'T'HHmmss'Z'") + .withZone(ZoneOffset.UTC) + .format(instant) + + private def buildQueryParams( + credential: String, + dateTime: String, + sessionToken: String, + username: String + ): String = + val params = List( + "Action" -> "connect", + "DBUser" -> username, + "X-Amz-Algorithm" -> ALGORITHM, + "X-Amz-Credential" -> credential, + "X-Amz-Date" -> dateTime, + "X-Amz-Expires" -> EXPIRESSECONDS.toString, + "X-Amz-Security-Token" -> sessionToken, + "X-Amz-SignedHeaders" -> "host" + ) + params + .sortBy(_._1) + .map { case (k, v) => s"${urlEncode(k)}=${urlEncode(v)}" } + .mkString("&") + + private def buildCanonicalRequest(host: String, queryString: String): String = + val method = "GET" + val canonicalUri = "/" + val canonicalHeaders = s"host:$host\n" + val signedHeaders = "host" + val payloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + s"$method\n$canonicalUri\n$queryString\n$canonicalHeaders\n$signedHeaders\n$payloadHash" + + private def bytesToHex(bytes: Array[Byte]): String = + bytes.map("%02x".format(_)).mkString + + private def sha256Hex(data: String): F[String] = + Stream + .chunk(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) + .through(Hashing[F].hash(HashAlgorithm.SHA256)) + .compile + .lastOrError + .map(hash => bytesToHex(hash.bytes.toArray)) + + private def buildStringToSign( + dateTime: String, + credentialScope: String, + canonicalRequestHash: String + ): String = + s"$Algorithm\n$dateTime\n$credentialScope\n$canonicalRequestHash" + + private def hmacSha256(key: Array[Byte], data: String): F[Array[Byte]] = + Hashing[F] + .hmac(HashAlgorithm.SHA256, Chunk.array(key)) + .use { hmac => + for + _ <- hmac.update(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) + hash <- hmac.hash + yield hash.bytes.toArray + } + + private def calculateSignature( + secretKey: String, + date: String, + region: String, + stringToSign: String + ): F[String] = + for + kDate <- hmacSha256(s"AWS4$secretKey".getBytes(StandardCharsets.UTF_8), date) + kRegion <- hmacSha256(kDate, region) + kService <- hmacSha256(kRegion, Service) + kSigning <- hmacSha256(kService, Terminator) + sig <- hmacSha256(kSigning, stringToSign) + yield bytesToHex(sig) + + private def urlEncode(value: String): String = + URLEncoder.encode(value, "UTF-8") + .replace("+", "%20") + .replace("*", "%2A") + .replace("%7E", "~") From f51802de3301145f8bccb296c7199b0d1f244492 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 29 Nov 2025 22:58:01 +0900 Subject: [PATCH 074/215] Added scaladoc comment --- .../auth/token/AuthTokenGenerator.scala | 24 ++++++++ .../auth/token/RdsIamAuthTokenGenerator.scala | 57 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala index 4c966b88b..a550b9f91 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala @@ -2,5 +2,29 @@ package ldbc.amazon.auth.token import ldbc.amazon.identity.AwsCredentials +/** + * A trait for generating authentication tokens using AWS credentials. + * + * This trait defines the contract for creating authentication tokens that can be used + * for various AWS services that support IAM-based authentication. The generated tokens + * provide temporary access based on IAM credentials and policies, eliminating the need + * to store long-term credentials in application code. + * + * @tparam F The effect type that wraps the token generation operations + */ trait AuthTokenGenerator[F[_]]: + + /** + * Generates an authentication token using the provided AWS credentials. + * + * The generated token provides temporary access to AWS services that support + * IAM-based authentication. The token is typically valid for a limited time + * period and provides access based on the IAM permissions associated with + * the provided credentials. + * + * @param credentials The AWS credentials containing access key ID, secret access key, + * and optional session token for temporary credentials + * @return The authentication token wrapped in the effect type F, ready to be used + * for service authentication + */ def generateToken(credentials: AwsCredentials): F[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 4bd0d54c2..5982f4f92 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -16,6 +16,46 @@ import fs2.{Chunk, Stream} import ldbc.amazon.identity.AwsCredentials +/** + * Implementation of AuthTokenGenerator for Amazon RDS IAM database authentication. + * + * This class generates authentication tokens that can be used to connect to Amazon RDS + * database instances using IAM credentials instead of traditional database passwords. + * The generated tokens are signed using AWS Signature Version 4 and are valid for + * 15 minutes from the time of generation. + * + * The tokens enable secure, temporary access to RDS databases based on IAM policies + * and eliminate the need to store database passwords in application code. + * + * @param hostname The RDS instance hostname or endpoint + * @param port The database port number (typically 3306 for MySQL) + * @param username The database username for which to generate the token + * @param region The AWS region where the RDS instance is located + * @param clock Clock instance for timestamp generation + * @tparam F The effect type that wraps the token generation operations + * + * @example + * {{{ + * import cats.effect.IO + * import ldbc.amazon.auth.token.RdsIamAuthTokenGenerator + * import ldbc.amazon.identity.AwsCredentials + * + * val generator = new RdsIamAuthTokenGenerator[IO]( + * hostname = "my-db-instance.region.rds.amazonaws.com", + * port = 3306, + * username = "db_user", + * region = "us-east-1" + * ) + * + * val credentials = AwsCredentials( + * accessKeyId = "AKIAIOSFODNN7EXAMPLE", + * secretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + * sessionToken = None + * ) + * + * val token: IO[String] = generator.generateToken(credentials) + * }}} + */ class RdsIamAuthTokenGenerator[F[_]: Hashing]( hostname: String, port: Int, @@ -23,11 +63,28 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( region: String )(using clock: Clock[F]) extends AuthTokenGenerator[F]: + /** AWS Signature Version 4 algorithm identifier */ private val ALGORITHM = "AWS4-HMAC-SHA256" + /** AWS service identifier for RDS database connections */ private val SERVICE = "rds-db" + /** Token expiration time in seconds (15 minutes) */ private val EXPIRES_SECONDS = 900 + /** AWS4 request terminator string */ private val TERMINATOR = "aws4_request" + /** + * Generates a signed authentication token for RDS IAM database authentication. + * + * This method creates a presigned URL-style token that can be used as a password + * when connecting to an RDS database instance. The token is signed using AWS + * Signature Version 4 and includes the necessary IAM credentials and metadata. + * + * The generated token follows the format required by MySQL's mysql_clear_password + * authentication plugin and can be used with SSL/TLS connections to RDS instances. + * + * @param credentials AWS credentials containing access key, secret key, and optional session token + * @return The signed authentication token that can be used as a database password + */ override def generateToken(credentials: AwsCredentials): F[String] = for now <- clock.realTimeInstant From db1822c27155846a4c0bdf45e09c05540822eb56 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 30 Nov 2025 03:02:39 +0900 Subject: [PATCH 075/215] Added scaladoc comment --- .../auth/token/RdsIamAuthTokenGenerator.scala | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 5982f4f92..72bc518e2 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -99,12 +99,34 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( signature <- calculateSignature(credentials.secretAccessKey, date, region, stringToSign) yield s"${config.hostname}:${config.port}/?$queryParams&X-Amz-Signature=$signature" + /** + * Formats an Instant to ISO 8601 basic format string required by AWS Signature Version 4. + * + * Converts a timestamp to the format "yyyyMMddTHHmmssZ" in UTC timezone, + * which is required for AWS authentication requests. + * + * @param instant The timestamp to format + * @return Formatted datetime string in AWS SigV4 format (e.g., "20230101T120000Z") + */ private def formatDateTime(instant: Instant): String = DateTimeFormatter .ofPattern("yyyyMMdd'T'HHmmss'Z'") .withZone(ZoneOffset.UTC) .format(instant) + /** + * Builds the query parameters for the RDS authentication request. + * + * Creates a URL-encoded query string containing all the required AWS Signature Version 4 + * parameters for RDS IAM authentication. The parameters are sorted alphabetically + * as required by the AWS signing process. + * + * @param credential The AWS credential string in format "accessKeyId/scope" + * @param dateTime The formatted timestamp for the request + * @param sessionToken The AWS session token (for temporary credentials) + * @param username The database username to authenticate as + * @return URL-encoded query string with all required authentication parameters + */ private def buildQueryParams( credential: String, dateTime: String, @@ -126,6 +148,17 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( .map { case (k, v) => s"${urlEncode(k)}=${urlEncode(v)}" } .mkString("&") + /** + * Constructs the canonical request string according to AWS Signature Version 4 specification. + * + * The canonical request is a standardized representation of the HTTP request that will be + * used for signature calculation. It includes the HTTP method, URI path, query string, + * headers, signed headers list, and payload hash. + * + * @param host The database host including port (e.g., "host.rds.amazonaws.com:3306") + * @param queryString The URL-encoded query parameters + * @return The canonical request string formatted according to AWS SigV4 requirements + */ private def buildCanonicalRequest(host: String, queryString: String): String = val method = "GET" val canonicalUri = "/" @@ -134,9 +167,28 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( val payloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string s"$method\n$canonicalUri\n$queryString\n$canonicalHeaders\n$signedHeaders\n$payloadHash" + /** + * Converts a byte array to lowercase hexadecimal string representation. + * + * Used for converting hash outputs to the hex format required by AWS signatures. + * Each byte is formatted as a two-character lowercase hex string. + * + * @param bytes The byte array to convert + * @return Lowercase hexadecimal string representation + */ private def bytesToHex(bytes: Array[Byte]): String = bytes.map("%02x".format(_)).mkString + /** + * Computes the SHA-256 hash of a string and returns it as a lowercase hex string. + * + * Uses fs2's streaming hash functionality to compute the SHA-256 hash of the input + * string encoded as UTF-8 bytes. The result is converted to lowercase hexadecimal + * format as required by AWS Signature Version 4. + * + * @param data The string to hash + * @return The SHA-256 hash as a lowercase hexadecimal string + */ private def sha256Hex(data: String): F[String] = Stream .chunk(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) @@ -145,6 +197,18 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( .lastOrError .map(hash => bytesToHex(hash.bytes.toArray)) + /** + * Creates the "String to Sign" for AWS Signature Version 4. + * + * The string to sign is a formatted string that combines the algorithm identifier, + * timestamp, credential scope, and canonical request hash. This string will be + * signed using the AWS signing key to produce the final signature. + * + * @param dateTime The ISO 8601 formatted timestamp + * @param credentialScope The credential scope (date/region/service/terminator) + * @param canonicalRequestHash The SHA-256 hash of the canonical request + * @return The formatted string to be signed + */ private def buildStringToSign( dateTime: String, credentialScope: String, @@ -152,6 +216,17 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( ): String = s"$Algorithm\n$dateTime\n$credentialScope\n$canonicalRequestHash" + /** + * Computes HMAC-SHA256 hash using the provided key and data. + * + * Uses fs2's streaming HMAC functionality to compute the HMAC-SHA256 of the input + * data using the provided key. This is a core cryptographic operation used in + * AWS Signature Version 4 key derivation and final signature calculation. + * + * @param key The cryptographic key as byte array + * @param data The data to authenticate as string (will be UTF-8 encoded) + * @return The HMAC-SHA256 result as byte array + */ private def hmacSha256(key: Array[Byte], data: String): F[Array[Byte]] = Hashing[F] .hmac(HashAlgorithm.SHA256, Chunk.array(key)) @@ -162,6 +237,26 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( yield hash.bytes.toArray } + /** + * Derives the AWS Signature Version 4 signing key and calculates the final signature. + * + * Implements the AWS SigV4 key derivation algorithm by computing a series of + * HMAC-SHA256 operations to derive the signing key, then signs the string-to-sign + * with that key. The process follows AWS specifications for signature calculation. + * + * Key derivation steps: + * 1. kDate = HMAC-SHA256("AWS4" + SecretKey, Date) + * 2. kRegion = HMAC-SHA256(kDate, Region) + * 3. kService = HMAC-SHA256(kRegion, Service) + * 4. kSigning = HMAC-SHA256(kService, "aws4_request") + * 5. signature = HMAC-SHA256(kSigning, StringToSign) + * + * @param secretKey The AWS secret access key + * @param date The date in YYYYMMDD format + * @param region The AWS region + * @param stringToSign The formatted string to be signed + * @return The final signature as lowercase hexadecimal string + */ private def calculateSignature( secretKey: String, date: String, @@ -176,6 +271,19 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( sig <- hmacSha256(kSigning, stringToSign) yield bytesToHex(sig) + /** + * URL encodes a string according to AWS requirements. + * + * Performs URL encoding with specific character replacements required by AWS: + * - Spaces are encoded as %20 (not +) + * - Asterisks (*) are encoded as %2A + * - Tildes (~) remain unencoded + * + * This encoding is required for query parameter values in AWS authentication. + * + * @param value The string to URL encode + * @return The URL encoded string with AWS-specific character handling + */ private def urlEncode(value: String): String = URLEncoder.encode(value, "UTF-8") .replace("+", "%20") From 16d64a2dee390849ea5219b55c55a9654e3e48e1 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 30 Nov 2025 03:03:21 +0900 Subject: [PATCH 076/215] Action sbt scalafmtAll --- .../auth/token/RdsIamAuthTokenGenerator.scala | 96 ++++++++++--------- 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 72bc518e2..2393de422 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -2,17 +2,18 @@ package ldbc.amazon.auth.token import java.net.URLEncoder import java.nio.charset.StandardCharsets +import java.time.{ Instant, ZoneOffset } import java.time.format.DateTimeFormatter -import java.time.{Instant, ZoneOffset} -import javax.crypto.Mac + import javax.crypto.spec.SecretKeySpec +import javax.crypto.Mac import cats.syntax.all.* -import cats.effect.kernel.{Clock, Sync} +import cats.effect.kernel.{ Clock, Sync } -import fs2.hashing.{HashAlgorithm, Hashing} -import fs2.{Chunk, Stream} +import fs2.{ Chunk, Stream } +import fs2.hashing.{ HashAlgorithm, Hashing } import ldbc.amazon.identity.AwsCredentials @@ -58,17 +59,21 @@ import ldbc.amazon.identity.AwsCredentials */ class RdsIamAuthTokenGenerator[F[_]: Hashing]( hostname: String, - port: Int, + port: Int, username: String, - region: String -)(using clock: Clock[F]) extends AuthTokenGenerator[F]: + region: String +)(using clock: Clock[F]) + extends AuthTokenGenerator[F]: /** AWS Signature Version 4 algorithm identifier */ private val ALGORITHM = "AWS4-HMAC-SHA256" + /** AWS service identifier for RDS database connections */ private val SERVICE = "rds-db" + /** Token expiration time in seconds (15 minutes) */ private val EXPIRES_SECONDS = 900 + /** AWS4 request terminator string */ private val TERMINATOR = "aws4_request" @@ -88,16 +93,16 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( override def generateToken(credentials: AwsCredentials): F[String] = for now <- clock.realTimeInstant - dateTime = formatDateTime(now) - date = dateTime.substring(0, 8) - credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" - credential = s"${credentials.accessKeyId}/$credentialScope" - queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) + dateTime = formatDateTime(now) + date = dateTime.substring(0, 8) + credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" + credential = s"${ credentials.accessKeyId }/$credentialScope" + queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) canonicalRequest = buildCanonicalRequest(s"$hostname:$port", queryParams) canonicalRequestHash <- sha256Hex(canonicalRequest) stringToSign = buildStringToSign(dateTime, credentialScope, canonicalRequestHash) signature <- calculateSignature(credentials.secretAccessKey, date, region, stringToSign) - yield s"${config.hostname}:${config.port}/?$queryParams&X-Amz-Signature=$signature" + yield s"${ config.hostname }:${ config.port }/?$queryParams&X-Amz-Signature=$signature" /** * Formats an Instant to ISO 8601 basic format string required by AWS Signature Version 4. @@ -128,24 +133,24 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( * @return URL-encoded query string with all required authentication parameters */ private def buildQueryParams( - credential: String, - dateTime: String, - sessionToken: String, - username: String - ): String = + credential: String, + dateTime: String, + sessionToken: String, + username: String + ): String = val params = List( - "Action" -> "connect", - "DBUser" -> username, - "X-Amz-Algorithm" -> ALGORITHM, - "X-Amz-Credential" -> credential, - "X-Amz-Date" -> dateTime, - "X-Amz-Expires" -> EXPIRESSECONDS.toString, + "Action" -> "connect", + "DBUser" -> username, + "X-Amz-Algorithm" -> ALGORITHM, + "X-Amz-Credential" -> credential, + "X-Amz-Date" -> dateTime, + "X-Amz-Expires" -> EXPIRESSECONDS.toString, "X-Amz-Security-Token" -> sessionToken, - "X-Amz-SignedHeaders" -> "host" + "X-Amz-SignedHeaders" -> "host" ) params .sortBy(_._1) - .map { case (k, v) => s"${urlEncode(k)}=${urlEncode(v)}" } + .map { case (k, v) => s"${ urlEncode(k) }=${ urlEncode(v) }" } .mkString("&") /** @@ -160,11 +165,11 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( * @return The canonical request string formatted according to AWS SigV4 requirements */ private def buildCanonicalRequest(host: String, queryString: String): String = - val method = "GET" - val canonicalUri = "/" + val method = "GET" + val canonicalUri = "/" val canonicalHeaders = s"host:$host\n" - val signedHeaders = "host" - val payloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + val signedHeaders = "host" + val payloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string s"$method\n$canonicalUri\n$queryString\n$canonicalHeaders\n$signedHeaders\n$payloadHash" /** @@ -210,10 +215,10 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( * @return The formatted string to be signed */ private def buildStringToSign( - dateTime: String, - credentialScope: String, - canonicalRequestHash: String - ): String = + dateTime: String, + credentialScope: String, + canonicalRequestHash: String + ): String = s"$Algorithm\n$dateTime\n$credentialScope\n$canonicalRequestHash" /** @@ -232,7 +237,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( .hmac(HashAlgorithm.SHA256, Chunk.array(key)) .use { hmac => for - _ <- hmac.update(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) + _ <- hmac.update(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) hash <- hmac.hash yield hash.bytes.toArray } @@ -258,17 +263,17 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( * @return The final signature as lowercase hexadecimal string */ private def calculateSignature( - secretKey: String, - date: String, - region: String, - stringToSign: String - ): F[String] = + secretKey: String, + date: String, + region: String, + stringToSign: String + ): F[String] = for - kDate <- hmacSha256(s"AWS4$secretKey".getBytes(StandardCharsets.UTF_8), date) - kRegion <- hmacSha256(kDate, region) + kDate <- hmacSha256(s"AWS4$secretKey".getBytes(StandardCharsets.UTF_8), date) + kRegion <- hmacSha256(kDate, region) kService <- hmacSha256(kRegion, Service) kSigning <- hmacSha256(kService, Terminator) - sig <- hmacSha256(kSigning, stringToSign) + sig <- hmacSha256(kSigning, stringToSign) yield bytesToHex(sig) /** @@ -285,7 +290,8 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( * @return The URL encoded string with AWS-specific character handling */ private def urlEncode(value: String): String = - URLEncoder.encode(value, "UTF-8") + URLEncoder + .encode(value, "UTF-8") .replace("+", "%20") .replace("*", "%2A") .replace("%7E", "~") From ec05f38c64b162a044f19968771b4906a888e55b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 30 Nov 2025 21:01:26 +0900 Subject: [PATCH 077/215] Change add effect type --- .../SystemPropertyCredentialsProvider.scala | 9 +++-- .../SystemSettingsCredentialsProvider.scala | 38 ++++++------------- .../identity/AwsCredentialsProvider.scala | 6 +-- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala index 6e7b67921..ea64997ef 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -6,6 +6,9 @@ package ldbc.amazon.auth.credentials +import cats.MonadThrow +import cats.effect.std.SystemProperties + import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider import ldbc.amazon.useragent.BusinessMetricFeatureId import ldbc.amazon.util.SdkSystemSetting @@ -14,12 +17,10 @@ import ldbc.amazon.util.SdkSystemSetting * [[AwsCredentialsProvider]] implementation that loads credentials from the aws.accessKeyId, aws.secretAccessKey and * aws.sessionToken system properties. */ -final class SystemPropertyCredentialsProvider extends SystemSettingsCredentialsProvider: +final class SystemPropertyCredentialsProvider[F[_]: SystemProperties: MonadThrow] extends SystemSettingsCredentialsProvider[F]: // Customers should be able to specify a credentials provider that only looks at the system properties, // but not the environment variables. For that reason, we're only checking the system properties here. - override def loadSetting(setting: SdkSystemSetting): Option[String] = Option( - System.getProperty(setting.systemProperty) - ) + override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = SystemProperties[F].get(setting.systemProperty) override def provider: String = BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala index 8a84d46fe..da743c294 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -6,38 +6,24 @@ package ldbc.amazon.auth.credentials.internal +import cats.MonadThrow +import cats.syntax.all.* + import ldbc.amazon.auth.credentials.* import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.util.SdkSystemSetting -trait SystemSettingsCredentialsProvider extends AwsCredentialsProvider: - - override def resolveCredentials(): Either[SdkClientException, AwsCredentials] = - val accessKeyOpt = loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.trim) - val secretKeyOpt = loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.trim) - val sessionTokenOpt = loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN).map(_.trim) - val accountId = loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.trim) +trait SystemSettingsCredentialsProvider[F[_]](using ev: MonadThrow[F]) extends AwsCredentialsProvider[F]: + override def resolveCredentials(): F[AwsCredentials] = for - accessKey <- accessKeyOpt match { - case Some(value) if value.nonEmpty => Right(value) - case _ => - Left( - new SdkClientException( - s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${ SdkSystemSetting.AWS_ACCESS_KEY_ID }) or system property (${ SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty })." - ) - ) - } - secretKey <- secretKeyOpt match { - case Some(value) if value.isEmpty => Right(value) - case _ => - Left( - new SdkClientException( - s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY }) or system property (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty })." - ) - ) - } + accessKeyOpt <- loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.map(_.trim)) + secretKeyOpt <- loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.map(_.trim)) + sessionTokenOpt <- loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN).map(_.map(_.trim)) + accountId <- loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.map(_.trim)) + accessKey <- ev.fromOption(accessKeyOpt, new SdkClientException(s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${ SdkSystemSetting.AWS_ACCESS_KEY_ID }) or system property (${ SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty }).")) + secretKey <- ev.fromOption(secretKeyOpt, new SdkClientException(s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY }) or system property (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty }).")) yield sessionTokenOpt match { case None => AwsBasicCredentials( @@ -60,6 +46,6 @@ trait SystemSettingsCredentialsProvider extends AwsCredentialsProvider: ) } - def loadSetting(setting: SdkSystemSetting): Option[String] + def loadSetting(setting: SdkSystemSetting): F[Option[String]] def provider: String diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala index 51c278c73..830eea731 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala @@ -6,9 +6,7 @@ package ldbc.amazon.identity -import ldbc.amazon.exception.SdkClientException - -trait AwsCredentialsProvider: +trait AwsCredentialsProvider[F[_]]: /** * Returns [[AwsCredentials]] that can be used to authorize an AWS request. Each implementation of AWSCredentialsProvider @@ -20,4 +18,4 @@ trait AwsCredentialsProvider: * * @return AwsCredentials which the caller can use to authorize an AWS request. */ - def resolveCredentials(): Either[SdkClientException, AwsCredentials] + def resolveCredentials(): F[AwsCredentials] From b1f20a550b66b38407352d3b8766981cc5f86f10 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 30 Nov 2025 21:01:41 +0900 Subject: [PATCH 078/215] Create EnvironmentVariableCredentialsProvider --- ...vironmentVariableCredentialsProvider.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala new file mode 100644 index 000000000..5b4273d64 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala @@ -0,0 +1,26 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.MonadThrow +import cats.effect.std.Env + +import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SdkSystemSetting + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and + * AWS_SESSION_TOKEN environment variables. + */ +final class EnvironmentVariableCredentialsProvider[F[_]: Env: MonadThrow] extends SystemSettingsCredentialsProvider[F]: + + // Customers should be able to specify a credentials provider that only looks at the system properties, + // but not the environment variables. For that reason, we're only checking the system properties here. + override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = Env[F].get(setting.toString) + + override def provider: String = BusinessMetricFeatureId.CREDENTIALS_ENV_VARS.code From 3045538db87ea38ecf742661ec3bb9c228f09278 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 30 Nov 2025 22:24:51 +0900 Subject: [PATCH 079/215] Create ProfileCredentialsProvider --- .../ProfileCredentialsProvider.scala | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala new file mode 100644 index 000000000..38f29e313 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala @@ -0,0 +1,148 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import cats.MonadThrow + +import cats.effect.* +import cats.effect.std.* +import cats.syntax.all.* + +import fs2.io.file.{Files, Path} + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* + +import ProfileCredentialsProvider.* + +final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent]( + profileName: String, + cacheRef: Ref[F, Option[(ProfileFile, AwsCredentials)]], + semaphore: Semaphore[F] +)(using ev: MonadThrow[F]) extends AwsCredentialsProvider[F]: + + override def resolveCredentials(): F[AwsCredentials] = + for + currentFile <- loadFile + cached <- cacheRef.get + credentials <- cached match + case Some((cachedFile, creds)) if cachedFile.lastModified == currentFile.lastModified => ev.pure(creds) + case _ => updateCredentials(currentFile) + yield credentials + + private def loadFile: F[ProfileFile] = + for + homeOpt <- SystemProperties[F].get("user.home") + home <- ev.fromOption(homeOpt, new SdkClientException("")) + credentialsPath = Path(s"$home/.aws/credentials") + exists <- Files[F].exists(credentialsPath) + _ <- ev.raiseUnless(exists)(new SdkClientException(s"File not found: $credentialsPath")) + content <- Files[F].readUtf8(credentialsPath).compile.string + lastMod <- Files[F].getLastModifiedTime(credentialsPath) + profiles <- ev.fromEither(parseProfiles(content)) + yield ProfileFile(profiles, Instant.ofEpochMilli(lastMod.toMillis)) + + private def parseProfiles(content: String): Either[Throwable, Map[String, Profile]] = + val profilePattern = "\\[(?:profile\\s+)?(.+)\\]".r + val propertyPattern = "^\\s*([^=]+)\\s*=\\s*(.+)\\s*$".r + + val lines = content.linesIterator.toList + + case class State( + currentProfile: Option[String] = None, + profiles: Map[String, Map[String, String]] = Map.empty + ) + + val finalState = lines.foldLeft(State()): (state, line) => + line.trim match + case _ if line.trim.isEmpty || line.trim.startsWith("#") => + state + case profilePattern(name) => + State(Some(name.trim), state.profiles + (name.trim -> Map.empty)) + case propertyPattern(key, value) => + state.currentProfile match + case Some(profile) => + val updatedProps = state.profiles.getOrElse(profile, Map.empty) + (key.trim -> value.trim) + state.copy(profiles = state.profiles + (profile -> updatedProps)) + case None => state + case _ => state + + Right(finalState.profiles.map: (name, props) => + name -> Profile(name, props) + ) + + private def parseStaticCredentials(props: Map[String, String]): Option[AwsCredentials] = + for + accessKeyId <- props.get("aws_access_key_id") + secretAccessKey <- props.get("aws_secret_access_key") + yield props.get("aws_session_token") match { + case Some(sessionToken) => + AwsSessionCredentials( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + sessionToken = sessionToken, + validateCredentials = false, + providerName = None, + accountId = None, + expirationTime = None + ) + case None => + AwsBasicCredentials( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + validateCredentials = false, + providerName = None, + accountId = None, + expirationTime = None + ) + } + + private def extractCredentials(profileFile: ProfileFile): F[AwsCredentials] = + for + profile <- ev.fromOption( + profileFile.profiles.get(profileName), + new SdkClientException("") + ) + credentials <- ev.fromOption( + parseStaticCredentials(profile.properties), + new SdkClientException("") + ) + yield credentials + + private def updateCredentials(profileFile: ProfileFile): F[AwsCredentials] = + semaphore.permit.use { _ => + for + cached <- cacheRef.get + credentials <- cached match + case Some((cachedFile, creds)) if cachedFile.lastModified == profileFile.lastModified => ev.pure(creds) + case _ => + for + creds <- extractCredentials(profileFile) + _ <- cacheRef.set(Some((profileFile, creds))) + yield creds + yield credentials + } + +object ProfileCredentialsProvider: + + final case class Profile( + name: String, + properties: Map[String, String] + ) + + final case class ProfileFile( + profiles: Map[String, Profile], + lastModified: Instant + ) + + def default[F[_]: SystemProperties: Files: Concurrent](profileName: String = "default"): F[ProfileCredentialsProvider[F]] = + for + cacheRef <- Ref.of[F, Option[(ProfileFile, AwsCredentials)]](None) + semaphore <- Semaphore[F](1) + yield new ProfileCredentialsProvider[F](profileName, cacheRef, semaphore) From 038556811e58dee63863c7e487ce20836bb3b8af Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 1 Dec 2025 01:00:21 +0900 Subject: [PATCH 080/215] Create CredentialsFetchError --- .../ldbc/amazon/exception/CredentialsFetchError.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala new file mode 100644 index 000000000..1601deebb --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala @@ -0,0 +1,10 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +class CredentialsFetchError(message: String) extends Exception: + override def getMessage: String = message From 4d835fdb92cbfc9f560aeebfd2e5c66425ae32d3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 1 Dec 2025 01:00:29 +0900 Subject: [PATCH 081/215] Create HttpClient --- .../scala/ldbc/amazon/client/HttpClient.scala | 13 ++ .../ldbc/amazon/client/HttpResponse.scala | 13 ++ .../ldbc/amazon/client/SimpleHttpClient.scala | 117 ++++++++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala new file mode 100644 index 000000000..66c0b5709 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala @@ -0,0 +1,13 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +trait HttpClient[F[_]]: + + def get(uri: URI, headers: Map[String, String]): F[HttpResponse] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala new file mode 100644 index 000000000..b22a66ff5 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala @@ -0,0 +1,13 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +final case class HttpResponse( + statusCode: Int, + headers: Map[String, String], + body: String + ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala new file mode 100644 index 000000000..5ce0a7e79 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -0,0 +1,117 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +import scala.concurrent.duration.* + +import cats.MonadThrow +import cats.syntax.all.* + +import cats.effect.* +import cats.effect.syntax.all.* + +import com.comcast.ip4s.* + +import fs2.* +import fs2.io.net.* + +import ldbc.amazon.exception.* + +class SimpleHttpClient[F[_]: Network: Async]( + connectTimeout: Duration, + readTimeout: Duration +)(using ev: MonadThrow[F]) extends HttpClient[F]: + + override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = + val host = uri.getHost + val port = if uri.getPort > 0 then uri.getPort else 80 + val path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + + for + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, path, headers) + yield response + + private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = + for + h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- ev.fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + yield SocketAddress(h, p) + + private def sendRequest( + socket: Socket[F], + host: String, + port: Int, + path: String, + headers: Map[String, String] + ): F[Unit] = + val hostHeader = if port == 80 then host else s"$host:$port" + val allHeaders = headers + ("Host" -> hostHeader) + ("Connection" -> "close") + + val requestLine = s"GET $path HTTP/1.1\r\n" + val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString + val request = requestLine + headerLines + "\r\n" + + Stream.emit(request) + .through(text.utf8.encode) + .through(socket.writes) + .compile + .drain + + private def parseStatusLine(line: String): F[Int] = + // "HTTP/1.1 200 OK" -> 200 + line.split(" ").toList match + case _ :: code :: _ => + code.toIntOption match + case Some(c) => ev.pure(c) + case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + + private def parseHeaderLine(line: String): Option[(String, String)] = + line.split(": ", 2).toList match + case key :: value :: Nil => Some(key -> value) + case _ => None + + private def parseHttpResponse(raw: String): F[HttpResponse] = + val lines = raw.split("\r\n").toList + + lines match + case statusLine :: rest => + parseStatusLine(statusLine).flatMap: statusCode => + val (headerLines, bodyLines) = rest.span(_.nonEmpty) + val headers = headerLines.flatMap(parseHeaderLine).toMap + val body = bodyLines.drop(1).mkString("\r\n") // drop empty line + + ev.pure(HttpResponse(statusCode, headers, body)) + case _ => + ev.raiseError(CredentialsFetchError("Empty response")) + + private def receiveResponse(socket: Socket[F]): F[HttpResponse] = + socket.reads + .through(text.utf8.decode) + .compile + .string + .flatMap(parseHttpResponse) + + private def makeRequest( + address: SocketAddress[Host], + host: String, + port: Int, + path: String, + headers: Map[String, String] + ): F[HttpResponse] = + Network[F].client(address) + .use { socket => + for + _ <- sendRequest(socket, host, port, path, headers) + response <- receiveResponse(socket) + yield response + } + .timeout(connectTimeout + readTimeout) From 41ded8551b71b261942772307199eab01cc5c9bd Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 1 Dec 2025 21:42:39 +0900 Subject: [PATCH 082/215] Change sessionToken to Option --- .../auth/token/AuthTokenGenerator.scala | 6 +++ .../auth/token/RdsIamAuthTokenGenerator.scala | 46 +++++++++++-------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala index a550b9f91..9fea79bb5 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.auth.token import ldbc.amazon.identity.AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 2393de422..07ff1c3fb 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -1,3 +1,9 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + package ldbc.amazon.auth.token import java.net.URLEncoder @@ -5,9 +11,6 @@ import java.nio.charset.StandardCharsets import java.time.{ Instant, ZoneOffset } import java.time.format.DateTimeFormatter -import javax.crypto.spec.SecretKeySpec -import javax.crypto.Mac - import cats.syntax.all.* import cats.effect.kernel.{ Clock, Sync } @@ -57,7 +60,7 @@ import ldbc.amazon.identity.AwsCredentials * val token: IO[String] = generator.generateToken(credentials) * }}} */ -class RdsIamAuthTokenGenerator[F[_]: Hashing]( +class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( hostname: String, port: Int, username: String, @@ -97,12 +100,13 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( date = dateTime.substring(0, 8) credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" credential = s"${ credentials.accessKeyId }/$credentialScope" - queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) + //queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) + queryParams = buildQueryParams(credential, dateTime, ???, username) canonicalRequest = buildCanonicalRequest(s"$hostname:$port", queryParams) canonicalRequestHash <- sha256Hex(canonicalRequest) stringToSign = buildStringToSign(dateTime, credentialScope, canonicalRequestHash) signature <- calculateSignature(credentials.secretAccessKey, date, region, stringToSign) - yield s"${ config.hostname }:${ config.port }/?$queryParams&X-Amz-Signature=$signature" + yield s"$hostname:$port/?$queryParams&X-Amz-Signature=$signature" /** * Formats an Instant to ISO 8601 basic format string required by AWS Signature Version 4. @@ -135,19 +139,23 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( private def buildQueryParams( credential: String, dateTime: String, - sessionToken: String, + sessionToken: Option[String], username: String ): String = - val params = List( - "Action" -> "connect", - "DBUser" -> username, - "X-Amz-Algorithm" -> ALGORITHM, - "X-Amz-Credential" -> credential, - "X-Amz-Date" -> dateTime, - "X-Amz-Expires" -> EXPIRESSECONDS.toString, - "X-Amz-Security-Token" -> sessionToken, - "X-Amz-SignedHeaders" -> "host" + val baseParams = List( + "Action" -> "connect", + "DBUser" -> username, + "X-Amz-Algorithm" -> ALGORITHM, + "X-Amz-Credential" -> credential, + "X-Amz-Date" -> dateTime, + "X-Amz-Expires" -> EXPIRES_SECONDS.toString, + "X-Amz-SignedHeaders" -> "host" ) + + val params = sessionToken match + case Some(token) => baseParams :+ ("X-Amz-Security-Token" -> token) + case None => baseParams + params .sortBy(_._1) .map { case (k, v) => s"${ urlEncode(k) }=${ urlEncode(v) }" } @@ -219,7 +227,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( credentialScope: String, canonicalRequestHash: String ): String = - s"$Algorithm\n$dateTime\n$credentialScope\n$canonicalRequestHash" + s"$ALGORITHM\n$dateTime\n$credentialScope\n$canonicalRequestHash" /** * Computes HMAC-SHA256 hash using the provided key and data. @@ -271,8 +279,8 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing]( for kDate <- hmacSha256(s"AWS4$secretKey".getBytes(StandardCharsets.UTF_8), date) kRegion <- hmacSha256(kDate, region) - kService <- hmacSha256(kRegion, Service) - kSigning <- hmacSha256(kService, Terminator) + kService <- hmacSha256(kRegion, SERVICE) + kSigning <- hmacSha256(kService, TERMINATOR) sig <- hmacSha256(kSigning, stringToSign) yield bytesToHex(sig) From 9185ca9d2251d612ab156c64c3c24ade0f49b914 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 00:53:10 +0900 Subject: [PATCH 083/215] Added scala xml --- build.sbt | 1 + 1 file changed, 1 insertion(+) diff --git a/build.sbt b/build.sbt index 6204b6daa..fa8cdc043 100644 --- a/build.sbt +++ b/build.sbt @@ -148,6 +148,7 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP .module("aws-authentication-plugin", "") .settings( libraryDependencies ++= Seq( + "org.scala-lang.modules" %%% "scala-xml" % "2.2.0", "co.fs2" %%% "fs2-core" % "3.12.2", "co.fs2" %%% "fs2-io" % "3.12.2", "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test From 4b51e46a0249ee049fafe2ac69baa65435bfdb5d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 00:53:22 +0900 Subject: [PATCH 084/215] Create StsClient --- .../scala/ldbc/amazon/client/StsClient.scala | 214 ++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala new file mode 100644 index 000000000..bc448363a --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -0,0 +1,214 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.{URI, URLEncoder} +import java.time.{Instant, ZoneOffset} +import java.time.format.DateTimeFormatter +import java.util.UUID + +import scala.xml.XML + +import cats.MonadThrow +import cats.effect.Concurrent +import cats.syntax.all.* + +import ldbc.amazon.exception.{SdkClientException, StsException} + +/** + * Trait for AWS STS (Security Token Service) client operations. + * + * This client implements the AWS STS AssumeRoleWithWebIdentity operation to exchange + * a Web Identity Token (JWT) for temporary AWS credentials. + */ +trait StsClient[F[_]]: + + /** + * Performs AssumeRoleWithWebIdentity operation. + * + * @param request The STS request parameters + * @param region The AWS region for STS endpoint + * @param httpClient HTTP client for making requests + * @return STS response with temporary credentials + */ + def assumeRoleWithWebIdentity( + request: StsClient.AssumeRoleWithWebIdentityRequest, + region: String, + httpClient: HttpClient[F] + ): F[StsClient.AssumeRoleWithWebIdentityResponse] + +object StsClient: + + /** + * AssumeRoleWithWebIdentity request parameters. + * + * @param roleArn The ARN of the IAM role to assume + * @param webIdentityToken The Web Identity Token (JWT) + * @param roleSessionName Optional session name for the assumed role session + * @param durationSeconds Optional duration of the session (default 3600 seconds) + */ + case class AssumeRoleWithWebIdentityRequest( + roleArn: String, + webIdentityToken: String, + roleSessionName: Option[String] = None, + durationSeconds: Option[Int] = None + ) + + /** + * STS response containing temporary credentials. + * + * @param accessKeyId The temporary access key ID + * @param secretAccessKey The temporary secret access key + * @param sessionToken The temporary session token + * @param expiration When the credentials expire + * @param assumedRoleArn The ARN of the assumed role + */ + case class AssumeRoleWithWebIdentityResponse( + accessKeyId: String, + secretAccessKey: String, + sessionToken: String, + expiration: Instant, + assumedRoleArn: String + ) + + private case class Impl[F[_]: Concurrent]() extends StsClient[F]: + + def assumeRoleWithWebIdentity( + request: AssumeRoleWithWebIdentityRequest, + region: String, + httpClient: HttpClient[F] + ): F[AssumeRoleWithWebIdentityResponse] = + for + timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) + sessionName = request.roleSessionName.getOrElse(s"ldbc-session-${UUID.randomUUID()}") + duration = request.durationSeconds.getOrElse(3600) + + // Build STS request + stsEndpoint = s"https://sts.$region.amazonaws.com/" + requestBody = buildRequestBody(request.copy( + roleSessionName = Some(sessionName), + durationSeconds = Some(duration) + )) + + // Make HTTP request + headers = Map( + "Content-Type" -> "application/x-amz-json-1.0", + "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRoleWithWebIdentity", + "X-Amz-Date" -> timestamp + ) + + response <- httpClient.get(URI.create(stsEndpoint), headers) + _ <- validateHttpResponse(response) + stsResponse <- parseAssumeRoleResponse(response.body) + yield stsResponse + + /** + * Creates a default implementation of StsClient. + * + * @tparam F The effect type + * @return A StsClient instance + */ + def default[F[_]: Concurrent]: StsClient[F] = Impl[F]() + + /** + * Builds the STS request body in AWS Query format. + */ + private def buildRequestBody(request: AssumeRoleWithWebIdentityRequest): String = + val params = Map( + "Action" -> "AssumeRoleWithWebIdentity", + "Version" -> "2011-06-15", + "RoleArn" -> request.roleArn, + "WebIdentityToken" -> request.webIdentityToken, + "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), + "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString + ) + + params.map { case (key, value) => + s"${URLEncoder.encode(key, "UTF-8")}=${URLEncoder.encode(value, "UTF-8")}" + }.mkString("&") + + /** + * Gets current timestamp in AWS format. + */ + private def getCurrentTimestamp(): Either[Throwable, String] = + try { + Right(DateTimeFormatter + .ofPattern("yyyyMMdd'T'HHmmss'Z'") + .withZone(ZoneOffset.UTC) + .format(Instant.now())) + } catch { + case ex: Exception => Left(new SdkClientException("Failed to generate timestamp")) + } + + /** + * Validates HTTP response status. + */ + private def validateHttpResponse[F[_]: MonadThrow](response: HttpResponse): F[Unit] = + if (response.statusCode >= 200 && response.statusCode < 300) { + MonadThrow[F].unit + } else { + MonadThrow[F].raiseError(new StsException( + s"STS request failed with status ${response.statusCode}: ${response.body}" + )) + } + + /** + * Parses STS XML response to extract credentials. + * + * Expected XML structure: + * ```xml + * + * + * + * ASIA... + * ... + * ... + * 2023-12-01T12:00:00Z + * + * + * arn:aws:sts::123456789012:assumed-role/MyRole/MySession + * AROA....:MySession + * + * + * + * ``` + */ + private def parseAssumeRoleResponse[F[_]: MonadThrow](xmlBody: String): F[AssumeRoleWithWebIdentityResponse] = + MonadThrow[F].catchNonFatal { + val xml = XML.loadString(xmlBody) + + // Extract credentials + val credentials = (xml \ "AssumeRoleWithWebIdentityResult" \ "Credentials").head + val accessKeyId = (credentials \ "AccessKeyId").text.trim + val secretAccessKey = (credentials \ "SecretAccessKey").text.trim + val sessionToken = (credentials \ "SessionToken").text.trim + val expirationStr = (credentials \ "Expiration").text.trim + + // Extract assumed role information + val assumedRoleUser = (xml \ "AssumeRoleWithWebIdentityResult" \ "AssumedRoleUser").head + val assumedRoleArn = (assumedRoleUser \ "Arn").text.trim + + // Parse expiration time + val expiration = Instant.parse(expirationStr) + + // Validate required fields + if (accessKeyId.isEmpty) throw new StsException("AccessKeyId not found in STS response") + if (secretAccessKey.isEmpty) throw new StsException("SecretAccessKey not found in STS response") + if (sessionToken.isEmpty) throw new StsException("SessionToken not found in STS response") + if (assumedRoleArn.isEmpty) throw new StsException("AssumedRoleArn not found in STS response") + + AssumeRoleWithWebIdentityResponse( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + sessionToken = sessionToken, + expiration = expiration, + assumedRoleArn = assumedRoleArn + ) + }.adaptError { case ex => + new StsException(s"Failed to parse STS response: ${ex.getMessage}") + } + From f1ef8a116deaf3b028cec0024fce69ee598a07b7 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 00:53:41 +0900 Subject: [PATCH 085/215] Create exception --- .../exception/InvalidTokenException.scala | 35 ++++++++++++ .../amazon/exception/SdkClientException.scala | 2 +- .../ldbc/amazon/exception/StsException.scala | 46 ++++++++++++++++ .../exception/TokenFileAccessException.scala | 34 ++++++++++++ .../TokenFileNotFoundException.scala | 53 +++++++++++++++++++ .../exception/WebIdentityTokenException.scala | 29 ++++++++++ 6 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala new file mode 100644 index 000000000..3beeff31d --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala @@ -0,0 +1,35 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Thrown when the Web Identity Token is invalid or malformed. + * + * This exception is typically thrown when: + * - The JWT token does not have the correct format (header.payload.signature) + * - The token file is empty or contains only whitespace + * - The token contains invalid characters or encoding + * - The JWT structure is corrupted + * + * Valid JWT token format: + * ``` + * eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature + * ``` + * + * @param message The detailed error message + * @param cause The underlying cause of the exception (optional) + */ +class InvalidTokenException( + message: String, + cause: Option[Throwable] = None +) extends WebIdentityTokenException(message, cause): + + /** + * Constructor with cause + */ + def this(message: String, cause: Throwable) = + this(message, Some(cause)) \ No newline at end of file diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala index a38ea56dc..5892e222a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala @@ -8,4 +8,4 @@ package ldbc.amazon.exception class SdkClientException(message: String) extends RuntimeException: - final override def getMessage: String = message + override def getMessage: String = message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala new file mode 100644 index 000000000..bb4481524 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +import scala.util.control.NoStackTrace + +/** + * Exception thrown when AWS STS (Security Token Service) operations fail. + * + * This exception is typically thrown when: + * - STS AssumeRoleWithWebIdentity request fails (HTTP 400/403/500 errors) + * - Invalid or expired Web Identity Token + * - IAM role doesn't exist or lacks required trust relationships + * - Network connectivity issues to STS endpoints + * - STS response parsing failures + * + * Common STS error scenarios: + * - HTTP 403: Token expired or invalid trust policy + * - HTTP 400: Malformed request or invalid parameters + * - HTTP 500: STS service temporary issues + * + * Example STS endpoint: + * ``` + * https://sts.us-east-1.amazonaws.com/ + * ``` + * + * @param message The detailed error message including STS response details + * @param cause The underlying cause of the exception (optional) + */ +class StsException( + message: String, + cause: Option[Throwable] = None +) extends SdkClientException(message) with NoStackTrace: + + // Set the cause if provided + cause.foreach(initCause) + + /** + * Constructor with cause + */ + def this(message: String, cause: Throwable) = + this(message, Some(cause)) \ No newline at end of file diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala new file mode 100644 index 000000000..54b8c4416 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala @@ -0,0 +1,34 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Thrown when the Web Identity Token file cannot be accessed due to permissions or I/O issues. + * + * This exception is typically thrown when: + * - The file exists but cannot be read due to insufficient permissions + * - I/O errors occur while reading the token file + * - File system issues prevent access to the token file + * + * Common solutions: + * - Verify file permissions (typically 600 for token files) + * - Check if the process has read access to the file + * - Ensure the file system is mounted and accessible + * + * @param message The detailed error message + * @param cause The underlying cause of the exception (optional) + */ +class TokenFileAccessException( + message: String, + cause: Option[Throwable] = None +) extends WebIdentityTokenException(message, cause): + + /** + * Constructor with cause + */ + def this(message: String, cause: Throwable) = + this(message, Some(cause)) \ No newline at end of file diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala new file mode 100644 index 000000000..ea57c510b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala @@ -0,0 +1,53 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Thrown when the Web Identity Token file cannot be found. + * + * This exception is typically thrown when: + * - The file path specified in AWS_WEB_IDENTITY_TOKEN_FILE does not exist + * - The token file has been moved or deleted + * - The file path is incorrectly configured + * + * Example usage in EKS/IRSA: + * ``` + * AWS_WEB_IDENTITY_TOKEN_FILE=/var/run/secrets/eks.amazonaws.com/serviceaccount/token + * ``` + * + * @param message The detailed error message + * @param tokenFilePath The path to the missing token file (optional) + * @param cause The underlying cause of the exception (optional) + */ +class TokenFileNotFoundException( + message: String, + tokenFilePath: Option[String] = None, + cause: Option[Throwable] = None +) extends WebIdentityTokenException(message, cause): + + /** + * Constructor with cause only + */ + def this(message: String, cause: Throwable) = + this(message, None, Some(cause)) + + /** + * Constructor with token file path only + */ + def this(message: String, tokenFilePath: String) = + this(message, Some(tokenFilePath), None) + + /** + * Constructor with both token file path and cause + */ + def this(message: String, tokenFilePath: String, cause: Throwable) = + this(message, Some(tokenFilePath), Some(cause)) + + override def getMessage: String = + tokenFilePath match + case Some(path) => s"$message (Token file path: $path)" + case None => message \ No newline at end of file diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala new file mode 100644 index 000000000..fcb41fdb4 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +import scala.util.control.NoStackTrace + +/** + * Base exception for Web Identity Token operations. + * + * @param message The error message + * @param cause The underlying cause (optional) + */ +abstract class WebIdentityTokenException( + message: String, + cause: Option[Throwable] = None +) extends SdkClientException(message) with NoStackTrace: + + // Set the cause if provided + cause.foreach(initCause) + + /** + * Constructor with cause + */ + def this(message: String, cause: Throwable) = + this(message, Some(cause)) \ No newline at end of file From c8d4a30d166cfcd1c8ee14b31a0d00dacb614d0e Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 00:54:06 +0900 Subject: [PATCH 086/215] Create WebIdentityCredentialsUtils --- .../WebIdentityCredentialsUtils.scala | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala new file mode 100644 index 000000000..bfe12d4a1 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -0,0 +1,164 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials.internal + +import cats.MonadThrow +import cats.effect.Concurrent +import cats.syntax.all.* + +import fs2.io.file.{Files, Path} + +import ldbc.amazon.auth.credentials.* +import ldbc.amazon.client.{HttpClient, StsClient} +import ldbc.amazon.exception.{TokenFileNotFoundException, InvalidTokenException} +import ldbc.amazon.identity.AwsCredentials + +/** + * Trait for handling Web Identity Token credentials and STS AssumeRoleWithWebIdentity operations. + * + * This trait provides functionality to: + * - Read JWT tokens from file system + * - Parse STS AssumeRoleWithWebIdentity responses + * - Generate AWS credentials from temporary session tokens + */ +trait WebIdentityCredentialsUtils[F[_]]: + + /** + * Assumes an IAM role using Web Identity Token and returns AWS credentials. + * + * This method performs the STS AssumeRoleWithWebIdentity operation, exchanging + * a Web Identity Token (JWT) for temporary AWS credentials. + * + * @param config The Web Identity Token configuration + * @param region The AWS region for STS endpoint (default: us-east-1) + * @param httpClient HTTP client for making STS requests + * @return AWS credentials with session token + */ + def assumeRoleWithWebIdentity( + config: WebIdentityTokenCredentialProperties, + region: String, + httpClient: HttpClient[F] + ): F[AwsCredentials] + +object WebIdentityCredentialsUtils: + + private case class Impl[F[_]: Files: Concurrent]( + stsClient: StsClient[F] + ) extends WebIdentityCredentialsUtils[F]: + + def assumeRoleWithWebIdentity( + config: WebIdentityTokenCredentialProperties, + region: String, + httpClient: HttpClient[F] + ): F[AwsCredentials] = + for + token <- readTokenFromFile(config.webIdentityTokenFile) + _ <- validateToken(token) + stsRequest = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = config.roleArn, + webIdentityToken = token, + roleSessionName = config.roleSessionName + ) + stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest, region, httpClient) + credentials = convertStsResponseToCredentials(stsResponse, config) + yield credentials + + /** + * Reads JWT token from the specified file path. + * + * @param tokenFilePath Path to the JWT token file + * @return The JWT token content as a string + */ + private def readTokenFromFile(tokenFilePath: Path): F[String] = + for + exists <- Files[F].exists(tokenFilePath) + _ <- Concurrent[F].raiseUnless(exists)( + new TokenFileNotFoundException(s"Web Identity Token file not found: $tokenFilePath") + ) + token <- Files[F].readUtf8(tokenFilePath).compile.string.map(_.trim) + _ <- Concurrent[F].raiseWhen(token.isEmpty)( + new InvalidTokenException(s"Web Identity Token file is empty: $tokenFilePath") + ) + yield token + + /** + * Validates the JWT token format. + * + * This is a basic validation to ensure the token has a JWT-like structure. + * A full implementation would include signature verification and claims validation. + * + * @param token The JWT token to validate + * @return Unit if token is valid + */ + private def validateToken(token: String): F[Unit] = + MonadThrow[F].fromEither { + // Basic JWT format validation (header.payload.signature) + val parts = token.split("\\.") + if (parts.length != 3) { + Left(new InvalidTokenException(s"Invalid JWT token format. Expected 3 parts, got ${parts.length}")) + } else if (parts.exists(_.isEmpty)) { + Left(new InvalidTokenException("JWT token contains empty parts")) + } else { + Right(()) + } + } + + /** + * Converts STS response to AWS credentials. + * + * @param stsResponse The STS AssumeRoleWithWebIdentity response + * @param config The Web Identity Token configuration + * @return AWS session credentials + */ + private def convertStsResponseToCredentials( + stsResponse: StsClient.AssumeRoleWithWebIdentityResponse, + config: WebIdentityTokenCredentialProperties + ): AwsCredentials = + AwsSessionCredentials( + accessKeyId = stsResponse.accessKeyId, + secretAccessKey = stsResponse.secretAccessKey, + sessionToken = stsResponse.sessionToken, + validateCredentials = false, + providerName = Some(config.providerName), + accountId = extractAccountIdFromArn(stsResponse.assumedRoleArn), + expirationTime = Some(stsResponse.expiration) + ) + + /** + * Extracts AWS account ID from ARN. + * + * @param arn The AWS ARN (e.g., arn:aws:sts::123456789012:assumed-role/MyRole/MySession) + * @return Optional account ID + */ + private def extractAccountIdFromArn(arn: String): Option[String] = + // ARN format: arn:aws:sts::ACCOUNT_ID:assumed-role/ROLE_NAME/SESSION_NAME + val arnParts = arn.split(":") + if (arnParts.length >= 5) { + Some(arnParts(4)) + } else { + None + } + + /** + * Creates a default implementation of WebIdentityCredentialsUtils. + * + * @tparam F The effect type + * @return A WebIdentityCredentialsUtils instance + */ + def default[F[_]: Files: Concurrent]: WebIdentityCredentialsUtils[F] = + val stsClient = StsClient.default[F] + Impl[F](stsClient) + + /** + * Creates a WebIdentityCredentialsUtils with custom StsClient. + * + * @param stsClient Custom STS client implementation + * @tparam F The effect type + * @return A WebIdentityCredentialsUtils instance + */ + def create[F[_]: Files: Concurrent](stsClient: StsClient[F]): WebIdentityCredentialsUtils[F] = + Impl[F](stsClient) From 64a7e408ae13e959dc9b33e42bd33db0d406f4df Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 00:54:25 +0900 Subject: [PATCH 087/215] Create WebIdentityTokenFileCredentialsProvider --- ...IdentityTokenFileCredentialsProvider.scala | 207 ++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala new file mode 100644 index 000000000..926443ac4 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -0,0 +1,207 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import scala.concurrent.duration.* + +import cats.effect.std.{Env, SystemProperties} +import cats.effect.* +import cats.syntax.all.* + +import fs2.io.file.{Files, Path} +import fs2.io.net.* + +import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils +import ldbc.amazon.client.{HttpClient, SimpleHttpClient} +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SdkSystemSetting + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from AWS Web Identity Token File. + * + * This provider is primarily used in environments that support OIDC (OpenID Connect) authentication, + * such as Kubernetes with service accounts, EKS with IAM Roles for Service Accounts (IRSA), + * or other containerized environments that provide JWT tokens for AWS authentication. + * + * The provider reads configuration from: + * - Environment variables: AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN, AWS_ROLE_SESSION_NAME + * - System properties: aws.webIdentityTokenFile, aws.roleArn, aws.roleSessionName + * - AWS config file: web_identity_token_file, role_arn, role_session_name (profile-specific) + * + * Authentication flow: + * 1. Read JWT token from the specified file path + * 2. Use the token to assume the specified IAM role via STS AssumeRoleWithWebIdentity + * 3. Return temporary AWS credentials (access key, secret key, session token) + * + * The JWT token file is typically mounted by container orchestration systems and rotated automatically. + * + * Example Kubernetes configuration with IRSA: + * {{{ + * apiVersion: v1 + * kind: ServiceAccount + * metadata: + * annotations: + * eks.amazonaws.com/role-arn: arn:aws:iam::123456789012:role/my-role + * }}} + * + * This will automatically set: + * - AWS_ROLE_ARN=arn:aws:iam::123456789012:role/my-role + * - AWS_WEB_IDENTITY_TOKEN_FILE=/var/run/secrets/eks.amazonaws.com/serviceaccount/token + * + * @param webIdentityUtils Web Identity credentials utility for STS operations + * @param httpClient HTTP client for STS requests + * @param region AWS region for STS endpoint + * @tparam F The effect type + */ +final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: Concurrent]( + webIdentityUtils: WebIdentityCredentialsUtils[F], + httpClient: HttpClient[F], + region: String = "us-east-1" +) extends AwsCredentialsProvider[F]: + + override def resolveCredentials(): F[AwsCredentials] = + for + config <- loadWebIdentityConfig() + credentials <- config match { + case None => + Concurrent[F].raiseError(new SdkClientException( + "Unable to load Web Identity Token credentials. " + + "Required environment variables (AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN) or " + + "system properties (aws.webIdentityTokenFile, aws.roleArn) are not set." + )) + case Some(webIdentityConfig) => + webIdentityUtils.assumeRoleWithWebIdentity(webIdentityConfig, region, httpClient) + } + yield credentials + + private def loadWebIdentityConfig(): F[Option[WebIdentityTokenCredentialProperties]] = + for + tokenFilePath <- loadTokenFilePath() + roleArn <- loadRoleArn() + roleSessionName <- loadRoleSessionName() + yield (tokenFilePath, roleArn) match { + case (Some(tokenFile), Some(arn)) => + Some(WebIdentityTokenCredentialProperties( + webIdentityTokenFile = Path(tokenFile), + roleArn = arn, + roleSessionName = roleSessionName, + providerName = BusinessMetricFeatureId.CREDENTIALS_WEB_IDENTITY_TOKEN.code + )) + case _ => None + } + + private def loadTokenFilePath(): F[Option[String]] = + for + envValue <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") + sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) + yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + + private def loadRoleArn(): F[Option[String]] = + for + envValue <- Env[F].get("AWS_ROLE_ARN") + sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) + yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + + private def loadRoleSessionName(): F[Option[String]] = + for + envValue <- Env[F].get("AWS_ROLE_SESSION_NAME") + sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_SESSION_NAME.systemProperty) + yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + +/** + * Configuration properties for Web Identity Token credentials. + * + * @param webIdentityTokenFile Path to the JWT token file + * @param roleArn The ARN of the IAM role to assume + * @param roleSessionName Optional session name for the assumed role session + * @param providerName Provider identifier for logging and metrics + */ +case class WebIdentityTokenCredentialProperties( + webIdentityTokenFile: Path, + roleArn: String, + roleSessionName: Option[String], + providerName: String +) + +object WebIdentityTokenFileCredentialsProvider: + + /** + * Creates a new Web Identity Token File credentials provider with default settings. + * + * @param region The AWS region for STS endpoint (default: us-east-1) + * @tparam F The effect type + * @return A new WebIdentityTokenFileCredentialsProvider instance + */ + def apply[F[_]: Files: Env: SystemProperties: Network: Async]( + region: String = "us-east-1" + ): F[WebIdentityTokenFileCredentialsProvider[F]] = + for + httpClient <- createDefaultHttpClient[F]() + webIdentityUtils = WebIdentityCredentialsUtils.default[F] + yield new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) + + /** + * Creates a new Web Identity Token File credentials provider with custom HTTP client. + * + * @param httpClient The HTTP client for STS requests + * @param region The AWS region for STS endpoint + * @tparam F The effect type + * @return A new WebIdentityTokenFileCredentialsProvider instance + */ + def default[F[_]: Env: SystemProperties: Files: Concurrent]( + httpClient: HttpClient[F], + region: String = "us-east-1" + ): WebIdentityTokenFileCredentialsProvider[F] = + val webIdentityUtils = WebIdentityCredentialsUtils.default[F] + new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) + + /** + * Creates a new Web Identity Token File credentials provider with custom WebIdentityCredentialsUtils. + * + * @param webIdentityUtils Custom Web Identity credentials utility + * @param httpClient The HTTP client for STS requests + * @param region The AWS region for STS endpoint + * @tparam F The effect type + * @return A new WebIdentityTokenFileCredentialsProvider instance + */ + def create[F[_]: Env: SystemProperties: Concurrent]( + webIdentityUtils: WebIdentityCredentialsUtils[F], + httpClient: HttpClient[F], + region: String = "us-east-1" + ): WebIdentityTokenFileCredentialsProvider[F] = + new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) + + /** + * Creates a default HTTP client for STS operations. + * + * @tparam F The effect type + * @return A configured HTTP client + */ + private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = + Async[F].pure(new SimpleHttpClient[F]( + connectTimeout = 30.seconds, + readTimeout = 30.seconds + )) + + /** + * Checks if Web Identity Token authentication is available by verifying + * the presence of required environment variables or system properties. + * + * @tparam F The effect type + * @return true if Web Identity Token authentication is properly configured + */ + def isAvailable[F[_]: Env: SystemProperties: Concurrent](): F[Boolean] = + for + tokenFile <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") + .flatMap(envValue => SystemProperties[F].get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) + .map(envValue.orElse(_))) + roleArn <- Env[F].get("AWS_ROLE_ARN") + .flatMap(envValue => SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) + .map(envValue.orElse(_))) + yield tokenFile.exists(_.trim.nonEmpty) && roleArn.exists(_.trim.nonEmpty) From 100fe99afb943aab6a9a3589bce15ae340142262 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 01:07:52 +0900 Subject: [PATCH 088/215] Added CREDENTIALS_WEB_IDENTITY_TOKEN --- .../scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala index b245509a8..14768d02e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -60,4 +60,5 @@ enum BusinessMetricFeatureId(val code: String): case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") + case CREDENTIALS_WEB_IDENTITY_TOKEN extends BusinessMetricFeatureId("k") case UNKNOWN extends BusinessMetricFeatureId("Unknown") From 7db27381e2305330a1fd7d60aa034c460a3e5f02 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 01:10:29 +0900 Subject: [PATCH 089/215] Action sbt scalafmtAll --- .../DefaultCredentialsProviderChain.scala | 101 ++++++++++ ...vironmentVariableCredentialsProvider.scala | 1 + .../ProfileCredentialsProvider.scala | 84 +++++---- .../SystemPropertyCredentialsProvider.scala | 7 +- ...IdentityTokenFileCredentialsProvider.scala | 101 +++++----- .../SystemSettingsCredentialsProvider.scala | 16 +- .../WebIdentityCredentialsUtils.scala | 63 +++---- .../auth/token/RdsIamAuthTokenGenerator.scala | 24 +-- .../scala/ldbc/amazon/client/HttpClient.scala | 2 +- .../ldbc/amazon/client/HttpResponse.scala | 8 +- .../ldbc/amazon/client/SimpleHttpClient.scala | 55 +++--- .../scala/ldbc/amazon/client/StsClient.scala | 174 ++++++++++-------- .../exception/InvalidTokenException.scala | 6 +- .../ldbc/amazon/exception/StsException.scala | 9 +- .../exception/TokenFileAccessException.scala | 6 +- .../TokenFileNotFoundException.scala | 8 +- .../exception/WebIdentityTokenException.scala | 7 +- 17 files changed, 412 insertions(+), 260 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala new file mode 100644 index 000000000..7f1ac9527 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -0,0 +1,101 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +/* +package ldbc.amazon.auth.credentials + +import cats.MonadThrow +import cats.effect.std.{Env, SystemProperties} +import cats.effect.{Concurrent, Network} +import cats.syntax.all.* + +import fs2.io.file.Files + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* + +/** + * Default AWS credentials provider chain that matches AWS SDK v2 behavior. + * + * The provider chain attempts to resolve credentials from the following sources in order: + * 1. Java system properties (aws.accessKeyId and aws.secretAccessKey) + * 2. Environment variables (AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY) + * 3. Web identity token file (AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN) + * 4. AWS credentials profile files (~/.aws/credentials) + * 5. Container credentials provider (ECS/EKS) + * 6. Instance profile credentials provider (EC2) + * + * The first provider in the chain that successfully provides credentials will be used, + * and the search will stop there. If no provider can provide credentials, an exception is thrown. + * + * Usage example: + * ```scala + * import cats.effect.IO + * + * val credentialsProvider = DefaultCredentialsProviderChain[IO]() + * val credentials: IO[AwsCredentials] = credentialsProvider.resolveCredentials() + * ``` + * + * @tparam F The effect type + */ +class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent] + extends AwsCredentialsProvider[F]: + + private lazy val providers: F[List[AwsCredentialsProvider[F]]] = + for + webIdentityProvider <- WebIdentityTokenFileCredentialsProvider[F]() + profileProvider <- ProfileCredentialsProvider.default[F]() + yield List( + new SystemPropertyCredentialsProvider[F](), + new EnvironmentVariableCredentialsProvider[F](), + webIdentityProvider, + profileProvider + // ContainerCredentialsProvider - TODO: implement + // InstanceProfileCredentialsProvider - TODO: implement + ) + + override def resolveCredentials(): F[AwsCredentials] = + for + providerList <- providers + credentials <- tryProvidersInOrder(providerList, Nil) + yield credentials + + private def tryProvidersInOrder( + providers: List[AwsCredentialsProvider[F]], + exceptions: List[String] + ): F[AwsCredentials] = + providers match + case Nil => + MonadThrow[F].raiseError(new SdkClientException( + s"Unable to load AWS credentials from any provider in the chain: ${exceptions.mkString(", ")}" + )) + + case provider :: remainingProviders => + provider.resolveCredentials().recoverWith { ex => + val errorMsg = s"${provider.getClass.getSimpleName}: ${ex.getMessage}" + tryProvidersInOrder(remainingProviders, exceptions :+ errorMsg) + } + +object DefaultCredentialsProviderChain: + + /** + * Creates a new default credentials provider chain. + * + * @tparam F The effect type + * @return A new DefaultCredentialsProviderChain instance + */ + def apply[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent](): DefaultCredentialsProviderChain[F] = + new DefaultCredentialsProviderChain[F]() + + /** + * Convenience method to resolve credentials using the default chain. + * + * @tparam F The effect type + * @return AWS credentials from the first successful provider + */ + def resolveCredentials[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent](): F[AwsCredentials] = + DefaultCredentialsProviderChain[F]().resolveCredentials() + */ \ No newline at end of file diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala index 5b4273d64..ffa36a783 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala @@ -7,6 +7,7 @@ package ldbc.amazon.auth.credentials import cats.MonadThrow + import cats.effect.std.Env import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala index 38f29e313..9a0fb1cc9 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala @@ -8,13 +8,13 @@ package ldbc.amazon.auth.credentials import java.time.Instant +import cats.syntax.all.* import cats.MonadThrow import cats.effect.* import cats.effect.std.* -import cats.syntax.all.* -import fs2.io.file.{Files, Path} +import fs2.io.file.{ Files, Path } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* @@ -22,42 +22,44 @@ import ldbc.amazon.identity.* import ProfileCredentialsProvider.* final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent]( - profileName: String, - cacheRef: Ref[F, Option[(ProfileFile, AwsCredentials)]], - semaphore: Semaphore[F] -)(using ev: MonadThrow[F]) extends AwsCredentialsProvider[F]: + profileName: String, + cacheRef: Ref[F, Option[(ProfileFile, AwsCredentials)]], + semaphore: Semaphore[F] +)(using ev: MonadThrow[F]) + extends AwsCredentialsProvider[F]: override def resolveCredentials(): F[AwsCredentials] = for currentFile <- loadFile cached <- cacheRef.get credentials <- cached match - case Some((cachedFile, creds)) if cachedFile.lastModified == currentFile.lastModified => ev.pure(creds) - case _ => updateCredentials(currentFile) + case Some((cachedFile, creds)) if cachedFile.lastModified == currentFile.lastModified => + ev.pure(creds) + case _ => updateCredentials(currentFile) yield credentials - + private def loadFile: F[ProfileFile] = for homeOpt <- SystemProperties[F].get("user.home") - home <- ev.fromOption(homeOpt, new SdkClientException("")) + home <- ev.fromOption(homeOpt, new SdkClientException("")) credentialsPath = Path(s"$home/.aws/credentials") exists <- Files[F].exists(credentialsPath) - _ <- ev.raiseUnless(exists)(new SdkClientException(s"File not found: $credentialsPath")) + _ <- ev.raiseUnless(exists)(new SdkClientException(s"File not found: $credentialsPath")) content <- Files[F].readUtf8(credentialsPath).compile.string lastMod <- Files[F].getLastModifiedTime(credentialsPath) profiles <- ev.fromEither(parseProfiles(content)) yield ProfileFile(profiles, Instant.ofEpochMilli(lastMod.toMillis)) private def parseProfiles(content: String): Either[Throwable, Map[String, Profile]] = - val profilePattern = "\\[(?:profile\\s+)?(.+)\\]".r + val profilePattern = "\\[(?:profile\\s+)?(.+)\\]".r val propertyPattern = "^\\s*([^=]+)\\s*=\\s*(.+)\\s*$".r val lines = content.linesIterator.toList case class State( - currentProfile: Option[String] = None, - profiles: Map[String, Map[String, String]] = Map.empty - ) + currentProfile: Option[String] = None, + profiles: Map[String, Map[String, String]] = Map.empty + ) val finalState = lines.foldLeft(State()): (state, line) => line.trim match @@ -74,12 +76,11 @@ final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent case _ => state Right(finalState.profiles.map: (name, props) => - name -> Profile(name, props) - ) + name -> Profile(name, props)) private def parseStaticCredentials(props: Map[String, String]): Option[AwsCredentials] = for - accessKeyId <- props.get("aws_access_key_id") + accessKeyId <- props.get("aws_access_key_id") secretAccessKey <- props.get("aws_secret_access_key") yield props.get("aws_session_token") match { case Some(sessionToken) => @@ -106,43 +107,46 @@ final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent private def extractCredentials(profileFile: ProfileFile): F[AwsCredentials] = for profile <- ev.fromOption( - profileFile.profiles.get(profileName), - new SdkClientException("") - ) + profileFile.profiles.get(profileName), + new SdkClientException("") + ) credentials <- ev.fromOption( - parseStaticCredentials(profile.properties), - new SdkClientException("") - ) + parseStaticCredentials(profile.properties), + new SdkClientException("") + ) yield credentials private def updateCredentials(profileFile: ProfileFile): F[AwsCredentials] = semaphore.permit.use { _ => for - cached <- cacheRef.get + cached <- cacheRef.get credentials <- cached match - case Some((cachedFile, creds)) if cachedFile.lastModified == profileFile.lastModified => ev.pure(creds) - case _ => - for - creds <- extractCredentials(profileFile) - _ <- cacheRef.set(Some((profileFile, creds))) - yield creds + case Some((cachedFile, creds)) if cachedFile.lastModified == profileFile.lastModified => + ev.pure(creds) + case _ => + for + creds <- extractCredentials(profileFile) + _ <- cacheRef.set(Some((profileFile, creds))) + yield creds yield credentials } - + object ProfileCredentialsProvider: final case class Profile( - name: String, - properties: Map[String, String] - ) + name: String, + properties: Map[String, String] + ) final case class ProfileFile( - profiles: Map[String, Profile], - lastModified: Instant - ) + profiles: Map[String, Profile], + lastModified: Instant + ) - def default[F[_]: SystemProperties: Files: Concurrent](profileName: String = "default"): F[ProfileCredentialsProvider[F]] = + def default[F[_]: SystemProperties: Files: Concurrent]( + profileName: String = "default" + ): F[ProfileCredentialsProvider[F]] = for - cacheRef <- Ref.of[F, Option[(ProfileFile, AwsCredentials)]](None) + cacheRef <- Ref.of[F, Option[(ProfileFile, AwsCredentials)]](None) semaphore <- Semaphore[F](1) yield new ProfileCredentialsProvider[F](profileName, cacheRef, semaphore) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala index ea64997ef..e691df18a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -7,6 +7,7 @@ package ldbc.amazon.auth.credentials import cats.MonadThrow + import cats.effect.std.SystemProperties import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider @@ -17,10 +18,12 @@ import ldbc.amazon.util.SdkSystemSetting * [[AwsCredentialsProvider]] implementation that loads credentials from the aws.accessKeyId, aws.secretAccessKey and * aws.sessionToken system properties. */ -final class SystemPropertyCredentialsProvider[F[_]: SystemProperties: MonadThrow] extends SystemSettingsCredentialsProvider[F]: +final class SystemPropertyCredentialsProvider[F[_]: SystemProperties: MonadThrow] + extends SystemSettingsCredentialsProvider[F]: // Customers should be able to specify a credentials provider that only looks at the system properties, // but not the environment variables. For that reason, we're only checking the system properties here. - override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = SystemProperties[F].get(setting.systemProperty) + override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = + SystemProperties[F].get(setting.systemProperty) override def provider: String = BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 926443ac4..6b5805f8b 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -8,15 +8,16 @@ package ldbc.amazon.auth.credentials import scala.concurrent.duration.* -import cats.effect.std.{Env, SystemProperties} -import cats.effect.* import cats.syntax.all.* -import fs2.io.file.{Files, Path} +import cats.effect.* +import cats.effect.std.{ Env, SystemProperties } + +import fs2.io.file.{ Files, Path } import fs2.io.net.* import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils -import ldbc.amazon.client.{HttpClient, SimpleHttpClient} +import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.useragent.BusinessMetricFeatureId @@ -61,56 +62,60 @@ import ldbc.amazon.util.SdkSystemSetting */ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: Concurrent]( webIdentityUtils: WebIdentityCredentialsUtils[F], - httpClient: HttpClient[F], - region: String = "us-east-1" + httpClient: HttpClient[F], + region: String = "us-east-1" ) extends AwsCredentialsProvider[F]: override def resolveCredentials(): F[AwsCredentials] = for - config <- loadWebIdentityConfig() + config <- loadWebIdentityConfig() credentials <- config match { - case None => - Concurrent[F].raiseError(new SdkClientException( - "Unable to load Web Identity Token credentials. " + - "Required environment variables (AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN) or " + - "system properties (aws.webIdentityTokenFile, aws.roleArn) are not set." - )) - case Some(webIdentityConfig) => - webIdentityUtils.assumeRoleWithWebIdentity(webIdentityConfig, region, httpClient) - } + case None => + Concurrent[F].raiseError( + new SdkClientException( + "Unable to load Web Identity Token credentials. " + + "Required environment variables (AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN) or " + + "system properties (aws.webIdentityTokenFile, aws.roleArn) are not set." + ) + ) + case Some(webIdentityConfig) => + webIdentityUtils.assumeRoleWithWebIdentity(webIdentityConfig, region, httpClient) + } yield credentials private def loadWebIdentityConfig(): F[Option[WebIdentityTokenCredentialProperties]] = for - tokenFilePath <- loadTokenFilePath() - roleArn <- loadRoleArn() + tokenFilePath <- loadTokenFilePath() + roleArn <- loadRoleArn() roleSessionName <- loadRoleSessionName() yield (tokenFilePath, roleArn) match { case (Some(tokenFile), Some(arn)) => - Some(WebIdentityTokenCredentialProperties( - webIdentityTokenFile = Path(tokenFile), - roleArn = arn, - roleSessionName = roleSessionName, - providerName = BusinessMetricFeatureId.CREDENTIALS_WEB_IDENTITY_TOKEN.code - )) + Some( + WebIdentityTokenCredentialProperties( + webIdentityTokenFile = Path(tokenFile), + roleArn = arn, + roleSessionName = roleSessionName, + providerName = BusinessMetricFeatureId.CREDENTIALS_WEB_IDENTITY_TOKEN.code + ) + ) case _ => None } private def loadTokenFilePath(): F[Option[String]] = for - envValue <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") + envValue <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) private def loadRoleArn(): F[Option[String]] = for - envValue <- Env[F].get("AWS_ROLE_ARN") + envValue <- Env[F].get("AWS_ROLE_ARN") sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) private def loadRoleSessionName(): F[Option[String]] = for - envValue <- Env[F].get("AWS_ROLE_SESSION_NAME") + envValue <- Env[F].get("AWS_ROLE_SESSION_NAME") sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_SESSION_NAME.systemProperty) yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) @@ -124,9 +129,9 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: */ case class WebIdentityTokenCredentialProperties( webIdentityTokenFile: Path, - roleArn: String, - roleSessionName: Option[String], - providerName: String + roleArn: String, + roleSessionName: Option[String], + providerName: String ) object WebIdentityTokenFileCredentialsProvider: @@ -156,7 +161,7 @@ object WebIdentityTokenFileCredentialsProvider: */ def default[F[_]: Env: SystemProperties: Files: Concurrent]( httpClient: HttpClient[F], - region: String = "us-east-1" + region: String = "us-east-1" ): WebIdentityTokenFileCredentialsProvider[F] = val webIdentityUtils = WebIdentityCredentialsUtils.default[F] new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) @@ -172,8 +177,8 @@ object WebIdentityTokenFileCredentialsProvider: */ def create[F[_]: Env: SystemProperties: Concurrent]( webIdentityUtils: WebIdentityCredentialsUtils[F], - httpClient: HttpClient[F], - region: String = "us-east-1" + httpClient: HttpClient[F], + region: String = "us-east-1" ): WebIdentityTokenFileCredentialsProvider[F] = new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) @@ -184,10 +189,12 @@ object WebIdentityTokenFileCredentialsProvider: * @return A configured HTTP client */ private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = - Async[F].pure(new SimpleHttpClient[F]( - connectTimeout = 30.seconds, - readTimeout = 30.seconds - )) + Async[F].pure( + new SimpleHttpClient[F]( + connectTimeout = 30.seconds, + readTimeout = 30.seconds + ) + ) /** * Checks if Web Identity Token authentication is available by verifying @@ -198,10 +205,18 @@ object WebIdentityTokenFileCredentialsProvider: */ def isAvailable[F[_]: Env: SystemProperties: Concurrent](): F[Boolean] = for - tokenFile <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") - .flatMap(envValue => SystemProperties[F].get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) - .map(envValue.orElse(_))) - roleArn <- Env[F].get("AWS_ROLE_ARN") - .flatMap(envValue => SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) - .map(envValue.orElse(_))) + tokenFile <- Env[F] + .get("AWS_WEB_IDENTITY_TOKEN_FILE") + .flatMap(envValue => + SystemProperties[F] + .get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) + .map(envValue.orElse(_)) + ) + roleArn <- Env[F] + .get("AWS_ROLE_ARN") + .flatMap(envValue => + SystemProperties[F] + .get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) + .map(envValue.orElse(_)) + ) yield tokenFile.exists(_.trim.nonEmpty) && roleArn.exists(_.trim.nonEmpty) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala index da743c294..3b599d7ae 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -6,8 +6,8 @@ package ldbc.amazon.auth.credentials.internal -import cats.MonadThrow import cats.syntax.all.* +import cats.MonadThrow import ldbc.amazon.auth.credentials.* import ldbc.amazon.exception.SdkClientException @@ -22,8 +22,18 @@ trait SystemSettingsCredentialsProvider[F[_]](using ev: MonadThrow[F]) extends A secretKeyOpt <- loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.map(_.trim)) sessionTokenOpt <- loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN).map(_.map(_.trim)) accountId <- loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.map(_.trim)) - accessKey <- ev.fromOption(accessKeyOpt, new SdkClientException(s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${ SdkSystemSetting.AWS_ACCESS_KEY_ID }) or system property (${ SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty }).")) - secretKey <- ev.fromOption(secretKeyOpt, new SdkClientException(s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY }) or system property (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty }).")) + accessKey <- ev.fromOption( + accessKeyOpt, + new SdkClientException( + s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${ SdkSystemSetting.AWS_ACCESS_KEY_ID }) or system property (${ SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty })." + ) + ) + secretKey <- ev.fromOption( + secretKeyOpt, + new SdkClientException( + s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY }) or system property (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty })." + ) + ) yield sessionTokenOpt match { case None => AwsBasicCredentials( diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index bfe12d4a1..18259a0f2 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -6,15 +6,16 @@ package ldbc.amazon.auth.credentials.internal +import cats.syntax.all.* import cats.MonadThrow + import cats.effect.Concurrent -import cats.syntax.all.* -import fs2.io.file.{Files, Path} +import fs2.io.file.{ Files, Path } import ldbc.amazon.auth.credentials.* -import ldbc.amazon.client.{HttpClient, StsClient} -import ldbc.amazon.exception.{TokenFileNotFoundException, InvalidTokenException} +import ldbc.amazon.client.{ HttpClient, StsClient } +import ldbc.amazon.exception.{ InvalidTokenException, TokenFileNotFoundException } import ldbc.amazon.identity.AwsCredentials /** @@ -39,8 +40,8 @@ trait WebIdentityCredentialsUtils[F[_]]: * @return AWS credentials with session token */ def assumeRoleWithWebIdentity( - config: WebIdentityTokenCredentialProperties, - region: String, + config: WebIdentityTokenCredentialProperties, + region: String, httpClient: HttpClient[F] ): F[AwsCredentials] @@ -51,18 +52,18 @@ object WebIdentityCredentialsUtils: ) extends WebIdentityCredentialsUtils[F]: def assumeRoleWithWebIdentity( - config: WebIdentityTokenCredentialProperties, - region: String, + config: WebIdentityTokenCredentialProperties, + region: String, httpClient: HttpClient[F] ): F[AwsCredentials] = for token <- readTokenFromFile(config.webIdentityTokenFile) - _ <- validateToken(token) + _ <- validateToken(token) stsRequest = StsClient.AssumeRoleWithWebIdentityRequest( - roleArn = config.roleArn, - webIdentityToken = token, - roleSessionName = config.roleSessionName - ) + roleArn = config.roleArn, + webIdentityToken = token, + roleSessionName = config.roleSessionName + ) stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest, region, httpClient) credentials = convertStsResponseToCredentials(stsResponse, config) yield credentials @@ -76,13 +77,13 @@ object WebIdentityCredentialsUtils: private def readTokenFromFile(tokenFilePath: Path): F[String] = for exists <- Files[F].exists(tokenFilePath) - _ <- Concurrent[F].raiseUnless(exists)( - new TokenFileNotFoundException(s"Web Identity Token file not found: $tokenFilePath") - ) + _ <- Concurrent[F].raiseUnless(exists)( + new TokenFileNotFoundException(s"Web Identity Token file not found: $tokenFilePath") + ) token <- Files[F].readUtf8(tokenFilePath).compile.string.map(_.trim) - _ <- Concurrent[F].raiseWhen(token.isEmpty)( - new InvalidTokenException(s"Web Identity Token file is empty: $tokenFilePath") - ) + _ <- Concurrent[F].raiseWhen(token.isEmpty)( + new InvalidTokenException(s"Web Identity Token file is empty: $tokenFilePath") + ) yield token /** @@ -98,9 +99,9 @@ object WebIdentityCredentialsUtils: MonadThrow[F].fromEither { // Basic JWT format validation (header.payload.signature) val parts = token.split("\\.") - if (parts.length != 3) { - Left(new InvalidTokenException(s"Invalid JWT token format. Expected 3 parts, got ${parts.length}")) - } else if (parts.exists(_.isEmpty)) { + if parts.length != 3 then { + Left(new InvalidTokenException(s"Invalid JWT token format. Expected 3 parts, got ${ parts.length }")) + } else if parts.exists(_.isEmpty) then { Left(new InvalidTokenException("JWT token contains empty parts")) } else { Right(()) @@ -116,16 +117,16 @@ object WebIdentityCredentialsUtils: */ private def convertStsResponseToCredentials( stsResponse: StsClient.AssumeRoleWithWebIdentityResponse, - config: WebIdentityTokenCredentialProperties + config: WebIdentityTokenCredentialProperties ): AwsCredentials = AwsSessionCredentials( - accessKeyId = stsResponse.accessKeyId, - secretAccessKey = stsResponse.secretAccessKey, - sessionToken = stsResponse.sessionToken, + accessKeyId = stsResponse.accessKeyId, + secretAccessKey = stsResponse.secretAccessKey, + sessionToken = stsResponse.sessionToken, validateCredentials = false, - providerName = Some(config.providerName), - accountId = extractAccountIdFromArn(stsResponse.assumedRoleArn), - expirationTime = Some(stsResponse.expiration) + providerName = Some(config.providerName), + accountId = extractAccountIdFromArn(stsResponse.assumedRoleArn), + expirationTime = Some(stsResponse.expiration) ) /** @@ -137,7 +138,7 @@ object WebIdentityCredentialsUtils: private def extractAccountIdFromArn(arn: String): Option[String] = // ARN format: arn:aws:sts::ACCOUNT_ID:assumed-role/ROLE_NAME/SESSION_NAME val arnParts = arn.split(":") - if (arnParts.length >= 5) { + if arnParts.length >= 5 then { Some(arnParts(4)) } else { None @@ -149,7 +150,7 @@ object WebIdentityCredentialsUtils: * @tparam F The effect type * @return A WebIdentityCredentialsUtils instance */ - def default[F[_]: Files: Concurrent]: WebIdentityCredentialsUtils[F] = + def default[F[_]: Files: Concurrent]: WebIdentityCredentialsUtils[F] = val stsClient = StsClient.default[F] Impl[F](stsClient) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 07ff1c3fb..14fd9cd9f 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -96,11 +96,11 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( override def generateToken(credentials: AwsCredentials): F[String] = for now <- clock.realTimeInstant - dateTime = formatDateTime(now) - date = dateTime.substring(0, 8) - credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" - credential = s"${ credentials.accessKeyId }/$credentialScope" - //queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) + dateTime = formatDateTime(now) + date = dateTime.substring(0, 8) + credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" + credential = s"${ credentials.accessKeyId }/$credentialScope" + // queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) queryParams = buildQueryParams(credential, dateTime, ???, username) canonicalRequest = buildCanonicalRequest(s"$hostname:$port", queryParams) canonicalRequestHash <- sha256Hex(canonicalRequest) @@ -143,18 +143,18 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( username: String ): String = val baseParams = List( - "Action" -> "connect", - "DBUser" -> username, - "X-Amz-Algorithm" -> ALGORITHM, - "X-Amz-Credential" -> credential, - "X-Amz-Date" -> dateTime, - "X-Amz-Expires" -> EXPIRES_SECONDS.toString, + "Action" -> "connect", + "DBUser" -> username, + "X-Amz-Algorithm" -> ALGORITHM, + "X-Amz-Credential" -> credential, + "X-Amz-Date" -> dateTime, + "X-Amz-Expires" -> EXPIRES_SECONDS.toString, "X-Amz-SignedHeaders" -> "host" ) val params = sessionToken match case Some(token) => baseParams :+ ("X-Amz-Security-Token" -> token) - case None => baseParams + case None => baseParams params .sortBy(_._1) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala index 66c0b5709..d9e48f52e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala @@ -9,5 +9,5 @@ package ldbc.amazon.client import java.net.URI trait HttpClient[F[_]]: - + def get(uri: URI, headers: Map[String, String]): F[HttpResponse] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala index b22a66ff5..20693ca32 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala @@ -7,7 +7,7 @@ package ldbc.amazon.client final case class HttpResponse( - statusCode: Int, - headers: Map[String, String], - body: String - ) + statusCode: Int, + headers: Map[String, String], + body: String +) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 5ce0a7e79..2ac785d76 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -10,14 +10,14 @@ import java.net.URI import scala.concurrent.duration.* -import cats.MonadThrow +import com.comcast.ip4s.* + import cats.syntax.all.* +import cats.MonadThrow import cats.effect.* import cats.effect.syntax.all.* -import com.comcast.ip4s.* - import fs2.* import fs2.io.net.* @@ -25,8 +25,9 @@ import ldbc.amazon.exception.* class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, - readTimeout: Duration -)(using ev: MonadThrow[F]) extends HttpClient[F]: + readTimeout: Duration +)(using ev: MonadThrow[F]) + extends HttpClient[F]: override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = val host = uri.getHost @@ -35,7 +36,7 @@ class SimpleHttpClient[F[_]: Network: Async]( Option(uri.getQuery).map("?" + _).getOrElse("") for - address <- resolveAddress(host, port) + address <- resolveAddress(host, port) response <- makeRequest(address, host, port, path, headers) yield response @@ -46,20 +47,21 @@ class SimpleHttpClient[F[_]: Network: Async]( yield SocketAddress(h, p) private def sendRequest( - socket: Socket[F], - host: String, - port: Int, - path: String, - headers: Map[String, String] - ): F[Unit] = + socket: Socket[F], + host: String, + port: Int, + path: String, + headers: Map[String, String] + ): F[Unit] = val hostHeader = if port == 80 then host else s"$host:$port" val allHeaders = headers + ("Host" -> hostHeader) + ("Connection" -> "close") val requestLine = s"GET $path HTTP/1.1\r\n" val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString - val request = requestLine + headerLines + "\r\n" + val request = requestLine + headerLines + "\r\n" - Stream.emit(request) + Stream + .emit(request) .through(text.utf8.encode) .through(socket.writes) .compile @@ -71,13 +73,13 @@ class SimpleHttpClient[F[_]: Network: Async]( case _ :: code :: _ => code.toIntOption match case Some(c) => ev.pure(c) - case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) private def parseHeaderLine(line: String): Option[(String, String)] = line.split(": ", 2).toList match case key :: value :: Nil => Some(key -> value) - case _ => None + case _ => None private def parseHttpResponse(raw: String): F[HttpResponse] = val lines = raw.split("\r\n").toList @@ -86,8 +88,8 @@ class SimpleHttpClient[F[_]: Network: Async]( case statusLine :: rest => parseStatusLine(statusLine).flatMap: statusCode => val (headerLines, bodyLines) = rest.span(_.nonEmpty) - val headers = headerLines.flatMap(parseHeaderLine).toMap - val body = bodyLines.drop(1).mkString("\r\n") // drop empty line + val headers = headerLines.flatMap(parseHeaderLine).toMap + val body = bodyLines.drop(1).mkString("\r\n") // drop empty line ev.pure(HttpResponse(statusCode, headers, body)) case _ => @@ -101,16 +103,17 @@ class SimpleHttpClient[F[_]: Network: Async]( .flatMap(parseHttpResponse) private def makeRequest( - address: SocketAddress[Host], - host: String, - port: Int, - path: String, - headers: Map[String, String] - ): F[HttpResponse] = - Network[F].client(address) + address: SocketAddress[Host], + host: String, + port: Int, + path: String, + headers: Map[String, String] + ): F[HttpResponse] = + Network[F] + .client(address) .use { socket => for - _ <- sendRequest(socket, host, port, path, headers) + _ <- sendRequest(socket, host, port, path, headers) response <- receiveResponse(socket) yield response } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index bc448363a..fb5d72118 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -6,18 +6,19 @@ package ldbc.amazon.client -import java.net.{URI, URLEncoder} -import java.time.{Instant, ZoneOffset} +import java.net.{ URI, URLEncoder } +import java.time.{ Instant, ZoneOffset } import java.time.format.DateTimeFormatter import java.util.UUID import scala.xml.XML +import cats.syntax.all.* import cats.MonadThrow + import cats.effect.Concurrent -import cats.syntax.all.* -import ldbc.amazon.exception.{SdkClientException, StsException} +import ldbc.amazon.exception.{ SdkClientException, StsException } /** * Trait for AWS STS (Security Token Service) client operations. @@ -36,8 +37,8 @@ trait StsClient[F[_]]: * @return STS response with temporary credentials */ def assumeRoleWithWebIdentity( - request: StsClient.AssumeRoleWithWebIdentityRequest, - region: String, + request: StsClient.AssumeRoleWithWebIdentityRequest, + region: String, httpClient: HttpClient[F] ): F[StsClient.AssumeRoleWithWebIdentityResponse] @@ -52,10 +53,10 @@ object StsClient: * @param durationSeconds Optional duration of the session (default 3600 seconds) */ case class AssumeRoleWithWebIdentityRequest( - roleArn: String, + roleArn: String, webIdentityToken: String, - roleSessionName: Option[String] = None, - durationSeconds: Option[Int] = None + roleSessionName: Option[String] = None, + durationSeconds: Option[Int] = None ) /** @@ -68,41 +69,43 @@ object StsClient: * @param assumedRoleArn The ARN of the assumed role */ case class AssumeRoleWithWebIdentityResponse( - accessKeyId: String, + accessKeyId: String, secretAccessKey: String, - sessionToken: String, - expiration: Instant, - assumedRoleArn: String + sessionToken: String, + expiration: Instant, + assumedRoleArn: String ) private case class Impl[F[_]: Concurrent]() extends StsClient[F]: def assumeRoleWithWebIdentity( - request: AssumeRoleWithWebIdentityRequest, - region: String, + request: AssumeRoleWithWebIdentityRequest, + region: String, httpClient: HttpClient[F] ): F[AssumeRoleWithWebIdentityResponse] = for timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) - sessionName = request.roleSessionName.getOrElse(s"ldbc-session-${UUID.randomUUID()}") - duration = request.durationSeconds.getOrElse(3600) - + sessionName = request.roleSessionName.getOrElse(s"ldbc-session-${ UUID.randomUUID() }") + duration = request.durationSeconds.getOrElse(3600) + // Build STS request stsEndpoint = s"https://sts.$region.amazonaws.com/" - requestBody = buildRequestBody(request.copy( - roleSessionName = Some(sessionName), - durationSeconds = Some(duration) - )) - + requestBody = buildRequestBody( + request.copy( + roleSessionName = Some(sessionName), + durationSeconds = Some(duration) + ) + ) + // Make HTTP request headers = Map( - "Content-Type" -> "application/x-amz-json-1.0", - "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRoleWithWebIdentity", - "X-Amz-Date" -> timestamp - ) - - response <- httpClient.get(URI.create(stsEndpoint), headers) - _ <- validateHttpResponse(response) + "Content-Type" -> "application/x-amz-json-1.0", + "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRoleWithWebIdentity", + "X-Amz-Date" -> timestamp + ) + + response <- httpClient.get(URI.create(stsEndpoint), headers) + _ <- validateHttpResponse(response) stsResponse <- parseAssumeRoleResponse(response.body) yield stsResponse @@ -119,27 +122,32 @@ object StsClient: */ private def buildRequestBody(request: AssumeRoleWithWebIdentityRequest): String = val params = Map( - "Action" -> "AssumeRoleWithWebIdentity", - "Version" -> "2011-06-15", - "RoleArn" -> request.roleArn, + "Action" -> "AssumeRoleWithWebIdentity", + "Version" -> "2011-06-15", + "RoleArn" -> request.roleArn, "WebIdentityToken" -> request.webIdentityToken, - "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), - "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString + "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), + "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString ) - - params.map { case (key, value) => - s"${URLEncoder.encode(key, "UTF-8")}=${URLEncoder.encode(value, "UTF-8")}" - }.mkString("&") + + params + .map { + case (key, value) => + s"${ URLEncoder.encode(key, "UTF-8") }=${ URLEncoder.encode(value, "UTF-8") }" + } + .mkString("&") /** * Gets current timestamp in AWS format. */ private def getCurrentTimestamp(): Either[Throwable, String] = try { - Right(DateTimeFormatter - .ofPattern("yyyyMMdd'T'HHmmss'Z'") - .withZone(ZoneOffset.UTC) - .format(Instant.now())) + Right( + DateTimeFormatter + .ofPattern("yyyyMMdd'T'HHmmss'Z'") + .withZone(ZoneOffset.UTC) + .format(Instant.now()) + ) } catch { case ex: Exception => Left(new SdkClientException("Failed to generate timestamp")) } @@ -148,12 +156,14 @@ object StsClient: * Validates HTTP response status. */ private def validateHttpResponse[F[_]: MonadThrow](response: HttpResponse): F[Unit] = - if (response.statusCode >= 200 && response.statusCode < 300) { + if response.statusCode >= 200 && response.statusCode < 300 then { MonadThrow[F].unit } else { - MonadThrow[F].raiseError(new StsException( - s"STS request failed with status ${response.statusCode}: ${response.body}" - )) + MonadThrow[F].raiseError( + new StsException( + s"STS request failed with status ${ response.statusCode }: ${ response.body }" + ) + ) } /** @@ -178,37 +188,39 @@ object StsClient: * ``` */ private def parseAssumeRoleResponse[F[_]: MonadThrow](xmlBody: String): F[AssumeRoleWithWebIdentityResponse] = - MonadThrow[F].catchNonFatal { - val xml = XML.loadString(xmlBody) - - // Extract credentials - val credentials = (xml \ "AssumeRoleWithWebIdentityResult" \ "Credentials").head - val accessKeyId = (credentials \ "AccessKeyId").text.trim - val secretAccessKey = (credentials \ "SecretAccessKey").text.trim - val sessionToken = (credentials \ "SessionToken").text.trim - val expirationStr = (credentials \ "Expiration").text.trim - - // Extract assumed role information - val assumedRoleUser = (xml \ "AssumeRoleWithWebIdentityResult" \ "AssumedRoleUser").head - val assumedRoleArn = (assumedRoleUser \ "Arn").text.trim - - // Parse expiration time - val expiration = Instant.parse(expirationStr) - - // Validate required fields - if (accessKeyId.isEmpty) throw new StsException("AccessKeyId not found in STS response") - if (secretAccessKey.isEmpty) throw new StsException("SecretAccessKey not found in STS response") - if (sessionToken.isEmpty) throw new StsException("SessionToken not found in STS response") - if (assumedRoleArn.isEmpty) throw new StsException("AssumedRoleArn not found in STS response") - - AssumeRoleWithWebIdentityResponse( - accessKeyId = accessKeyId, - secretAccessKey = secretAccessKey, - sessionToken = sessionToken, - expiration = expiration, - assumedRoleArn = assumedRoleArn - ) - }.adaptError { case ex => - new StsException(s"Failed to parse STS response: ${ex.getMessage}") - } - + MonadThrow[F] + .catchNonFatal { + val xml = XML.loadString(xmlBody) + + // Extract credentials + val credentials = (xml \ "AssumeRoleWithWebIdentityResult" \ "Credentials").head + val accessKeyId = (credentials \ "AccessKeyId").text.trim + val secretAccessKey = (credentials \ "SecretAccessKey").text.trim + val sessionToken = (credentials \ "SessionToken").text.trim + val expirationStr = (credentials \ "Expiration").text.trim + + // Extract assumed role information + val assumedRoleUser = (xml \ "AssumeRoleWithWebIdentityResult" \ "AssumedRoleUser").head + val assumedRoleArn = (assumedRoleUser \ "Arn").text.trim + + // Parse expiration time + val expiration = Instant.parse(expirationStr) + + // Validate required fields + if accessKeyId.isEmpty then throw new StsException("AccessKeyId not found in STS response") + if secretAccessKey.isEmpty then throw new StsException("SecretAccessKey not found in STS response") + if sessionToken.isEmpty then throw new StsException("SessionToken not found in STS response") + if assumedRoleArn.isEmpty then throw new StsException("AssumedRoleArn not found in STS response") + + AssumeRoleWithWebIdentityResponse( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + sessionToken = sessionToken, + expiration = expiration, + assumedRoleArn = assumedRoleArn + ) + } + .adaptError { + case ex => + new StsException(s"Failed to parse STS response: ${ ex.getMessage }") + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala index 3beeff31d..a5e175d1c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala @@ -24,12 +24,12 @@ package ldbc.amazon.exception * @param cause The underlying cause of the exception (optional) */ class InvalidTokenException( - message: String, - cause: Option[Throwable] = None + message: String, + cause: Option[Throwable] = None ) extends WebIdentityTokenException(message, cause): /** * Constructor with cause */ def this(message: String, cause: Throwable) = - this(message, Some(cause)) \ No newline at end of file + this(message, Some(cause)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala index bb4481524..162024147 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -32,9 +32,10 @@ import scala.util.control.NoStackTrace * @param cause The underlying cause of the exception (optional) */ class StsException( - message: String, - cause: Option[Throwable] = None -) extends SdkClientException(message) with NoStackTrace: + message: String, + cause: Option[Throwable] = None +) extends SdkClientException(message) + with NoStackTrace: // Set the cause if provided cause.foreach(initCause) @@ -43,4 +44,4 @@ class StsException( * Constructor with cause */ def this(message: String, cause: Throwable) = - this(message, Some(cause)) \ No newline at end of file + this(message, Some(cause)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala index 54b8c4416..a40802e20 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala @@ -23,12 +23,12 @@ package ldbc.amazon.exception * @param cause The underlying cause of the exception (optional) */ class TokenFileAccessException( - message: String, - cause: Option[Throwable] = None + message: String, + cause: Option[Throwable] = None ) extends WebIdentityTokenException(message, cause): /** * Constructor with cause */ def this(message: String, cause: Throwable) = - this(message, Some(cause)) \ No newline at end of file + this(message, Some(cause)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala index ea57c510b..dc4ae761f 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala @@ -24,9 +24,9 @@ package ldbc.amazon.exception * @param cause The underlying cause of the exception (optional) */ class TokenFileNotFoundException( - message: String, - tokenFilePath: Option[String] = None, - cause: Option[Throwable] = None + message: String, + tokenFilePath: Option[String] = None, + cause: Option[Throwable] = None ) extends WebIdentityTokenException(message, cause): /** @@ -50,4 +50,4 @@ class TokenFileNotFoundException( override def getMessage: String = tokenFilePath match case Some(path) => s"$message (Token file path: $path)" - case None => message \ No newline at end of file + case None => message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala index fcb41fdb4..99787ea2f 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala @@ -16,8 +16,9 @@ import scala.util.control.NoStackTrace */ abstract class WebIdentityTokenException( message: String, - cause: Option[Throwable] = None -) extends SdkClientException(message) with NoStackTrace: + cause: Option[Throwable] = None +) extends SdkClientException(message) + with NoStackTrace: // Set the cause if provided cause.foreach(initCause) @@ -26,4 +27,4 @@ abstract class WebIdentityTokenException( * Constructor with cause */ def this(message: String, cause: Throwable) = - this(message, Some(cause)) \ No newline at end of file + this(message, Some(cause)) From 3d4ad8d25d4e9a13a7a5260ea1455ae28c8d5728 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 01:11:45 +0900 Subject: [PATCH 090/215] Action sbt scalafmtSbt --- build.sbt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build.sbt b/build.sbt index fa8cdc043..6de6ac9bb 100644 --- a/build.sbt +++ b/build.sbt @@ -148,10 +148,10 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP .module("aws-authentication-plugin", "") .settings( libraryDependencies ++= Seq( - "org.scala-lang.modules" %%% "scala-xml" % "2.2.0", - "co.fs2" %%% "fs2-core" % "3.12.2", - "co.fs2" %%% "fs2-io" % "3.12.2", - "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test + "org.scala-lang.modules" %%% "scala-xml" % "2.2.0", + "co.fs2" %%% "fs2-core" % "3.12.2", + "co.fs2" %%% "fs2-io" % "3.12.2", + "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test ) ) .jsSettings( From 620f1c51592eb0dbf84984e162bc197e7bde64b9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 21:27:03 +0900 Subject: [PATCH 091/215] Change js compile error clock function --- .../amazon/auth/token/RdsIamAuthTokenGenerator.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 14fd9cd9f..9231da0d7 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -11,6 +11,8 @@ import java.nio.charset.StandardCharsets import java.time.{ Instant, ZoneOffset } import java.time.format.DateTimeFormatter +import scala.concurrent.duration.FiniteDuration + import cats.syntax.all.* import cats.effect.kernel.{ Clock, Sync } @@ -95,7 +97,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( */ override def generateToken(credentials: AwsCredentials): F[String] = for - now <- clock.realTimeInstant + now <- clock.realTime dateTime = formatDateTime(now) date = dateTime.substring(0, 8) credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" @@ -109,19 +111,19 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( yield s"$hostname:$port/?$queryParams&X-Amz-Signature=$signature" /** - * Formats an Instant to ISO 8601 basic format string required by AWS Signature Version 4. + * Formats an FiniteDuration to ISO 8601 basic format string required by AWS Signature Version 4. * * Converts a timestamp to the format "yyyyMMddTHHmmssZ" in UTC timezone, * which is required for AWS authentication requests. * - * @param instant The timestamp to format + * @param duration current time is used regardless of this value * @return Formatted datetime string in AWS SigV4 format (e.g., "20230101T120000Z") */ - private def formatDateTime(instant: Instant): String = + private def formatDateTime(duration: FiniteDuration): String = DateTimeFormatter .ofPattern("yyyyMMdd'T'HHmmss'Z'") .withZone(ZoneOffset.UTC) - .format(instant) + .format(Instant.now().plusMillis(duration.toMillis)) /** * Builds the query parameters for the RDS authentication request. From 241d4a40617efd8319c425d1eff38d472f4f2e60 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 21:27:28 +0900 Subject: [PATCH 092/215] Action sbt scalafmtAll --- .../DefaultCredentialsProviderChain.scala | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index 7f1ac9527..733239489 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -41,7 +41,7 @@ import ldbc.amazon.identity.* * * @tparam F The effect type */ -class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent] +class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent] extends AwsCredentialsProvider[F]: private lazy val providers: F[List[AwsCredentialsProvider[F]]] = @@ -54,7 +54,7 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Networ webIdentityProvider, profileProvider // ContainerCredentialsProvider - TODO: implement - // InstanceProfileCredentialsProvider - TODO: implement + // InstanceProfileCredentialsProvider - TODO: implement ) override def resolveCredentials(): F[AwsCredentials] = @@ -64,7 +64,7 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Networ yield credentials private def tryProvidersInOrder( - providers: List[AwsCredentialsProvider[F]], + providers: List[AwsCredentialsProvider[F]], exceptions: List[String] ): F[AwsCredentials] = providers match @@ -82,20 +82,20 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Networ object DefaultCredentialsProviderChain: /** - * Creates a new default credentials provider chain. - * - * @tparam F The effect type - * @return A new DefaultCredentialsProviderChain instance - */ + * Creates a new default credentials provider chain. + * + * @tparam F The effect type + * @return A new DefaultCredentialsProviderChain instance + */ def apply[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent](): DefaultCredentialsProviderChain[F] = new DefaultCredentialsProviderChain[F]() /** - * Convenience method to resolve credentials using the default chain. - * - * @tparam F The effect type - * @return AWS credentials from the first successful provider - */ + * Convenience method to resolve credentials using the default chain. + * + * @tparam F The effect type + * @return AWS credentials from the first successful provider + */ def resolveCredentials[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent](): F[AwsCredentials] = DefaultCredentialsProviderChain[F]().resolveCredentials() - */ \ No newline at end of file + */ From b7cc3be6aeed330fe5be7cb3290788da876d0f2f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 21:28:32 +0900 Subject: [PATCH 093/215] Change scaladoc --- .../scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 9231da0d7..e7ffe65cd 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -116,7 +116,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( * Converts a timestamp to the format "yyyyMMddTHHmmssZ" in UTC timezone, * which is required for AWS authentication requests. * - * @param duration current time is used regardless of this value + * @param duration Duration to add to the current time * @return Formatted datetime string in AWS SigV4 format (e.g., "20230101T120000Z") */ private def formatDateTime(duration: FiniteDuration): String = From da95e23b14ea409ea3ed5ed8004002793f28ab67 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 22:02:20 +0900 Subject: [PATCH 094/215] Create ContainerCredentialsProvider --- build.sbt | 1 + .../ContainerCredentialsProvider.scala | 260 ++++++++++++++++++ .../useragent/BusinessMetricFeatureId.scala | 1 + 3 files changed, 262 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala diff --git a/build.sbt b/build.sbt index 6de6ac9bb..590b15a2f 100644 --- a/build.sbt +++ b/build.sbt @@ -151,6 +151,7 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP "org.scala-lang.modules" %%% "scala-xml" % "2.2.0", "co.fs2" %%% "fs2-core" % "3.12.2", "co.fs2" %%% "fs2-io" % "3.12.2", + "io.circe" %% "circe-parser" % "0.14.10", "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test ) ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala new file mode 100644 index 000000000..ba3f267eb --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -0,0 +1,260 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import scala.concurrent.duration.* + +import cats.effect.std.Env +import cats.effect.{Async, Concurrent} +import cats.syntax.all.* + +import fs2.io.file.{Files, Path} +import fs2.io.net.* + +import io.circe.* +import io.circe.parser.* + +import ldbc.amazon.client.{HttpClient, SimpleHttpClient} +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.useragent.BusinessMetricFeatureId + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from AWS Container Credential Provider. + * + * This provider is used in containerized environments such as: + * - Amazon ECS with IAM Task Roles + * - Amazon EKS with IAM Roles for Service Accounts (IRSA) + * - Amazon EKS with Pod Identity + * + * The provider reads configuration from environment variables: + * - AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: Relative URI for ECS metadata endpoint + * - AWS_CONTAINER_CREDENTIALS_FULL_URI: Full URI for credential endpoint + * - AWS_CONTAINER_AUTHORIZATION_TOKEN: Authorization token for requests + * - AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE: Path to file containing authorization token + * + * Authentication flow: + * 1. Check environment variables for endpoint configuration + * 2. Load authorization token from direct env var or file + * 3. Make HTTP GET request to credential endpoint with Authorization header + * 4. Parse JSON response to extract temporary credentials + * + * ECS endpoint: http://169.254.170.2/ + * EKS Pod Identity endpoint: http://169.254.170.23/v1/credentials + * + * @param httpClient HTTP client for making credential requests + * @tparam F The effect type + */ +final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( + httpClient: HttpClient[F] +) extends AwsCredentialsProvider[F]: + + override def resolveCredentials(): F[AwsCredentials] = + for + config <- loadContainerCredentialsConfig() + credentials <- config match { + case None => + Concurrent[F].raiseError(new SdkClientException( + "Unable to load container credentials. " + + "Environment variables AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or " + + "AWS_CONTAINER_CREDENTIALS_FULL_URI are not set." + )) + case Some(containerConfig) => + fetchCredentialsFromEndpoint(containerConfig) + } + yield credentials + + private def loadContainerCredentialsConfig(): F[Option[ContainerCredentialsConfig]] = + for + relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + directToken <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN") + tokenFile <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") + token <- loadAuthorizationToken(directToken, tokenFile) + yield (relativeUri, fullUri) match { + case (Some(relative), _) => + Some(ContainerCredentialsConfig( + endpointUri = s"http://169.254.170.2$relative", + authorizationToken = token + )) + case (_, Some(full)) => + Some(ContainerCredentialsConfig( + endpointUri = full, + authorizationToken = token + )) + case _ => None + } + + private def loadAuthorizationToken( + directToken: Option[String], + tokenFilePath: Option[String] + ): F[Option[String]] = + (directToken, tokenFilePath) match { + case (Some(token), _) => + Concurrent[F].pure(Some(token.trim).filter(_.nonEmpty)) + case (_, Some(filePath)) => + loadTokenFromFile(Path(filePath)) + case _ => + Concurrent[F].pure(None) + } + + private def loadTokenFromFile(tokenFilePath: Path): F[Option[String]] = + Files[F].exists(tokenFilePath).flatMap { exists => + if (exists) { + Files[F].readUtf8(tokenFilePath).compile.string + .map(_.trim) + .map(token => if (token.nonEmpty) Some(token) else None) + .handleErrorWith { _ => + Concurrent[F].pure(None) + } + } else { + Concurrent[F].pure(None) + } + } + + private def fetchCredentialsFromEndpoint(config: ContainerCredentialsConfig): F[AwsCredentials] = + val headers = buildRequestHeaders(config.authorizationToken) + for + response <- httpClient.get(URI.create(config.endpointUri), headers) + _ <- validateHttpResponse(response) + credentials <- parseCredentialsResponse(response.body) + yield credentials + + private def buildRequestHeaders(authToken: Option[String]): Map[String, String] = + val baseHeaders = Map( + "Accept" -> "application/json", + "User-Agent" -> "aws-sdk-scala/ldbc" + ) + authToken match { + case Some(token) => baseHeaders + ("Authorization" -> token) + case None => baseHeaders + } + + private def validateHttpResponse(response: ldbc.amazon.client.HttpResponse): F[Unit] = + if (response.statusCode >= 200 && response.statusCode < 300) { + Concurrent[F].unit + } else { + Concurrent[F].raiseError(new SdkClientException( + s"Container credentials request failed with status ${response.statusCode}: ${response.body}" + )) + } + + private def parseCredentialsResponse(jsonBody: String): F[AwsCredentials] = + Concurrent[F].fromEither( + parse(jsonBody).flatMap(_.as[ContainerCredentialsResponse]) + ).map { response => + AwsSessionCredentials( + accessKeyId = response.AccessKeyId, + secretAccessKey = response.SecretAccessKey, + sessionToken = response.Token, + validateCredentials = false, + providerName = Some(BusinessMetricFeatureId.CREDENTIALS_CONTAINER.code), + accountId = extractAccountIdFromRoleArn(response.RoleArn), + expirationTime = Some(Instant.parse(response.Expiration)) + ) + }.adaptError { case ex => + new SdkClientException(s"Failed to parse container credentials response: ${ex.getMessage}") + } + + private def extractAccountIdFromRoleArn(roleArn: Option[String]): Option[String] = + roleArn.flatMap { arn => + // ARN format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + val arnParts = arn.split(":") + if (arnParts.length >= 5) { + Some(arnParts(4)) + } else { + None + } + } + +/** + * Configuration for container credentials endpoint. + * + * @param endpointUri The full URI to the credential endpoint + * @param authorizationToken Optional authorization token for requests + */ +private case class ContainerCredentialsConfig( + endpointUri: String, + authorizationToken: Option[String] +) + +/** + * JSON response from container credentials endpoint. + * + * @param AccessKeyId AWS access key ID + * @param SecretAccessKey AWS secret access key + * @param Token Session token for temporary credentials + * @param Expiration RFC3339 formatted expiration timestamp + * @param RoleArn Optional ARN of the assumed role + */ +private case class ContainerCredentialsResponse( + AccessKeyId: String, + SecretAccessKey: String, + Token: String, + Expiration: String, + RoleArn: Option[String] = None +) + +private object ContainerCredentialsResponse: + given Decoder[ContainerCredentialsResponse] = Decoder.forProduct5( + "AccessKeyId", + "SecretAccessKey", + "Token", + "Expiration", + "RoleArn" + )(ContainerCredentialsResponse.apply) + +object ContainerCredentialsProvider: + + /** + * Creates a new Container credentials provider with default settings. + * + * @tparam F The effect type + * @return A new ContainerCredentialsProvider instance + */ + def apply[F[_]: Files: Env: Network: Async](): F[ContainerCredentialsProvider[F]] = + createDefaultHttpClient[F]().map(httpClient => new ContainerCredentialsProvider[F](httpClient)) + + /** + * Creates a new Container credentials provider with custom HTTP client. + * + * @param httpClient The HTTP client for credential requests + * @tparam F The effect type + * @return A new ContainerCredentialsProvider instance + */ + def create[F[_]: Files: Env: Concurrent]( + httpClient: HttpClient[F] + ): ContainerCredentialsProvider[F] = + new ContainerCredentialsProvider[F](httpClient) + + /** + * Creates a default HTTP client for container credential operations. + * + * @tparam F The effect type + * @return A configured HTTP client + */ + private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = + Async[F].pure(new SimpleHttpClient[F]( + connectTimeout = 5.seconds, + readTimeout = 5.seconds + )) + + /** + * Checks if Container credentials are available by verifying + * the presence of required environment variables. + * + * @tparam F The effect type + * @return true if Container credentials are properly configured + */ + def isAvailable[F[_]: Env: Concurrent](): F[Boolean] = + for + relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + yield relativeUri.exists(_.trim.nonEmpty) || fullUri.exists(_.trim.nonEmpty) \ No newline at end of file diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala index 14768d02e..52c06a615 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -60,5 +60,6 @@ enum BusinessMetricFeatureId(val code: String): case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") + case CREDENTIALS_CONTAINER extends BusinessMetricFeatureId("1") case CREDENTIALS_WEB_IDENTITY_TOKEN extends BusinessMetricFeatureId("k") case UNKNOWN extends BusinessMetricFeatureId("Unknown") From a7180b5948842d1eede31b730f4cba3e8f4b31bc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 22:25:51 +0900 Subject: [PATCH 095/215] Added put method --- .../scala/ldbc/amazon/client/HttpClient.scala | 2 ++ .../ldbc/amazon/client/SimpleHttpClient.scala | 36 ++++++++++++++----- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala index d9e48f52e..3a6140dc7 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala @@ -11,3 +11,5 @@ import java.net.URI trait HttpClient[F[_]]: def get(uri: URI, headers: Map[String, String]): F[HttpResponse] + + def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 2ac785d76..9c36e75c6 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -37,7 +37,18 @@ class SimpleHttpClient[F[_]: Network: Async]( for address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, path, headers) + response <- makeRequest(address, host, port, "GET", path, headers, None) + yield response + + override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + val host = uri.getHost + val port = if uri.getPort > 0 then uri.getPort else 80 + val path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + + for + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, "PUT", path, headers, Some(body)) yield response private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = @@ -48,20 +59,27 @@ class SimpleHttpClient[F[_]: Network: Async]( private def sendRequest( socket: Socket[F], + method: String, host: String, port: Int, path: String, - headers: Map[String, String] + headers: Map[String, String], + body: Option[String] ): F[Unit] = val hostHeader = if port == 80 then host else s"$host:$port" - val allHeaders = headers + ("Host" -> hostHeader) + ("Connection" -> "close") + val contentHeaders = body match { + case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) + case None => Map.empty + } + val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") - val requestLine = s"GET $path HTTP/1.1\r\n" + val requestLine = s"$method $path HTTP/1.1\r\n" val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString - val request = requestLine + headerLines + "\r\n" + val requestWithHeaders = requestLine + headerLines + "\r\n" + val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) Stream - .emit(request) + .emit(fullRequest) .through(text.utf8.encode) .through(socket.writes) .compile @@ -106,14 +124,16 @@ class SimpleHttpClient[F[_]: Network: Async]( address: SocketAddress[Host], host: String, port: Int, + method: String, path: String, - headers: Map[String, String] + headers: Map[String, String], + body: Option[String] ): F[HttpResponse] = Network[F] .client(address) .use { socket => for - _ <- sendRequest(socket, host, port, path, headers) + _ <- sendRequest(socket, method, host, port, path, headers, body) response <- receiveResponse(socket) yield response } From cf677f7c35bea70399f7bd74d88e39764ea8f040 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 22:26:05 +0900 Subject: [PATCH 096/215] Create InstanceProfileCredentialsProvider --- .../InstanceProfileCredentialsProvider.scala | 311 ++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala new file mode 100644 index 000000000..bcb0e4d8e --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -0,0 +1,311 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import scala.concurrent.duration.* + +import cats.effect.std.Env +import cats.effect.{Async, Concurrent, Ref} +import cats.syntax.all.* + +import fs2.io.net.* + +import io.circe.* +import io.circe.parser.* + +import ldbc.amazon.client.{HttpClient, SimpleHttpClient} +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.client.* + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from EC2 Instance Metadata Service (IMDS). + * + * This provider is used on Amazon EC2 instances with attached IAM instance profiles. + * It automatically retrieves temporary credentials from the EC2 metadata service at: + * http://169.254.169.254/latest/meta-data/iam/security-credentials/ + * + * The provider supports both IMDSv1 (legacy) and IMDSv2 (recommended) protocols: + * - IMDSv2: Uses session tokens for enhanced security against SSRF attacks + * - IMDSv1: Fallback mode for backward compatibility + * + * Environment variables: + * - AWS_EC2_METADATA_DISABLED: Set to 'true' to disable IMDS credential provider + * - AWS_EC2_METADATA_SERVICE_ENDPOINT: Override default IMDS endpoint (for testing) + * + * Authentication flow: + * 1. Optionally acquire IMDSv2 session token (PUT request with TTL header) + * 2. List available IAM roles from metadata service + * 3. Retrieve credentials for the first available role + * 4. Cache credentials until 4 minutes before expiration + * 5. Automatically refresh credentials in the background + * + * @param httpClient HTTP client for metadata service requests + * @param credentialsRef Mutable reference for credential caching + * @tparam F The effect type + */ +final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( + httpClient: HttpClient[F], + credentialsRef: Ref[F, Option[CachedCredentials]] +) extends AwsCredentialsProvider[F]: + + private val DEFAULT_IMD_SEND_POINT = "http://169.254.169.254" + private val METADATA_TOKEN_TTL_SECONDS = 21600 // 6 hours + private val CREDENTIAL_REFRESH_BUFFER = 4.minutes + + override def resolveCredentials(): F[AwsCredentials] = + for + disabled <- checkIfDisabled() + _ <- Concurrent[F].raiseWhen(disabled)( + new SdkClientException("EC2 metadata service is disabled via AWS_EC2_METADATA_DISABLED") + ) + cached <- credentialsRef.get + credentials <- cached match { + case Some(creds) if !isExpiringSoon(creds) => + Concurrent[F].pure(creds.credentials) + case _ => + refreshCredentials() + } + yield credentials + + private def checkIfDisabled(): F[Boolean] = + Env[F].get("AWS_EC2_METADATA_DISABLED").map { + case Some(value) => value.toLowerCase == "true" + case None => false + } + + private def refreshCredentials(): F[AwsCredentials] = + for + endpoint <- getImdsEndpoint() + token <- acquireMetadataToken(endpoint).attempt.map(_.toOption) + roleName <- getRoleName(endpoint, token) + credentials <- getCredentialsForRole(endpoint, token, roleName) + cached = CachedCredentials(credentials, Instant.now()) + _ <- credentialsRef.set(Some(cached)) + yield credentials + + private def getImdsEndpoint(): F[String] = + Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").map { + case Some(endpoint) => endpoint.stripSuffix("/") + case None => DEFAULT_IMD_SEND_POINT + } + + private def acquireMetadataToken(endpoint: String): F[String] = + val tokenUrl = s"$endpoint/latest/api/token" + val headers = Map( + "X-aws-ec2-metadata-token-ttl-seconds" -> METADATA_TOKEN_TTL_SECONDS.toString + ) + + for + response <- httpClient.put(URI.create(tokenUrl), headers, "") + _ <- validateHttpResponse(response, "Failed to acquire metadata token") + yield response.body.trim + + private def getRoleName(endpoint: String, token: Option[String]): F[String] = + val roleUrl = s"$endpoint/latest/meta-data/iam/security-credentials/" + val headers = buildRequestHeaders(token) + + for + response <- httpClient.get(URI.create(roleUrl), headers) + _ <- validateHttpResponse(response, "Failed to list IAM roles") + roleName <- parseRoleListResponse(response.body) + yield roleName + + private def getCredentialsForRole( + endpoint: String, + token: Option[String], + roleName: String + ): F[AwsCredentials] = + val credentialsUrl = s"$endpoint/latest/meta-data/iam/security-credentials/$roleName" + val headers = buildRequestHeaders(token) + + for + response <- httpClient.get(URI.create(credentialsUrl), headers) + _ <- validateHttpResponse(response, s"Failed to retrieve credentials for role $roleName") + credentials <- parseCredentialsResponse(response.body, roleName) + yield credentials + + private def buildRequestHeaders(token: Option[String]): Map[String, String] = + val baseHeaders = Map( + "Accept" -> "application/json", + "User-Agent" -> "aws-sdk-scala/ldbc" + ) + token match { + case Some(t) => baseHeaders + ("X-aws-ec2-metadata-token" -> t) + case None => baseHeaders + } + + private def validateHttpResponse(response: HttpResponse, context: String): F[Unit] = + response.statusCode match { + case code if code >= 200 && code < 300 => Concurrent[F].unit + case 401 => Concurrent[F].raiseError( + new SdkClientException(s"$context: Unauthorized (401) - Invalid or expired metadata token") + ) + case 403 => Concurrent[F].raiseError( + new SdkClientException(s"$context: Forbidden (403) - No instance profile attached") + ) + case 404 => Concurrent[F].raiseError( + new SdkClientException(s"$context: Not Found (404) - Instance metadata not available") + ) + case code => Concurrent[F].raiseError( + new SdkClientException(s"$context: HTTP $code - ${response.body}") + ) + } + + private def parseRoleListResponse(body: String): F[String] = + val roles = body.trim.split('\n').map(_.trim).filter(_.nonEmpty) + roles.headOption match { + case Some(roleName) => Concurrent[F].pure(roleName) + case None => Concurrent[F].raiseError( + new SdkClientException("No IAM roles found in instance metadata") + ) + } + + private def parseCredentialsResponse(jsonBody: String, roleName: String): F[AwsCredentials] = + Concurrent[F].fromEither( + parse(jsonBody).flatMap(_.as[InstanceMetadataCredentialsResponse]) + ).flatMap { response => + if (response.Code == "Success") { + Concurrent[F].pure(AwsSessionCredentials( + accessKeyId = response.AccessKeyId, + secretAccessKey = response.SecretAccessKey, + sessionToken = response.Token, + validateCredentials = false, + providerName = Some(BusinessMetricFeatureId.CREDENTIALS_IMDS.code), + accountId = extractAccountIdFromArn(response.AccessKeyId), + expirationTime = Some(Instant.parse(response.Expiration)) + )) + } else { + Concurrent[F].raiseError( + new SdkClientException(s"Failed to retrieve credentials for role $roleName: ${response.Code}") + ) + } + }.adaptError { case ex => + new SdkClientException(s"Failed to parse instance metadata credentials response: ${ex.getMessage}") + } + + private def extractAccountIdFromArn(accessKeyId: String): Option[String] = + // For instance profile credentials, we don't have the account ID directly + // The access key ID pattern is AKIA... for long-term or ASIA... for temporary credentials + None + + private def isExpiringSoon(cached: CachedCredentials): Boolean = + cached.credentials.expirationTime match { + case Some(expiration) => + val now = Instant.now() + val bufferTime = expiration.minusSeconds(CREDENTIAL_REFRESH_BUFFER.toSeconds) + now.isAfter(bufferTime) + case None => false + } + +/** + * Cached credentials with retrieval timestamp. + */ +private case class CachedCredentials( + credentials: AwsCredentials, + retrievedAt: Instant +) + +/** + * JSON response from EC2 instance metadata service. + */ +private case class InstanceMetadataCredentialsResponse( + Code: String, + LastUpdated: String, + Type: String, + AccessKeyId: String, + SecretAccessKey: String, + Token: String, + Expiration: String +) + +private object InstanceMetadataCredentialsResponse: + given Decoder[InstanceMetadataCredentialsResponse] = Decoder.forProduct7( + "Code", + "LastUpdated", + "Type", + "AccessKeyId", + "SecretAccessKey", + "Token", + "Expiration" + )(InstanceMetadataCredentialsResponse.apply) + +object InstanceProfileCredentialsProvider: + + /** + * Creates a new Instance Profile credentials provider with default settings. + * + * @tparam F The effect type + * @return A new InstanceProfileCredentialsProvider instance + */ + def apply[F[_]: Env: Network: Async](): F[InstanceProfileCredentialsProvider[F]] = + for + httpClient <- createDefaultHttpClient[F]() + credentialsRef <- Ref.of[F, Option[CachedCredentials]](None) + yield new InstanceProfileCredentialsProvider[F](httpClient, credentialsRef) + + /** + * Creates a new Instance Profile credentials provider with custom HTTP client. + * + * @param httpClient The HTTP client for metadata service requests + * @tparam F The effect type + * @return A new InstanceProfileCredentialsProvider instance + */ + def create[F[_]: Env: Concurrent]( + httpClient: HttpClient[F] + ): F[InstanceProfileCredentialsProvider[F]] = + Ref.of[F, Option[CachedCredentials]](None).map { credentialsRef => + new InstanceProfileCredentialsProvider[F](httpClient, credentialsRef) + } + + /** + * Creates a default HTTP client optimized for EC2 metadata service. + * + * @tparam F The effect type + * @return A configured HTTP client + */ + private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = + Async[F].pure(new SimpleHttpClient[F]( + connectTimeout = 2.seconds, + readTimeout = 5.seconds + )) + + /** + * Checks if Instance Profile credentials are available by attempting + * to connect to the EC2 metadata service. + * + * @tparam F The effect type + * @return true if the metadata service is reachable + */ + def isAvailable[F[_]: Env: Network: Async](): F[Boolean] = + for + disabled <- Env[F].get("AWS_EC2_METADATA_DISABLED").map { + case Some(value) => value.toLowerCase == "true" + case None => false + } + available <- if (disabled) { + Async[F].pure(false) + } else { + checkMetadataServiceAvailability[F]() + } + yield available + + private def checkMetadataServiceAvailability[F[_]: Network: Async](): F[Boolean] = + val httpClient = new SimpleHttpClient[F]( + connectTimeout = 1.second, + readTimeout = 2.seconds + ) + + httpClient.get( + URI.create("http://169.254.169.254/latest/meta-data/"), + Map.empty + ).map(_.statusCode == 200) + .handleErrorWith(_ => Async[F].pure(false)) \ No newline at end of file From e1bc8d2af6417ef1c83113af1c60e308290529bf Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 22:33:20 +0900 Subject: [PATCH 097/215] Create DefaultCredentialsProviderChain --- .../DefaultCredentialsProviderChain.scala | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index 733239489..b9a78e492 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -4,18 +4,20 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ -/* package ldbc.amazon.auth.credentials -import cats.MonadThrow +import scala.concurrent.duration.* + import cats.effect.std.{Env, SystemProperties} -import cats.effect.{Concurrent, Network} +import cats.effect.* import cats.syntax.all.* import fs2.io.file.Files +import fs2.io.net.* import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* +import ldbc.amazon.client.* /** * Default AWS credentials provider chain that matches AWS SDK v2 behavior. @@ -41,20 +43,20 @@ import ldbc.amazon.identity.* * * @tparam F The effect type */ -class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent] +class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Concurrent](httpClient: HttpClient[F], region: String) extends AwsCredentialsProvider[F]: private lazy val providers: F[List[AwsCredentialsProvider[F]]] = for - webIdentityProvider <- WebIdentityTokenFileCredentialsProvider[F]() profileProvider <- ProfileCredentialsProvider.default[F]() + instanceProfileCredentialsProvider <- InstanceProfileCredentialsProvider.create[F](httpClient) yield List( new SystemPropertyCredentialsProvider[F](), new EnvironmentVariableCredentialsProvider[F](), - webIdentityProvider, - profileProvider - // ContainerCredentialsProvider - TODO: implement - // InstanceProfileCredentialsProvider - TODO: implement + WebIdentityTokenFileCredentialsProvider.default[F](httpClient, region), + profileProvider, + ContainerCredentialsProvider.create[F](httpClient), + instanceProfileCredentialsProvider ) override def resolveCredentials(): F[AwsCredentials] = @@ -69,7 +71,7 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Networ ): F[AwsCredentials] = providers match case Nil => - MonadThrow[F].raiseError(new SdkClientException( + Concurrent[F].raiseError(new SdkClientException( s"Unable to load AWS credentials from any provider in the chain: ${exceptions.mkString(", ")}" )) @@ -87,15 +89,9 @@ object DefaultCredentialsProviderChain: * @tparam F The effect type * @return A new DefaultCredentialsProviderChain instance */ - def apply[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent](): DefaultCredentialsProviderChain[F] = - new DefaultCredentialsProviderChain[F]() - - /** - * Convenience method to resolve credentials using the default chain. - * - * @tparam F The effect type - * @return AWS credentials from the first successful provider - */ - def resolveCredentials[F[_]: Files: Env: SystemProperties: Network: MonadThrow: Concurrent](): F[AwsCredentials] = - DefaultCredentialsProviderChain[F]().resolveCredentials() - */ + def default[F[_]: Files: Env: SystemProperties: Network: Async](region: String): DefaultCredentialsProviderChain[F] = + val httpClient = new SimpleHttpClient[F]( + connectTimeout = 1.second, + readTimeout = 2.seconds + ) + new DefaultCredentialsProviderChain[F](httpClient, region) From e519b97cd08b870fda4e3c329560e7abf0da9533 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 22:34:07 +0900 Subject: [PATCH 098/215] Action sbt scalafmtAll --- .../ContainerCredentialsProvider.scala | 150 +++++++------ .../DefaultCredentialsProviderChain.scala | 33 +-- .../InstanceProfileCredentialsProvider.scala | 210 ++++++++++-------- .../ldbc/amazon/client/SimpleHttpClient.scala | 10 +- 4 files changed, 221 insertions(+), 182 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala index ba3f267eb..bcdea91dd 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -11,17 +11,18 @@ import java.time.Instant import scala.concurrent.duration.* -import cats.effect.std.Env -import cats.effect.{Async, Concurrent} import cats.syntax.all.* -import fs2.io.file.{Files, Path} +import cats.effect.{ Async, Concurrent } +import cats.effect.std.Env + +import fs2.io.file.{ Files, Path } import fs2.io.net.* import io.circe.* import io.circe.parser.* -import ldbc.amazon.client.{HttpClient, SimpleHttpClient} +import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.useragent.BusinessMetricFeatureId @@ -58,59 +59,68 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( override def resolveCredentials(): F[AwsCredentials] = for - config <- loadContainerCredentialsConfig() + config <- loadContainerCredentialsConfig() credentials <- config match { - case None => - Concurrent[F].raiseError(new SdkClientException( - "Unable to load container credentials. " + - "Environment variables AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or " + - "AWS_CONTAINER_CREDENTIALS_FULL_URI are not set." - )) - case Some(containerConfig) => - fetchCredentialsFromEndpoint(containerConfig) - } + case None => + Concurrent[F].raiseError( + new SdkClientException( + "Unable to load container credentials. " + + "Environment variables AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or " + + "AWS_CONTAINER_CREDENTIALS_FULL_URI are not set." + ) + ) + case Some(containerConfig) => + fetchCredentialsFromEndpoint(containerConfig) + } yield credentials private def loadContainerCredentialsConfig(): F[Option[ContainerCredentialsConfig]] = for relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") - fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") directToken <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN") - tokenFile <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") - token <- loadAuthorizationToken(directToken, tokenFile) + tokenFile <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") + token <- loadAuthorizationToken(directToken, tokenFile) yield (relativeUri, fullUri) match { case (Some(relative), _) => - Some(ContainerCredentialsConfig( - endpointUri = s"http://169.254.170.2$relative", - authorizationToken = token - )) + Some( + ContainerCredentialsConfig( + endpointUri = s"http://169.254.170.2$relative", + authorizationToken = token + ) + ) case (_, Some(full)) => - Some(ContainerCredentialsConfig( - endpointUri = full, - authorizationToken = token - )) + Some( + ContainerCredentialsConfig( + endpointUri = full, + authorizationToken = token + ) + ) case _ => None } private def loadAuthorizationToken( - directToken: Option[String], + directToken: Option[String], tokenFilePath: Option[String] ): F[Option[String]] = (directToken, tokenFilePath) match { - case (Some(token), _) => + case (Some(token), _) => Concurrent[F].pure(Some(token.trim).filter(_.nonEmpty)) case (_, Some(filePath)) => loadTokenFromFile(Path(filePath)) - case _ => + case _ => Concurrent[F].pure(None) } private def loadTokenFromFile(tokenFilePath: Path): F[Option[String]] = Files[F].exists(tokenFilePath).flatMap { exists => - if (exists) { - Files[F].readUtf8(tokenFilePath).compile.string + if exists then { + Files[F] + .readUtf8(tokenFilePath) + .compile + .string .map(_.trim) - .map(token => if (token.nonEmpty) Some(token) else None) + .map(token => if token.nonEmpty then Some(token) else None) .handleErrorWith { _ => Concurrent[F].pure(None) } @@ -122,52 +132,58 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( private def fetchCredentialsFromEndpoint(config: ContainerCredentialsConfig): F[AwsCredentials] = val headers = buildRequestHeaders(config.authorizationToken) for - response <- httpClient.get(URI.create(config.endpointUri), headers) - _ <- validateHttpResponse(response) + response <- httpClient.get(URI.create(config.endpointUri), headers) + _ <- validateHttpResponse(response) credentials <- parseCredentialsResponse(response.body) yield credentials private def buildRequestHeaders(authToken: Option[String]): Map[String, String] = val baseHeaders = Map( - "Accept" -> "application/json", + "Accept" -> "application/json", "User-Agent" -> "aws-sdk-scala/ldbc" ) authToken match { case Some(token) => baseHeaders + ("Authorization" -> token) - case None => baseHeaders + case None => baseHeaders } private def validateHttpResponse(response: ldbc.amazon.client.HttpResponse): F[Unit] = - if (response.statusCode >= 200 && response.statusCode < 300) { + if response.statusCode >= 200 && response.statusCode < 300 then { Concurrent[F].unit } else { - Concurrent[F].raiseError(new SdkClientException( - s"Container credentials request failed with status ${response.statusCode}: ${response.body}" - )) + Concurrent[F].raiseError( + new SdkClientException( + s"Container credentials request failed with status ${ response.statusCode }: ${ response.body }" + ) + ) } private def parseCredentialsResponse(jsonBody: String): F[AwsCredentials] = - Concurrent[F].fromEither( - parse(jsonBody).flatMap(_.as[ContainerCredentialsResponse]) - ).map { response => - AwsSessionCredentials( - accessKeyId = response.AccessKeyId, - secretAccessKey = response.SecretAccessKey, - sessionToken = response.Token, - validateCredentials = false, - providerName = Some(BusinessMetricFeatureId.CREDENTIALS_CONTAINER.code), - accountId = extractAccountIdFromRoleArn(response.RoleArn), - expirationTime = Some(Instant.parse(response.Expiration)) + Concurrent[F] + .fromEither( + parse(jsonBody).flatMap(_.as[ContainerCredentialsResponse]) ) - }.adaptError { case ex => - new SdkClientException(s"Failed to parse container credentials response: ${ex.getMessage}") - } + .map { response => + AwsSessionCredentials( + accessKeyId = response.AccessKeyId, + secretAccessKey = response.SecretAccessKey, + sessionToken = response.Token, + validateCredentials = false, + providerName = Some(BusinessMetricFeatureId.CREDENTIALS_CONTAINER.code), + accountId = extractAccountIdFromRoleArn(response.RoleArn), + expirationTime = Some(Instant.parse(response.Expiration)) + ) + } + .adaptError { + case ex => + new SdkClientException(s"Failed to parse container credentials response: ${ ex.getMessage }") + } private def extractAccountIdFromRoleArn(roleArn: Option[String]): Option[String] = roleArn.flatMap { arn => // ARN format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME val arnParts = arn.split(":") - if (arnParts.length >= 5) { + if arnParts.length >= 5 then { Some(arnParts(4)) } else { None @@ -181,7 +197,7 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( * @param authorizationToken Optional authorization token for requests */ private case class ContainerCredentialsConfig( - endpointUri: String, + endpointUri: String, authorizationToken: Option[String] ) @@ -195,17 +211,17 @@ private case class ContainerCredentialsConfig( * @param RoleArn Optional ARN of the assumed role */ private case class ContainerCredentialsResponse( - AccessKeyId: String, + AccessKeyId: String, SecretAccessKey: String, - Token: String, - Expiration: String, - RoleArn: Option[String] = None + Token: String, + Expiration: String, + RoleArn: Option[String] = None ) private object ContainerCredentialsResponse: given Decoder[ContainerCredentialsResponse] = Decoder.forProduct5( "AccessKeyId", - "SecretAccessKey", + "SecretAccessKey", "Token", "Expiration", "RoleArn" @@ -241,10 +257,12 @@ object ContainerCredentialsProvider: * @return A configured HTTP client */ private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = - Async[F].pure(new SimpleHttpClient[F]( - connectTimeout = 5.seconds, - readTimeout = 5.seconds - )) + Async[F].pure( + new SimpleHttpClient[F]( + connectTimeout = 5.seconds, + readTimeout = 5.seconds + ) + ) /** * Checks if Container credentials are available by verifying @@ -256,5 +274,5 @@ object ContainerCredentialsProvider: def isAvailable[F[_]: Env: Concurrent](): F[Boolean] = for relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") - fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") - yield relativeUri.exists(_.trim.nonEmpty) || fullUri.exists(_.trim.nonEmpty) \ No newline at end of file + fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + yield relativeUri.exists(_.trim.nonEmpty) || fullUri.exists(_.trim.nonEmpty) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index b9a78e492..6e96e057a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -8,16 +8,17 @@ package ldbc.amazon.auth.credentials import scala.concurrent.duration.* -import cats.effect.std.{Env, SystemProperties} -import cats.effect.* import cats.syntax.all.* +import cats.effect.* +import cats.effect.std.{ Env, SystemProperties } + import fs2.io.file.Files import fs2.io.net.* +import ldbc.amazon.client.* import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* -import ldbc.amazon.client.* /** * Default AWS credentials provider chain that matches AWS SDK v2 behavior. @@ -43,12 +44,14 @@ import ldbc.amazon.client.* * * @tparam F The effect type */ -class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Concurrent](httpClient: HttpClient[F], region: String) - extends AwsCredentialsProvider[F]: +class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Concurrent]( + httpClient: HttpClient[F], + region: String +) extends AwsCredentialsProvider[F]: private lazy val providers: F[List[AwsCredentialsProvider[F]]] = for - profileProvider <- ProfileCredentialsProvider.default[F]() + profileProvider <- ProfileCredentialsProvider.default[F]() instanceProfileCredentialsProvider <- InstanceProfileCredentialsProvider.create[F](httpClient) yield List( new SystemPropertyCredentialsProvider[F](), @@ -62,22 +65,24 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Concur override def resolveCredentials(): F[AwsCredentials] = for providerList <- providers - credentials <- tryProvidersInOrder(providerList, Nil) + credentials <- tryProvidersInOrder(providerList, Nil) yield credentials private def tryProvidersInOrder( - providers: List[AwsCredentialsProvider[F]], + providers: List[AwsCredentialsProvider[F]], exceptions: List[String] ): F[AwsCredentials] = providers match case Nil => - Concurrent[F].raiseError(new SdkClientException( - s"Unable to load AWS credentials from any provider in the chain: ${exceptions.mkString(", ")}" - )) - + Concurrent[F].raiseError( + new SdkClientException( + s"Unable to load AWS credentials from any provider in the chain: ${ exceptions.mkString(", ") }" + ) + ) + case provider :: remainingProviders => provider.resolveCredentials().recoverWith { ex => - val errorMsg = s"${provider.getClass.getSimpleName}: ${ex.getMessage}" + val errorMsg = s"${ provider.getClass.getSimpleName }: ${ ex.getMessage }" tryProvidersInOrder(remainingProviders, exceptions :+ errorMsg) } @@ -92,6 +97,6 @@ object DefaultCredentialsProviderChain: def default[F[_]: Files: Env: SystemProperties: Network: Async](region: String): DefaultCredentialsProviderChain[F] = val httpClient = new SimpleHttpClient[F]( connectTimeout = 1.second, - readTimeout = 2.seconds + readTimeout = 2.seconds ) new DefaultCredentialsProviderChain[F](httpClient, region) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala index bcb0e4d8e..6b193a4e8 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -11,20 +11,21 @@ import java.time.Instant import scala.concurrent.duration.* -import cats.effect.std.Env -import cats.effect.{Async, Concurrent, Ref} import cats.syntax.all.* +import cats.effect.{ Async, Concurrent, Ref } +import cats.effect.std.Env + import fs2.io.net.* import io.circe.* import io.circe.parser.* -import ldbc.amazon.client.{HttpClient, SimpleHttpClient} +import ldbc.amazon.client.* +import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.useragent.BusinessMetricFeatureId -import ldbc.amazon.client.* /** * [[AwsCredentialsProvider]] implementation that loads credentials from EC2 Instance Metadata Service (IMDS). @@ -53,40 +54,40 @@ import ldbc.amazon.client.* * @tparam F The effect type */ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( - httpClient: HttpClient[F], + httpClient: HttpClient[F], credentialsRef: Ref[F, Option[CachedCredentials]] ) extends AwsCredentialsProvider[F]: - private val DEFAULT_IMD_SEND_POINT = "http://169.254.169.254" + private val DEFAULT_IMD_SEND_POINT = "http://169.254.169.254" private val METADATA_TOKEN_TTL_SECONDS = 21600 // 6 hours - private val CREDENTIAL_REFRESH_BUFFER = 4.minutes + private val CREDENTIAL_REFRESH_BUFFER = 4.minutes override def resolveCredentials(): F[AwsCredentials] = for disabled <- checkIfDisabled() - _ <- Concurrent[F].raiseWhen(disabled)( - new SdkClientException("EC2 metadata service is disabled via AWS_EC2_METADATA_DISABLED") - ) - cached <- credentialsRef.get + _ <- Concurrent[F].raiseWhen(disabled)( + new SdkClientException("EC2 metadata service is disabled via AWS_EC2_METADATA_DISABLED") + ) + cached <- credentialsRef.get credentials <- cached match { - case Some(creds) if !isExpiringSoon(creds) => - Concurrent[F].pure(creds.credentials) - case _ => - refreshCredentials() - } + case Some(creds) if !isExpiringSoon(creds) => + Concurrent[F].pure(creds.credentials) + case _ => + refreshCredentials() + } yield credentials private def checkIfDisabled(): F[Boolean] = Env[F].get("AWS_EC2_METADATA_DISABLED").map { case Some(value) => value.toLowerCase == "true" - case None => false + case None => false } private def refreshCredentials(): F[AwsCredentials] = for - endpoint <- getImdsEndpoint() - token <- acquireMetadataToken(endpoint).attempt.map(_.toOption) - roleName <- getRoleName(endpoint, token) + endpoint <- getImdsEndpoint() + token <- acquireMetadataToken(endpoint).attempt.map(_.toOption) + roleName <- getRoleName(endpoint, token) credentials <- getCredentialsForRole(endpoint, token, roleName) cached = CachedCredentials(credentials, Instant.now()) _ <- credentialsRef.set(Some(cached)) @@ -95,102 +96,113 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( private def getImdsEndpoint(): F[String] = Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").map { case Some(endpoint) => endpoint.stripSuffix("/") - case None => DEFAULT_IMD_SEND_POINT + case None => DEFAULT_IMD_SEND_POINT } private def acquireMetadataToken(endpoint: String): F[String] = val tokenUrl = s"$endpoint/latest/api/token" - val headers = Map( + val headers = Map( "X-aws-ec2-metadata-token-ttl-seconds" -> METADATA_TOKEN_TTL_SECONDS.toString ) - + for response <- httpClient.put(URI.create(tokenUrl), headers, "") - _ <- validateHttpResponse(response, "Failed to acquire metadata token") + _ <- validateHttpResponse(response, "Failed to acquire metadata token") yield response.body.trim private def getRoleName(endpoint: String, token: Option[String]): F[String] = val roleUrl = s"$endpoint/latest/meta-data/iam/security-credentials/" val headers = buildRequestHeaders(token) - + for response <- httpClient.get(URI.create(roleUrl), headers) - _ <- validateHttpResponse(response, "Failed to list IAM roles") + _ <- validateHttpResponse(response, "Failed to list IAM roles") roleName <- parseRoleListResponse(response.body) yield roleName private def getCredentialsForRole( - endpoint: String, - token: Option[String], + endpoint: String, + token: Option[String], roleName: String ): F[AwsCredentials] = val credentialsUrl = s"$endpoint/latest/meta-data/iam/security-credentials/$roleName" - val headers = buildRequestHeaders(token) - + val headers = buildRequestHeaders(token) + for - response <- httpClient.get(URI.create(credentialsUrl), headers) - _ <- validateHttpResponse(response, s"Failed to retrieve credentials for role $roleName") + response <- httpClient.get(URI.create(credentialsUrl), headers) + _ <- validateHttpResponse(response, s"Failed to retrieve credentials for role $roleName") credentials <- parseCredentialsResponse(response.body, roleName) yield credentials private def buildRequestHeaders(token: Option[String]): Map[String, String] = val baseHeaders = Map( - "Accept" -> "application/json", + "Accept" -> "application/json", "User-Agent" -> "aws-sdk-scala/ldbc" ) token match { case Some(t) => baseHeaders + ("X-aws-ec2-metadata-token" -> t) - case None => baseHeaders + case None => baseHeaders } private def validateHttpResponse(response: HttpResponse, context: String): F[Unit] = response.statusCode match { case code if code >= 200 && code < 300 => Concurrent[F].unit - case 401 => Concurrent[F].raiseError( - new SdkClientException(s"$context: Unauthorized (401) - Invalid or expired metadata token") - ) - case 403 => Concurrent[F].raiseError( - new SdkClientException(s"$context: Forbidden (403) - No instance profile attached") - ) - case 404 => Concurrent[F].raiseError( - new SdkClientException(s"$context: Not Found (404) - Instance metadata not available") - ) - case code => Concurrent[F].raiseError( - new SdkClientException(s"$context: HTTP $code - ${response.body}") - ) + case 401 => + Concurrent[F].raiseError( + new SdkClientException(s"$context: Unauthorized (401) - Invalid or expired metadata token") + ) + case 403 => + Concurrent[F].raiseError( + new SdkClientException(s"$context: Forbidden (403) - No instance profile attached") + ) + case 404 => + Concurrent[F].raiseError( + new SdkClientException(s"$context: Not Found (404) - Instance metadata not available") + ) + case code => + Concurrent[F].raiseError( + new SdkClientException(s"$context: HTTP $code - ${ response.body }") + ) } private def parseRoleListResponse(body: String): F[String] = val roles = body.trim.split('\n').map(_.trim).filter(_.nonEmpty) roles.headOption match { case Some(roleName) => Concurrent[F].pure(roleName) - case None => Concurrent[F].raiseError( - new SdkClientException("No IAM roles found in instance metadata") - ) + case None => + Concurrent[F].raiseError( + new SdkClientException("No IAM roles found in instance metadata") + ) } private def parseCredentialsResponse(jsonBody: String, roleName: String): F[AwsCredentials] = - Concurrent[F].fromEither( - parse(jsonBody).flatMap(_.as[InstanceMetadataCredentialsResponse]) - ).flatMap { response => - if (response.Code == "Success") { - Concurrent[F].pure(AwsSessionCredentials( - accessKeyId = response.AccessKeyId, - secretAccessKey = response.SecretAccessKey, - sessionToken = response.Token, - validateCredentials = false, - providerName = Some(BusinessMetricFeatureId.CREDENTIALS_IMDS.code), - accountId = extractAccountIdFromArn(response.AccessKeyId), - expirationTime = Some(Instant.parse(response.Expiration)) - )) - } else { - Concurrent[F].raiseError( - new SdkClientException(s"Failed to retrieve credentials for role $roleName: ${response.Code}") - ) + Concurrent[F] + .fromEither( + parse(jsonBody).flatMap(_.as[InstanceMetadataCredentialsResponse]) + ) + .flatMap { response => + if response.Code == "Success" then { + Concurrent[F].pure( + AwsSessionCredentials( + accessKeyId = response.AccessKeyId, + secretAccessKey = response.SecretAccessKey, + sessionToken = response.Token, + validateCredentials = false, + providerName = Some(BusinessMetricFeatureId.CREDENTIALS_IMDS.code), + accountId = extractAccountIdFromArn(response.AccessKeyId), + expirationTime = Some(Instant.parse(response.Expiration)) + ) + ) + } else { + Concurrent[F].raiseError( + new SdkClientException(s"Failed to retrieve credentials for role $roleName: ${ response.Code }") + ) + } + } + .adaptError { + case ex => + new SdkClientException(s"Failed to parse instance metadata credentials response: ${ ex.getMessage }") } - }.adaptError { case ex => - new SdkClientException(s"Failed to parse instance metadata credentials response: ${ex.getMessage}") - } private def extractAccountIdFromArn(accessKeyId: String): Option[String] = // For instance profile credentials, we don't have the account ID directly @@ -200,7 +212,7 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( private def isExpiringSoon(cached: CachedCredentials): Boolean = cached.credentials.expirationTime match { case Some(expiration) => - val now = Instant.now() + val now = Instant.now() val bufferTime = expiration.minusSeconds(CREDENTIAL_REFRESH_BUFFER.toSeconds) now.isAfter(bufferTime) case None => false @@ -218,20 +230,20 @@ private case class CachedCredentials( * JSON response from EC2 instance metadata service. */ private case class InstanceMetadataCredentialsResponse( - Code: String, - LastUpdated: String, - Type: String, - AccessKeyId: String, + Code: String, + LastUpdated: String, + Type: String, + AccessKeyId: String, SecretAccessKey: String, - Token: String, - Expiration: String + Token: String, + Expiration: String ) private object InstanceMetadataCredentialsResponse: given Decoder[InstanceMetadataCredentialsResponse] = Decoder.forProduct7( "Code", "LastUpdated", - "Type", + "Type", "AccessKeyId", "SecretAccessKey", "Token", @@ -248,7 +260,7 @@ object InstanceProfileCredentialsProvider: */ def apply[F[_]: Env: Network: Async](): F[InstanceProfileCredentialsProvider[F]] = for - httpClient <- createDefaultHttpClient[F]() + httpClient <- createDefaultHttpClient[F]() credentialsRef <- Ref.of[F, Option[CachedCredentials]](None) yield new InstanceProfileCredentialsProvider[F](httpClient, credentialsRef) @@ -273,10 +285,12 @@ object InstanceProfileCredentialsProvider: * @return A configured HTTP client */ private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = - Async[F].pure(new SimpleHttpClient[F]( - connectTimeout = 2.seconds, - readTimeout = 5.seconds - )) + Async[F].pure( + new SimpleHttpClient[F]( + connectTimeout = 2.seconds, + readTimeout = 5.seconds + ) + ) /** * Checks if Instance Profile credentials are available by attempting @@ -288,24 +302,26 @@ object InstanceProfileCredentialsProvider: def isAvailable[F[_]: Env: Network: Async](): F[Boolean] = for disabled <- Env[F].get("AWS_EC2_METADATA_DISABLED").map { - case Some(value) => value.toLowerCase == "true" - case None => false - } - available <- if (disabled) { - Async[F].pure(false) - } else { - checkMetadataServiceAvailability[F]() - } + case Some(value) => value.toLowerCase == "true" + case None => false + } + available <- if disabled then { + Async[F].pure(false) + } else { + checkMetadataServiceAvailability[F]() + } yield available private def checkMetadataServiceAvailability[F[_]: Network: Async](): F[Boolean] = val httpClient = new SimpleHttpClient[F]( connectTimeout = 1.second, - readTimeout = 2.seconds + readTimeout = 2.seconds ) - - httpClient.get( - URI.create("http://169.254.169.254/latest/meta-data/"), - Map.empty - ).map(_.statusCode == 200) - .handleErrorWith(_ => Async[F].pure(false)) \ No newline at end of file + + httpClient + .get( + URI.create("http://169.254.169.254/latest/meta-data/"), + Map.empty + ) + .map(_.statusCode == 200) + .handleErrorWith(_ => Async[F].pure(false)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 9c36e75c6..d644b3bda 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -66,17 +66,17 @@ class SimpleHttpClient[F[_]: Network: Async]( headers: Map[String, String], body: Option[String] ): F[Unit] = - val hostHeader = if port == 80 then host else s"$host:$port" + val hostHeader = if port == 80 then host else s"$host:$port" val contentHeaders = body match { case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) - case None => Map.empty + case None => Map.empty } val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") - val requestLine = s"$method $path HTTP/1.1\r\n" - val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString + val requestLine = s"$method $path HTTP/1.1\r\n" + val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString val requestWithHeaders = requestLine + headerLines + "\r\n" - val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) + val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) Stream .emit(fullRequest) From 6f055dc8dedcd1c5f1532868886b9abcf9297f8d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 3 Dec 2025 22:36:55 +0900 Subject: [PATCH 099/215] Check credentials type --- .../ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index e7ffe65cd..2e35964bd 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -20,6 +20,7 @@ import cats.effect.kernel.{ Clock, Sync } import fs2.{ Chunk, Stream } import fs2.hashing.{ HashAlgorithm, Hashing } +import ldbc.amazon.auth.credentials.* import ldbc.amazon.identity.AwsCredentials /** @@ -102,8 +103,11 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( date = dateTime.substring(0, 8) credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" credential = s"${ credentials.accessKeyId }/$credentialScope" - // queryParams = buildQueryParams(credential, dateTime, credentials.sessionToken, username) - queryParams = buildQueryParams(credential, dateTime, ???, username) + queryParams = credentials match { + case c: AwsBasicCredentials => buildQueryParams(credential, dateTime, None, username) + case c: AwsSessionCredentials => + buildQueryParams(credential, dateTime, Some(c.sessionToken), username) + } canonicalRequest = buildCanonicalRequest(s"$hostname:$port", queryParams) canonicalRequestHash <- sha256Hex(canonicalRequest) stringToSign = buildStringToSign(dateTime, credentialScope, canonicalRequestHash) From abf523c6806a4a96015496345c172c755fe2e4a2 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 00:48:00 +0900 Subject: [PATCH 100/215] Create SimpleXmlParser --- build.sbt | 3 +- .../ldbc/amazon/util/SimpleXmlParser.scala | 47 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala diff --git a/build.sbt b/build.sbt index 590b15a2f..305cf3838 100644 --- a/build.sbt +++ b/build.sbt @@ -148,7 +148,6 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP .module("aws-authentication-plugin", "") .settings( libraryDependencies ++= Seq( - "org.scala-lang.modules" %%% "scala-xml" % "2.2.0", "co.fs2" %%% "fs2-core" % "3.12.2", "co.fs2" %%% "fs2-io" % "3.12.2", "io.circe" %% "circe-parser" % "0.14.10", @@ -221,7 +220,7 @@ lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) ) .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) .nativeSettings(Test / nativeBrewFormulas += "s2n") - .dependsOn(connector, queryBuilder, schema) + .dependsOn(connector, queryBuilder, schema, awsAuthenticationPlugin) .enablePlugins(NoPublishPlugin) lazy val benchmark = (project in file("benchmark")) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala new file mode 100644 index 000000000..764458d2c --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala @@ -0,0 +1,47 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +object SimpleXmlParser: + def decodeXmlEntities(s: String): String = + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace(""", "\"") + .replace("'", "'") + + def extractTagContent(tagName: String, xml: String): Option[String] = { + val startTag = s"<$tagName>" + val endTag = s"" + val startIdx = xml.indexOf(startTag) + + if (startIdx < 0) None + else { + val contentStart = startIdx + startTag.length + val endIdx = xml.indexOf(endTag, contentStart) + if (endIdx < 0) None + else Some(decodeXmlEntities(xml.substring(contentStart, endIdx).trim)) + } + } + + def extractSection(tagName: String, xml: String): Option[String] = { + val startTag = s"<$tagName>" + val endTag = s"" + val startIdx = xml.indexOf(startTag) + + if (startIdx < 0) None + else { + val endIdx = xml.indexOf(endTag, startIdx) + if (endIdx < 0) None + else Some(xml.substring(startIdx, endIdx + endTag.length)) + } + } + + def requireTag(tagName: String, xml: String, errorMsg: String): String = + extractTagContent(tagName, xml) + .filter(_.nonEmpty) + .getOrElse(throw new IllegalArgumentException(errorMsg)) From 6cdbcf8234f15f4557c888278170f05b38334705 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 00:48:39 +0900 Subject: [PATCH 101/215] Change use scala-xml -> SimpleXmlParser --- .../scala/ldbc/amazon/client/StsClient.scala | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index fb5d72118..80114c2a9 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -11,14 +11,13 @@ import java.time.{ Instant, ZoneOffset } import java.time.format.DateTimeFormatter import java.util.UUID -import scala.xml.XML - import cats.syntax.all.* import cats.MonadThrow import cats.effect.Concurrent import ldbc.amazon.exception.{ SdkClientException, StsException } +import ldbc.amazon.util.SimpleXmlParser /** * Trait for AWS STS (Security Token Service) client operations. @@ -190,37 +189,27 @@ object StsClient: private def parseAssumeRoleResponse[F[_]: MonadThrow](xmlBody: String): F[AssumeRoleWithWebIdentityResponse] = MonadThrow[F] .catchNonFatal { - val xml = XML.loadString(xmlBody) - - // Extract credentials - val credentials = (xml \ "AssumeRoleWithWebIdentityResult" \ "Credentials").head - val accessKeyId = (credentials \ "AccessKeyId").text.trim - val secretAccessKey = (credentials \ "SecretAccessKey").text.trim - val sessionToken = (credentials \ "SessionToken").text.trim - val expirationStr = (credentials \ "Expiration").text.trim - - // Extract assumed role information - val assumedRoleUser = (xml \ "AssumeRoleWithWebIdentityResult" \ "AssumedRoleUser").head - val assumedRoleArn = (assumedRoleUser \ "Arn").text.trim - // Parse expiration time - val expiration = Instant.parse(expirationStr) + val accessKeyId = SimpleXmlParser.requireTag("AccessKeyId", xmlBody, "AccessKeyId not found") + val secretAccessKey = SimpleXmlParser.requireTag("SecretAccessKey", xmlBody, "SecretAccessKey not found") + val sessionToken = SimpleXmlParser.requireTag("SessionToken", xmlBody, "SessionToken not found") + val expirationStr = SimpleXmlParser.requireTag("Expiration", xmlBody, "Expiration not found") - // Validate required fields - if accessKeyId.isEmpty then throw new StsException("AccessKeyId not found in STS response") - if secretAccessKey.isEmpty then throw new StsException("SecretAccessKey not found in STS response") - if sessionToken.isEmpty then throw new StsException("SessionToken not found in STS response") - if assumedRoleArn.isEmpty then throw new StsException("AssumedRoleArn not found in STS response") + val assumedRoleArn = SimpleXmlParser.extractSection("AssumedRoleUser", xmlBody) + .flatMap(section => SimpleXmlParser.extractTagContent("Arn", section)) + .filter(_.nonEmpty) + .getOrElse(throw new StsException("AssumedRoleArn not found")) AssumeRoleWithWebIdentityResponse( - accessKeyId = accessKeyId, + accessKeyId = accessKeyId, secretAccessKey = secretAccessKey, - sessionToken = sessionToken, - expiration = expiration, - assumedRoleArn = assumedRoleArn + sessionToken = sessionToken, + expiration = Instant.parse(expirationStr), + assumedRoleArn = assumedRoleArn ) } .adaptError { + case ex: StsException => ex case ex => new StsException(s"Failed to parse STS response: ${ ex.getMessage }") } From 4f1182c3c2baabad622f3f69bf1e581161a5048d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 00:58:01 +0900 Subject: [PATCH 102/215] Change use java util UUID -> UUIDGen --- .../credentials/DefaultCredentialsProviderChain.scala | 6 +++--- .../WebIdentityTokenFileCredentialsProvider.scala | 6 +++--- .../internal/WebIdentityCredentialsUtils.scala | 3 ++- .../src/main/scala/ldbc/amazon/client/StsClient.scala | 8 ++++---- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index 6e96e057a..9c1e06fc8 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -11,7 +11,7 @@ import scala.concurrent.duration.* import cats.syntax.all.* import cats.effect.* -import cats.effect.std.{ Env, SystemProperties } +import cats.effect.std.{ Env, SystemProperties, UUIDGen } import fs2.io.file.Files import fs2.io.net.* @@ -44,7 +44,7 @@ import ldbc.amazon.identity.* * * @tparam F The effect type */ -class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: Concurrent]( +class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: UUIDGen: Concurrent]( httpClient: HttpClient[F], region: String ) extends AwsCredentialsProvider[F]: @@ -94,7 +94,7 @@ object DefaultCredentialsProviderChain: * @tparam F The effect type * @return A new DefaultCredentialsProviderChain instance */ - def default[F[_]: Files: Env: SystemProperties: Network: Async](region: String): DefaultCredentialsProviderChain[F] = + def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async](region: String): DefaultCredentialsProviderChain[F] = val httpClient = new SimpleHttpClient[F]( connectTimeout = 1.second, readTimeout = 2.seconds diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 6b5805f8b..846bd46cb 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -11,7 +11,7 @@ import scala.concurrent.duration.* import cats.syntax.all.* import cats.effect.* -import cats.effect.std.{ Env, SystemProperties } +import cats.effect.std.{ Env, SystemProperties, UUIDGen } import fs2.io.file.{ Files, Path } import fs2.io.net.* @@ -143,7 +143,7 @@ object WebIdentityTokenFileCredentialsProvider: * @tparam F The effect type * @return A new WebIdentityTokenFileCredentialsProvider instance */ - def apply[F[_]: Files: Env: SystemProperties: Network: Async]( + def apply[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( region: String = "us-east-1" ): F[WebIdentityTokenFileCredentialsProvider[F]] = for @@ -159,7 +159,7 @@ object WebIdentityTokenFileCredentialsProvider: * @tparam F The effect type * @return A new WebIdentityTokenFileCredentialsProvider instance */ - def default[F[_]: Env: SystemProperties: Files: Concurrent]( + def default[F[_]: Env: SystemProperties: Files: UUIDGen: Concurrent]( httpClient: HttpClient[F], region: String = "us-east-1" ): WebIdentityTokenFileCredentialsProvider[F] = diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 18259a0f2..3571e2723 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -10,6 +10,7 @@ import cats.syntax.all.* import cats.MonadThrow import cats.effect.Concurrent +import cats.effect.std.UUIDGen import fs2.io.file.{ Files, Path } @@ -150,7 +151,7 @@ object WebIdentityCredentialsUtils: * @tparam F The effect type * @return A WebIdentityCredentialsUtils instance */ - def default[F[_]: Files: Concurrent]: WebIdentityCredentialsUtils[F] = + def default[F[_]: Files: UUIDGen: Concurrent]: WebIdentityCredentialsUtils[F] = val stsClient = StsClient.default[F] Impl[F](stsClient) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 80114c2a9..932584c2a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -9,12 +9,12 @@ package ldbc.amazon.client import java.net.{ URI, URLEncoder } import java.time.{ Instant, ZoneOffset } import java.time.format.DateTimeFormatter -import java.util.UUID import cats.syntax.all.* import cats.MonadThrow import cats.effect.Concurrent +import cats.effect.std.UUIDGen import ldbc.amazon.exception.{ SdkClientException, StsException } import ldbc.amazon.util.SimpleXmlParser @@ -75,7 +75,7 @@ object StsClient: assumedRoleArn: String ) - private case class Impl[F[_]: Concurrent]() extends StsClient[F]: + private case class Impl[F[_]: UUIDGen: Concurrent]() extends StsClient[F]: def assumeRoleWithWebIdentity( request: AssumeRoleWithWebIdentityRequest, @@ -84,7 +84,7 @@ object StsClient: ): F[AssumeRoleWithWebIdentityResponse] = for timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) - sessionName = request.roleSessionName.getOrElse(s"ldbc-session-${ UUID.randomUUID() }") + sessionName = request.roleSessionName.getOrElse(s"ldbc-session-${ UUIDGen[F].randomUUID }") duration = request.durationSeconds.getOrElse(3600) // Build STS request @@ -114,7 +114,7 @@ object StsClient: * @tparam F The effect type * @return A StsClient instance */ - def default[F[_]: Concurrent]: StsClient[F] = Impl[F]() + def default[F[_]: UUIDGen: Concurrent]: StsClient[F] = Impl[F]() /** * Builds the STS request body in AWS Query format. From 60f6143208ece9a28760b78730d8b08f0b7c56a9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 01:13:12 +0900 Subject: [PATCH 103/215] Create SimpleJsonParser --- build.sbt | 1 - .../ContainerCredentialsProvider.scala | 30 ++- .../InstanceProfileCredentialsProvider.scala | 36 ++-- .../ldbc/amazon/util/SimpleJsonParser.scala | 203 ++++++++++++++++++ 4 files changed, 245 insertions(+), 25 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala diff --git a/build.sbt b/build.sbt index 305cf3838..ca5271700 100644 --- a/build.sbt +++ b/build.sbt @@ -150,7 +150,6 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP libraryDependencies ++= Seq( "co.fs2" %%% "fs2-core" % "3.12.2", "co.fs2" %%% "fs2-io" % "3.12.2", - "io.circe" %% "circe-parser" % "0.14.10", "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test ) ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala index bcdea91dd..1767c5165 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -19,13 +19,11 @@ import cats.effect.std.Env import fs2.io.file.{ Files, Path } import fs2.io.net.* -import io.circe.* -import io.circe.parser.* - import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SimpleJsonParser /** * [[AwsCredentialsProvider]] implementation that loads credentials from AWS Container Credential Provider. @@ -161,7 +159,11 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( private def parseCredentialsResponse(jsonBody: String): F[AwsCredentials] = Concurrent[F] .fromEither( - parse(jsonBody).flatMap(_.as[ContainerCredentialsResponse]) + SimpleJsonParser + .parse(jsonBody) + .flatMap(ContainerCredentialsResponse.fromJson) + .left + .map(msg => new SdkClientException(s"Failed to parse JSON: $msg")) ) .map { response => AwsSessionCredentials( @@ -219,13 +221,19 @@ private case class ContainerCredentialsResponse( ) private object ContainerCredentialsResponse: - given Decoder[ContainerCredentialsResponse] = Decoder.forProduct5( - "AccessKeyId", - "SecretAccessKey", - "Token", - "Expiration", - "RoleArn" - )(ContainerCredentialsResponse.apply) + def fromJson(json: SimpleJsonParser.JsonObject): Either[String, ContainerCredentialsResponse] = + for + accessKeyId <- json.require("AccessKeyId") + secretAccessKey <- json.require("SecretAccessKey") + token <- json.require("Token") + expiration <- json.require("Expiration") + yield ContainerCredentialsResponse( + AccessKeyId = accessKeyId, + SecretAccessKey = secretAccessKey, + Token = token, + Expiration = expiration, + RoleArn = json.get("RoleArn") + ) object ContainerCredentialsProvider: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala index 6b193a4e8..9c362e867 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -18,14 +18,12 @@ import cats.effect.std.Env import fs2.io.net.* -import io.circe.* -import io.circe.parser.* - import ldbc.amazon.client.* import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SimpleJsonParser /** * [[AwsCredentialsProvider]] implementation that loads credentials from EC2 Instance Metadata Service (IMDS). @@ -178,7 +176,11 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( private def parseCredentialsResponse(jsonBody: String, roleName: String): F[AwsCredentials] = Concurrent[F] .fromEither( - parse(jsonBody).flatMap(_.as[InstanceMetadataCredentialsResponse]) + SimpleJsonParser + .parse(jsonBody) + .flatMap(InstanceMetadataCredentialsResponse.fromJson) + .left + .map(msg => new SdkClientException(s"Failed to parse JSON: $msg")) ) .flatMap { response => if response.Code == "Success" then { @@ -240,15 +242,23 @@ private case class InstanceMetadataCredentialsResponse( ) private object InstanceMetadataCredentialsResponse: - given Decoder[InstanceMetadataCredentialsResponse] = Decoder.forProduct7( - "Code", - "LastUpdated", - "Type", - "AccessKeyId", - "SecretAccessKey", - "Token", - "Expiration" - )(InstanceMetadataCredentialsResponse.apply) + + def fromJson(json: SimpleJsonParser.JsonObject): Either[String, InstanceMetadataCredentialsResponse] = + for + code <- json.require("Code") + accessKeyId <- json.require("AccessKeyId") + secretAccessKey <- json.require("SecretAccessKey") + token <- json.require("Token") + expiration <- json.require("Expiration") + yield InstanceMetadataCredentialsResponse( + Code = code, + LastUpdated = json.getOrEmpty("LastUpdated"), + Type = json.getOrEmpty("Type"), + AccessKeyId = accessKeyId, + SecretAccessKey = secretAccessKey, + Token = token, + Expiration = expiration + ) object InstanceProfileCredentialsProvider: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala new file mode 100644 index 000000000..0ede0f267 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -0,0 +1,203 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +/** + * Simple JSON parser for AWS credential responses. + * Avoids external dependencies for Scala.js and Scala Native compatibility. + */ +object SimpleJsonParser: + + case class JsonObject(fields: Map[String, String]): + def get(key: String): Option[String] = fields.get(key) + + def getOrEmpty(key: String): String = fields.getOrElse(key, "") + + def require(key: String): Either[String, String] = + fields.get(key).toRight(s"Required field '$key' not found") + + /** + * Parses a flat JSON object with string values. + * + * Supports: + * - String values + * - Number values (returned as string) + * - Boolean values (returned as "true"/"false") + * - Null values (returned as empty string) + * + * Does NOT support: + * - Nested objects + * - Arrays + */ + def parse(json: String): Either[String, JsonObject] = + try + val trimmed = json.trim + if !trimmed.startsWith("{") || !trimmed.endsWith("}") then + Left("Invalid JSON: must be an object") + else + val content = trimmed.substring(1, trimmed.length - 1).trim + if content.isEmpty then Right(JsonObject(Map.empty)) + else + val fields = parseFields(content) + Right(JsonObject(fields)) + catch + case ex: Exception => Left(s"JSON parse error: ${ex.getMessage}") + + private def parseFields(content: String): Map[String, String] = + val result = scala.collection.mutable.Map[String, String]() + var idx = 0 + + while idx < content.length do + // Skip whitespace + idx = skipWhitespace(content, idx) + if idx < content.length && content.charAt(idx) != '}' then + // Parse key + val (key, nextIdx1) = parseString(content, idx) + idx = skipWhitespace(content, nextIdx1) + + // Expect colon + if idx >= content.length || content.charAt(idx) != ':' then + throw new IllegalArgumentException(s"Expected ':' after key '$key' at position $idx") + + idx = skipWhitespace(content, idx + 1) + + // Parse value + val (value, nextIdx2) = parseValue(content, idx) + result(key) = value + idx = skipWhitespace(content, nextIdx2) + + // Skip comma if present + if idx < content.length && content.charAt(idx) == ',' then + idx += 1 + + result.toMap + + private def skipWhitespace(s: String, from: Int): Int = + var idx = from + while idx < s.length && s.charAt(idx).isWhitespace do + idx += 1 + idx + + private def parseString(s: String, from: Int): (String, Int) = + if from >= s.length || s.charAt(from) != '"' then + throw new IllegalArgumentException(s"Expected '\"' at position $from") + + val sb = new StringBuilder + var idx = from + 1 + var escaped = false + + while idx < s.length do + val ch = s.charAt(idx) + if escaped then + ch match + case '"' => sb.append('"') + case '\\' => sb.append('\\') + case '/' => sb.append('/') + case 'b' => sb.append('\b') + case 'f' => sb.append('\f') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'u' => + if idx + 4 >= s.length then + throw new IllegalArgumentException("Invalid unicode escape sequence") + val hex = s.substring(idx + 1, idx + 5) + sb.append(Integer.parseInt(hex, 16).toChar) + idx += 4 + case _ => sb.append(ch) + escaped = false + else if ch == '\\' then + escaped = true + else if ch == '"' then + return (sb.toString, idx + 1) + else + sb.append(ch) + idx += 1 + + throw new IllegalArgumentException("Unterminated string") + + private def parseValue(s: String, from: Int): (String, Int) = + if from >= s.length then + throw new IllegalArgumentException(s"Unexpected end of input at position $from") + + val ch = s.charAt(from) + + if ch == '"' then + // String value + parseString(s, from) + else if ch == 'n' && s.length >= from + 4 && s.substring(from, from + 4) == "null" then + // null value - return empty string + ("", from + 4) + else if ch == 't' && s.length >= from + 4 && s.substring(from, from + 4) == "true" then + ("true", from + 4) + else if ch == 'f' && s.length >= from + 5 && s.substring(from, from + 5) == "false" then + ("false", from + 5) + else if ch == '-' || ch.isDigit then + // Number value + var idx = from + while idx < s.length && isNumberChar(s.charAt(idx)) do + idx += 1 + (s.substring(from, idx), idx) + else if ch == '{' then + // Nested object - skip it entirely + val endIdx = findMatchingBrace(s, from) + ("{...}", endIdx + 1) + else if ch == '[' then + // Array - skip it entirely + val endIdx = findMatchingBracket(s, from) + ("[...]", endIdx + 1) + else + throw new IllegalArgumentException(s"Unexpected character '$ch' at position $from") + + private def isNumberChar(ch: Char): Boolean = + ch.isDigit || ch == '.' || ch == '-' || ch == '+' || ch == 'e' || ch == 'E' + + private def findMatchingBrace(s: String, from: Int): Int = + var depth = 0 + var idx = from + var inString = false + var escaped = false + + while idx < s.length do + val ch = s.charAt(idx) + if escaped then + escaped = false + else if ch == '\\' && inString then + escaped = true + else if ch == '"' then + inString = !inString + else if !inString then + if ch == '{' then depth += 1 + else if ch == '}' then + depth -= 1 + if depth == 0 then return idx + idx += 1 + + throw new IllegalArgumentException("Unmatched brace") + + private def findMatchingBracket(s: String, from: Int): Int = + var depth = 0 + var idx = from + var inString = false + var escaped = false + + while idx < s.length do + val ch = s.charAt(idx) + if escaped then + escaped = false + else if ch == '\\' && inString then + escaped = true + else if ch == '"' then + inString = !inString + else if !inString then + if ch == '[' then depth += 1 + else if ch == ']' then + depth -= 1 + if depth == 0 then return idx + idx += 1 + + throw new IllegalArgumentException("Unmatched bracket") From f706bdec65ce6af3787408a4db54186a4a271506 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 01:13:53 +0900 Subject: [PATCH 104/215] Change use Instant.EPOCH.plusNanos --- .../scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 2e35964bd..39a07cb01 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -127,7 +127,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( DateTimeFormatter .ofPattern("yyyyMMdd'T'HHmmss'Z'") .withZone(ZoneOffset.UTC) - .format(Instant.now().plusMillis(duration.toMillis)) + .format(Instant.EPOCH.plusNanos(duration.toNanos)) /** * Builds the query parameters for the RDS authentication request. From 6a2b2058c5f75c347ddc29ca6d1b4ca0f9ec75a8 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 01:15:25 +0900 Subject: [PATCH 105/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 259b331e4..ecf318320 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,11 +155,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) From ade30791cbbdaa886f30f84234ddf424c7ad48cd Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 01:15:47 +0900 Subject: [PATCH 106/215] Action sbt scalafmtSbt --- build.sbt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build.sbt b/build.sbt index ca5271700..0ae94d85b 100644 --- a/build.sbt +++ b/build.sbt @@ -148,9 +148,9 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP .module("aws-authentication-plugin", "") .settings( libraryDependencies ++= Seq( - "co.fs2" %%% "fs2-core" % "3.12.2", - "co.fs2" %%% "fs2-io" % "3.12.2", - "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test + "co.fs2" %%% "fs2-core" % "3.12.2", + "co.fs2" %%% "fs2-io" % "3.12.2", + "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test ) ) .jsSettings( From 14cc6d4e356286178dd7fde6d99e28c155cbd53f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 4 Dec 2025 01:16:23 +0900 Subject: [PATCH 107/215] Action sbt scalafmtAll --- .../ContainerCredentialsProvider.scala | 8 +-- .../DefaultCredentialsProviderChain.scala | 4 +- .../WebIdentityCredentialsUtils.scala | 2 +- .../scala/ldbc/amazon/client/StsClient.scala | 21 ++++--- .../ldbc/amazon/util/SimpleJsonParser.scala | 61 +++++++------------ .../ldbc/amazon/util/SimpleXmlParser.scala | 14 ++--- 6 files changed, 47 insertions(+), 63 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala index 1767c5165..161c0a461 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -223,16 +223,16 @@ private case class ContainerCredentialsResponse( private object ContainerCredentialsResponse: def fromJson(json: SimpleJsonParser.JsonObject): Either[String, ContainerCredentialsResponse] = for - accessKeyId <- json.require("AccessKeyId") + accessKeyId <- json.require("AccessKeyId") secretAccessKey <- json.require("SecretAccessKey") - token <- json.require("Token") - expiration <- json.require("Expiration") + token <- json.require("Token") + expiration <- json.require("Expiration") yield ContainerCredentialsResponse( AccessKeyId = accessKeyId, SecretAccessKey = secretAccessKey, Token = token, Expiration = expiration, - RoleArn = json.get("RoleArn") + RoleArn = json.get("RoleArn") ) object ContainerCredentialsProvider: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index 9c1e06fc8..ac36a9d0a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -94,7 +94,9 @@ object DefaultCredentialsProviderChain: * @tparam F The effect type * @return A new DefaultCredentialsProviderChain instance */ - def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async](region: String): DefaultCredentialsProviderChain[F] = + def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( + region: String + ): DefaultCredentialsProviderChain[F] = val httpClient = new SimpleHttpClient[F]( connectTimeout = 1.second, readTimeout = 2.seconds diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 3571e2723..5ce0158c0 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -9,8 +9,8 @@ package ldbc.amazon.auth.credentials.internal import cats.syntax.all.* import cats.MonadThrow -import cats.effect.Concurrent import cats.effect.std.UUIDGen +import cats.effect.Concurrent import fs2.io.file.{ Files, Path } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 932584c2a..868197bfd 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -13,8 +13,8 @@ import java.time.format.DateTimeFormatter import cats.syntax.all.* import cats.MonadThrow -import cats.effect.Concurrent import cats.effect.std.UUIDGen +import cats.effect.Concurrent import ldbc.amazon.exception.{ SdkClientException, StsException } import ldbc.amazon.util.SimpleXmlParser @@ -190,26 +190,27 @@ object StsClient: MonadThrow[F] .catchNonFatal { - val accessKeyId = SimpleXmlParser.requireTag("AccessKeyId", xmlBody, "AccessKeyId not found") + val accessKeyId = SimpleXmlParser.requireTag("AccessKeyId", xmlBody, "AccessKeyId not found") val secretAccessKey = SimpleXmlParser.requireTag("SecretAccessKey", xmlBody, "SecretAccessKey not found") - val sessionToken = SimpleXmlParser.requireTag("SessionToken", xmlBody, "SessionToken not found") - val expirationStr = SimpleXmlParser.requireTag("Expiration", xmlBody, "Expiration not found") + val sessionToken = SimpleXmlParser.requireTag("SessionToken", xmlBody, "SessionToken not found") + val expirationStr = SimpleXmlParser.requireTag("Expiration", xmlBody, "Expiration not found") - val assumedRoleArn = SimpleXmlParser.extractSection("AssumedRoleUser", xmlBody) + val assumedRoleArn = SimpleXmlParser + .extractSection("AssumedRoleUser", xmlBody) .flatMap(section => SimpleXmlParser.extractTagContent("Arn", section)) .filter(_.nonEmpty) .getOrElse(throw new StsException("AssumedRoleArn not found")) AssumeRoleWithWebIdentityResponse( - accessKeyId = accessKeyId, + accessKeyId = accessKeyId, secretAccessKey = secretAccessKey, - sessionToken = sessionToken, - expiration = Instant.parse(expirationStr), - assumedRoleArn = assumedRoleArn + sessionToken = sessionToken, + expiration = Instant.parse(expirationStr), + assumedRoleArn = assumedRoleArn ) } .adaptError { case ex: StsException => ex - case ex => + case ex => new StsException(s"Failed to parse STS response: ${ ex.getMessage }") } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala index 0ede0f267..4e9e75679 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -36,16 +36,14 @@ object SimpleJsonParser: def parse(json: String): Either[String, JsonObject] = try val trimmed = json.trim - if !trimmed.startsWith("{") || !trimmed.endsWith("}") then - Left("Invalid JSON: must be an object") + if !trimmed.startsWith("{") || !trimmed.endsWith("}") then Left("Invalid JSON: must be an object") else val content = trimmed.substring(1, trimmed.length - 1).trim if content.isEmpty then Right(JsonObject(Map.empty)) else val fields = parseFields(content) Right(JsonObject(fields)) - catch - case ex: Exception => Left(s"JSON parse error: ${ex.getMessage}") + catch case ex: Exception => Left(s"JSON parse error: ${ ex.getMessage }") private def parseFields(content: String): Map[String, String] = val result = scala.collection.mutable.Map[String, String]() @@ -68,18 +66,16 @@ object SimpleJsonParser: // Parse value val (value, nextIdx2) = parseValue(content, idx) result(key) = value - idx = skipWhitespace(content, nextIdx2) + idx = skipWhitespace(content, nextIdx2) // Skip comma if present - if idx < content.length && content.charAt(idx) == ',' then - idx += 1 + if idx < content.length && content.charAt(idx) == ',' then idx += 1 result.toMap private def skipWhitespace(s: String, from: Int): Int = var idx = from - while idx < s.length && s.charAt(idx).isWhitespace do - idx += 1 + while idx < s.length && s.charAt(idx).isWhitespace do idx += 1 idx private def parseString(s: String, from: Int): (String, Int) = @@ -102,27 +98,22 @@ object SimpleJsonParser: case 'n' => sb.append('\n') case 'r' => sb.append('\r') case 't' => sb.append('\t') - case 'u' => - if idx + 4 >= s.length then - throw new IllegalArgumentException("Invalid unicode escape sequence") + case 'u' => + if idx + 4 >= s.length then throw new IllegalArgumentException("Invalid unicode escape sequence") val hex = s.substring(idx + 1, idx + 5) sb.append(Integer.parseInt(hex, 16).toChar) idx += 4 case _ => sb.append(ch) escaped = false - else if ch == '\\' then - escaped = true - else if ch == '"' then - return (sb.toString, idx + 1) - else - sb.append(ch) + else if ch == '\\' then escaped = true + else if ch == '"' then return (sb.toString, idx + 1) + else sb.append(ch) idx += 1 throw new IllegalArgumentException("Unterminated string") private def parseValue(s: String, from: Int): (String, Int) = - if from >= s.length then - throw new IllegalArgumentException(s"Unexpected end of input at position $from") + if from >= s.length then throw new IllegalArgumentException(s"Unexpected end of input at position $from") val ch = s.charAt(from) @@ -132,15 +123,12 @@ object SimpleJsonParser: else if ch == 'n' && s.length >= from + 4 && s.substring(from, from + 4) == "null" then // null value - return empty string ("", from + 4) - else if ch == 't' && s.length >= from + 4 && s.substring(from, from + 4) == "true" then - ("true", from + 4) - else if ch == 'f' && s.length >= from + 5 && s.substring(from, from + 5) == "false" then - ("false", from + 5) + else if ch == 't' && s.length >= from + 4 && s.substring(from, from + 4) == "true" then ("true", from + 4) + else if ch == 'f' && s.length >= from + 5 && s.substring(from, from + 5) == "false" then ("false", from + 5) else if ch == '-' || ch.isDigit then // Number value var idx = from - while idx < s.length && isNumberChar(s.charAt(idx)) do - idx += 1 + while idx < s.length && isNumberChar(s.charAt(idx)) do idx += 1 (s.substring(from, idx), idx) else if ch == '{' then // Nested object - skip it entirely @@ -150,8 +138,7 @@ object SimpleJsonParser: // Array - skip it entirely val endIdx = findMatchingBracket(s, from) ("[...]", endIdx + 1) - else - throw new IllegalArgumentException(s"Unexpected character '$ch' at position $from") + else throw new IllegalArgumentException(s"Unexpected character '$ch' at position $from") private def isNumberChar(ch: Char): Boolean = ch.isDigit || ch == '.' || ch == '-' || ch == '+' || ch == 'e' || ch == 'E' @@ -164,12 +151,9 @@ object SimpleJsonParser: while idx < s.length do val ch = s.charAt(idx) - if escaped then - escaped = false - else if ch == '\\' && inString then - escaped = true - else if ch == '"' then - inString = !inString + if escaped then escaped = false + else if ch == '\\' && inString then escaped = true + else if ch == '"' then inString = !inString else if !inString then if ch == '{' then depth += 1 else if ch == '}' then @@ -187,12 +171,9 @@ object SimpleJsonParser: while idx < s.length do val ch = s.charAt(idx) - if escaped then - escaped = false - else if ch == '\\' && inString then - escaped = true - else if ch == '"' then - inString = !inString + if escaped then escaped = false + else if ch == '\\' && inString then escaped = true + else if ch == '"' then inString = !inString else if !inString then if ch == '[' then depth += 1 else if ch == ']' then diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala index 764458d2c..b0bc7a50e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala @@ -16,27 +16,27 @@ object SimpleXmlParser: def extractTagContent(tagName: String, xml: String): Option[String] = { val startTag = s"<$tagName>" - val endTag = s"" + val endTag = s"" val startIdx = xml.indexOf(startTag) - if (startIdx < 0) None + if startIdx < 0 then None else { val contentStart = startIdx + startTag.length - val endIdx = xml.indexOf(endTag, contentStart) - if (endIdx < 0) None + val endIdx = xml.indexOf(endTag, contentStart) + if endIdx < 0 then None else Some(decodeXmlEntities(xml.substring(contentStart, endIdx).trim)) } } def extractSection(tagName: String, xml: String): Option[String] = { val startTag = s"<$tagName>" - val endTag = s"" + val endTag = s"" val startIdx = xml.indexOf(startTag) - if (startIdx < 0) None + if startIdx < 0 then None else { val endIdx = xml.indexOf(endTag, startIdx) - if (endIdx < 0) None + if endIdx < 0 then None else Some(xml.substring(startIdx, endIdx + endTag.length)) } } From 449f875c8985bf9885db101f9393484195d16efb Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:01:33 +0900 Subject: [PATCH 108/215] Added split cummnma --- .../identity/internal/DefaultAwsCredentialsIdentity.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index b7f981083..908f06e2f 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -23,8 +23,8 @@ final case class DefaultAwsCredentialsIdentity( val builder = new StringBuilder() builder.append("AwsCredentialsIdentity(") builder.append(s"accessKeyId=$accessKeyId") - providerName.foreach(v => builder.append(s"providerName=$v")) - accountId.foreach(v => builder.append(s"accountId=$v")) + providerName.foreach(v => builder.append(s", providerName=$v")) + accountId.foreach(v => builder.append(s", accountId=$v")) builder.result() From 1e449e9ddd1f7b6af4e399cceeed4f272e050374 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:09:11 +0900 Subject: [PATCH 109/215] Added should append ) --- .../amazon/identity/internal/DefaultAwsCredentialsIdentity.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index 908f06e2f..c201c09ee 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -25,6 +25,7 @@ final case class DefaultAwsCredentialsIdentity( builder.append(s"accessKeyId=$accessKeyId") providerName.foreach(v => builder.append(s", providerName=$v")) accountId.foreach(v => builder.append(s", accountId=$v")) + builder.append(")") builder.result() From 8ac9bc46a8171b6f9643be8b3ba716520ab1b9f7 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:11:05 +0900 Subject: [PATCH 110/215] Fixed UUIDGen usage --- .../shared/src/main/scala/ldbc/amazon/client/StsClient.scala | 4 +++- .../shared/src/main/scala/ldbc/amazon/identity/Identity.scala | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 868197bfd..560258668 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -84,7 +84,9 @@ object StsClient: ): F[AssumeRoleWithWebIdentityResponse] = for timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) - sessionName = request.roleSessionName.getOrElse(s"ldbc-session-${ UUIDGen[F].randomUUID }") + sessionName <- request.roleSessionName.fold( + UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") + )(Concurrent[F].pure) duration = request.durationSeconds.getOrElse(3600) // Build STS request diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala index efc4b7e47..821e23888 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala @@ -25,6 +25,6 @@ trait Identity: /** * The source that resolved this identity, normally an identity provider. Note that * this string value would be set by an identity provider implementation and is - * intended to be used for for tracking purposes. Avoid building logic on its value. + * intended to be used for tracking purposes. Avoid building logic on its value. */ def providerName: Option[String] From aaada2bc5520fbf162a63791b162115bc6045ef4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:14:46 +0900 Subject: [PATCH 111/215] Change scaladoc --- .../EnvironmentVariableCredentialsProvider.scala | 4 ++-- .../credentials/internal/WebIdentityCredentialsUtils.scala | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala index ffa36a783..d05ca9f0a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala @@ -20,8 +20,8 @@ import ldbc.amazon.util.SdkSystemSetting */ final class EnvironmentVariableCredentialsProvider[F[_]: Env: MonadThrow] extends SystemSettingsCredentialsProvider[F]: - // Customers should be able to specify a credentials provider that only looks at the system properties, - // but not the environment variables. For that reason, we're only checking the system properties here. + // Customers should be able to specify a credentials provider that only looks at the environment variables, + // but not the system properties. For that reason, we're only checking the environment variables here. override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = Env[F].get(setting.toString) override def provider: String = BusinessMetricFeatureId.CREDENTIALS_ENV_VARS.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 5ce0158c0..8ddbe0720 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -139,11 +139,7 @@ object WebIdentityCredentialsUtils: private def extractAccountIdFromArn(arn: String): Option[String] = // ARN format: arn:aws:sts::ACCOUNT_ID:assumed-role/ROLE_NAME/SESSION_NAME val arnParts = arn.split(":") - if arnParts.length >= 5 then { - Some(arnParts(4)) - } else { - None - } + arnParts.lift(4) /** * Creates a default implementation of WebIdentityCredentialsUtils. From a9f16d4c6e2b5b997de7225b936aa3d767910d41 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:20:02 +0900 Subject: [PATCH 112/215] Fixed SimpleJsonParser escape string parse bug --- .../src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala index 4e9e75679..ad0994a1a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -103,7 +103,7 @@ object SimpleJsonParser: val hex = s.substring(idx + 1, idx + 5) sb.append(Integer.parseInt(hex, 16).toChar) idx += 4 - case _ => sb.append(ch) + case _ => throw new IllegalArgumentException(s"Invalid escape sequence: \\$ch") escaped = false else if ch == '\\' then escaped = true else if ch == '"' then return (sb.toString, idx + 1) From f0bcce22bc4f4c1f6384898f0cb527a1c33db303 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:25:41 +0900 Subject: [PATCH 113/215] Change parse null to None --- .../ldbc/amazon/util/SimpleJsonParser.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala index ad0994a1a..da6a286f8 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -12,13 +12,16 @@ package ldbc.amazon.util */ object SimpleJsonParser: - case class JsonObject(fields: Map[String, String]): - def get(key: String): Option[String] = fields.get(key) + case class JsonObject(fields: Map[String, Option[String]]): + def get(key: String): Option[String] = fields.get(key).flatten - def getOrEmpty(key: String): String = fields.getOrElse(key, "") + def getOrEmpty(key: String): String = fields.get(key).flatten.getOrElse("") def require(key: String): Either[String, String] = - fields.get(key).toRight(s"Required field '$key' not found") + fields.get(key) match + case None => Left(s"Required field '$key' not found") + case Some(None) => Left(s"Field '$key' is null") + case Some(Some(value)) => Right(value) /** * Parses a flat JSON object with string values. @@ -45,8 +48,8 @@ object SimpleJsonParser: Right(JsonObject(fields)) catch case ex: Exception => Left(s"JSON parse error: ${ ex.getMessage }") - private def parseFields(content: String): Map[String, String] = - val result = scala.collection.mutable.Map[String, String]() + private def parseFields(content: String): Map[String, Option[String]] = + val result = scala.collection.mutable.Map[String, Option[String]]() var idx = 0 while idx < content.length do @@ -65,7 +68,7 @@ object SimpleJsonParser: // Parse value val (value, nextIdx2) = parseValue(content, idx) - result(key) = value + result(key) = Option(value) idx = skipWhitespace(content, nextIdx2) // Skip comma if present @@ -121,8 +124,8 @@ object SimpleJsonParser: // String value parseString(s, from) else if ch == 'n' && s.length >= from + 4 && s.substring(from, from + 4) == "null" then - // null value - return empty string - ("", from + 4) + // null value + (null, from + 4) else if ch == 't' && s.length >= from + 4 && s.substring(from, from + 4) == "true" then ("true", from + 4) else if ch == 'f' && s.length >= from + 5 && s.substring(from, from + 5) == "false" then ("false", from + 5) else if ch == '-' || ch.isDigit then From 2b4530cf45b32a62ad0fce6dfef538f06b5933d3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:25:59 +0900 Subject: [PATCH 114/215] Create SimpleJsonParserTest --- .../amazon/util/SimpleJsonParserTest.scala | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala new file mode 100644 index 000000000..83edc8312 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala @@ -0,0 +1,176 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +import munit.CatsEffectSuite + +class SimpleJsonParserTest extends CatsEffectSuite: + + test("parse empty object") { + val result = SimpleJsonParser.parse("{}") + assert(result.isRight) + assertEquals(result.map(_.fields.size).toOption.get, 0) + } + + test("parse object with single string field") { + val json = """{"key": "value"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("key")).toOption.flatten, Some("value")) + assertEquals(result.map(_.getOrEmpty("key")).toOption.get, "value") + } + + test("parse object with multiple string fields") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("AccessKeyId")).toOption.flatten, Some("AKIAIOSFODNN7EXAMPLE")) + assertEquals(result.map(_.get("SecretAccessKey")).toOption.flatten, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) + } + + test("parse object with number values") { + val json = """{"port": 3306, "timeout": 30.5}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("port")).toOption.flatten, Some("3306")) + assertEquals(result.map(_.get("timeout")).toOption.flatten, Some("30.5")) + } + + test("parse object with boolean values") { + val json = """{"enabled": true, "disabled": false}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("enabled")).toOption.flatten, Some("true")) + assertEquals(result.map(_.get("disabled")).toOption.flatten, Some("false")) + } + + test("parse object with null values") { + val json = """{"nullValue": null}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("nullValue")).toOption.flatten, None) + assertEquals(result.map(_.getOrEmpty("nullValue")).toOption.get, "") + } + + test("parse object with escaped string values") { + val json = """{"escaped": "Hello \"World\"", "newline": "Line1\nLine2"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("escaped")).toOption.flatten, Some("Hello \"World\"")) + assertEquals(result.map(_.get("newline")).toOption.flatten, Some("Line1\nLine2")) + } + + test("parse object with unicode escape sequences") { + val json = """{"unicode": "\u0041\u0042\u0043"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("unicode")).toOption.flatten, Some("ABC")) + } + + test("parse object with whitespace") { + val json = """ { "key" : "value" , "another" : 42 } """ + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("key")).toOption.flatten, Some("value")) + assertEquals(result.map(_.get("another")).toOption.flatten, Some("42")) + } + + test("parse object with nested objects (should skip)") { + val json = """{"nested": {"inner": "value"}, "simple": "text"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("nested")).toOption.flatten, Some("{...}")) + assertEquals(result.map(_.get("simple")).toOption.flatten, Some("text")) + } + + test("parse object with arrays (should skip)") { + val json = """{"array": [1, 2, 3], "simple": "text"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("array")).toOption.flatten, Some("[...]")) + assertEquals(result.map(_.get("simple")).toOption.flatten, Some("text")) + } + + test("JsonObject.require returns Right for existing key") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + val required = result.map(_.require("AccessKeyId")) + assert(required.toOption.get.isRight) + assertEquals(required.toOption.get.toOption.get, "AKIAIOSFODNN7EXAMPLE") + } + + test("JsonObject.require returns Left for null value") { + val json = """{"nullValue": null}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + val required = result.map(_.require("nullValue")) + assert(required.toOption.get.isLeft) + assertEquals(required.toOption.get.left.toOption.get, "Field 'nullValue' is null") + } + + test("JsonObject.require returns Left for missing key") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + val required = result.map(_.require("MissingKey")) + assert(required.toOption.get.isLeft) + assertEquals(required.toOption.get.left.toOption.get, "Required field 'MissingKey' not found") + } + + test("JsonObject.getOrEmpty returns empty string for missing key") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.getOrEmpty("MissingKey")).toOption.get, "") + } + + test("fail to parse invalid JSON - not an object") { + val result = SimpleJsonParser.parse("\"just a string\"") + assert(result.isLeft) + assert(result.left.toOption.get.contains("Invalid JSON: must be an object")) + } + + test("fail to parse invalid JSON - missing closing brace") { + val result = SimpleJsonParser.parse("""{"key": "value"""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - missing colon") { + val result = SimpleJsonParser.parse("""{"key" "value"}""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - unterminated string") { + val result = SimpleJsonParser.parse("""{"key": "value}""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - invalid escape sequence") { + val result = SimpleJsonParser.parse("""{"key": "value\q"}""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - incomplete unicode escape") { + val result = SimpleJsonParser.parse("""{"key": "value\u00"}""") + assert(result.isLeft) + } + + test("parse complex AWS credentials response") { + val json = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE", + "Expiration": "2024-12-06T12:34:56Z" + }""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("AccessKeyId")).toOption.flatten, Some("ASIAIOSFODNN7EXAMPLE")) + assertEquals(result.map(_.get("SecretAccessKey")).toOption.flatten, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) + assertEquals(result.map(_.get("Token")).toOption.flatten, Some("IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE")) + assertEquals(result.map(_.get("Expiration")).toOption.flatten, Some("2024-12-06T12:34:56Z")) + } \ No newline at end of file From f5b3fd12d37d467c37e01d08924c9c444945b8ca Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:50:15 +0900 Subject: [PATCH 115/215] Added ssl support --- .../ldbc/amazon/client/SimpleHttpClient.scala | 117 +++++++++++++----- 1 file changed, 87 insertions(+), 30 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index d644b3bda..4b5c36fd1 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -7,6 +7,7 @@ package ldbc.amazon.client import java.net.URI +import javax.net.ssl.SNIHostName import scala.concurrent.duration.* @@ -20,35 +21,89 @@ import cats.effect.syntax.all.* import fs2.* import fs2.io.net.* +import fs2.io.net.tls.* import ldbc.amazon.exception.* +/** + * Secure HTTP client that supports both HTTP and HTTPS protocols. + * + * Security Features: + * - Validates URI schemes and rejects unsupported protocols + * - Uses TLS for HTTPS connections with proper certificate validation + * - Defaults to secure ports (443 for HTTPS, 80 for HTTP) + * - Prevents credentials from being sent over cleartext connections + * + * This addresses the security vulnerability where AWS credentials + * could be sent over unencrypted HTTP connections. + */ class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration )(using ev: MonadThrow[F]) extends HttpClient[F]: - override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = - val host = uri.getHost - val port = if uri.getPort > 0 then uri.getPort else 80 - val path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + private def isHttps(uri: URI): Boolean = + uri.getScheme != null && uri.getScheme.toLowerCase == "https" + + private def getDefaultPort(uri: URI): Int = + if uri.getPort > 0 then uri.getPort + else if isHttps(uri) then 443 + else 80 + + private def validateScheme(uri: URI): F[Unit] = + uri.getScheme match + case null => ev.raiseError(new SdkClientException("URI scheme is required")) + case scheme if scheme.toLowerCase == "http" => + // Log warning for HTTP usage, but allow it for non-sensitive endpoints + ev.unit + case scheme if scheme.toLowerCase == "https" => ev.unit + case unsupported => ev.raiseError(new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.")) + + private def validateSecurityRequirements(uri: URI): F[Unit] = + // AWS endpoints should always use HTTPS + if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then + ev.raiseError(new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${uri.getScheme}://${uri.getHost}")) + else + ev.unit + + private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + if isSecure then + for + socket <- Network[F].client(address) + tlsContext <- Network[F].tlsContext.systemResource + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(serverNames = Some(List(new SNIHostName(host))))) + .build + yield tlsSocket + else + Network[F].client(address) + override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, "GET", path, headers, None) + response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) yield response override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - val host = uri.getHost - val port = if uri.getPort > 0 then uri.getPort else 80 - val path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, "PUT", path, headers, Some(body)) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) yield response private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = @@ -58,15 +113,17 @@ class SimpleHttpClient[F[_]: Network: Async]( yield SocketAddress(h, p) private def sendRequest( - socket: Socket[F], - method: String, - host: String, - port: Int, - path: String, - headers: Map[String, String], - body: Option[String] + socket: Socket[F], + method: String, + host: String, + port: Int, + isSecure: Boolean, + path: String, + headers: Map[String, String], + body: Option[String] ): F[Unit] = - val hostHeader = if port == 80 then host else s"$host:$port" + val defaultPort = if isSecure then 443 else 80 + val hostHeader = if port == defaultPort then host else s"$host:$port" val contentHeaders = body match { case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) case None => Map.empty @@ -121,19 +178,19 @@ class SimpleHttpClient[F[_]: Network: Async]( .flatMap(parseHttpResponse) private def makeRequest( - address: SocketAddress[Host], - host: String, - port: Int, - method: String, - path: String, - headers: Map[String, String], - body: Option[String] + address: SocketAddress[Host], + host: String, + port: Int, + isSecure: Boolean, + method: String, + path: String, + headers: Map[String, String], + body: Option[String] ): F[HttpResponse] = - Network[F] - .client(address) + createSocket(address, isSecure, host) .use { socket => for - _ <- sendRequest(socket, method, host, port, path, headers, body) + _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) response <- receiveResponse(socket) yield response } From 4aeb55b5e083bf74cf7a3c6e06c6a2892af69e35 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:50:27 +0900 Subject: [PATCH 116/215] Create SimpleHttpClientSecurityTest --- .../client/SimpleHttpClientSecurityTest.scala | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala new file mode 100644 index 000000000..f4388fb9f --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala @@ -0,0 +1,125 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +import scala.concurrent.duration.* + +import munit.CatsEffectSuite + +import cats.effect.IO + +import fs2.io.net.Network + +class SimpleHttpClientSecurityTest extends CatsEffectSuite: + + val client = new SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) + + test("reject URIs without scheme") { + val uri = URI.create("example.com/path") + + client.get(uri, Map.empty).attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("URI scheme is required")) + } + } + + test("reject unsupported URI schemes") { + val uri = URI.create("ftp://example.com/path") + + client.get(uri, Map.empty).attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("Unsupported URI scheme: ftp")) + } + } + + test("accept HTTP URIs") { + val uri = URI.create("http://httpbin.org/get") + + // This test will fail if network is not available, but validates the scheme is accepted + client.get(uri, Map.empty).attempt.map { result => + // Either succeeds or fails with network error (not scheme validation error) + result.fold( + error => assert(!error.getMessage.contains("Unsupported URI scheme")), + _ => assert(true) + ) + } + } + + test("accept HTTPS URIs and use correct default port") { + val uri = URI.create("https://httpbin.org/get") + + // This test validates HTTPS scheme is accepted and would use port 443 + client.get(uri, Map.empty).attempt.map { result => + // Either succeeds or fails with network error (not scheme validation error) + result.fold( + error => assert(!error.getMessage.contains("Unsupported URI scheme")), + _ => assert(true) + ) + } + } + + test("use correct default ports") { + // HTTP should default to port 80 + val httpUri = URI.create("http://example.com") + assert(httpUri.getPort == -1) // No explicit port + + // HTTPS should default to port 443 + val httpsUri = URI.create("https://example.com") + assert(httpsUri.getPort == -1) // No explicit port + + // The client should handle these correctly internally + assert(true) // This is tested implicitly by the above tests + } + + test("preserve explicit ports in URIs") { + val httpUri = URI.create("http://example.com:8080/path") + val httpsUri = URI.create("https://example.com:8443/path") + + assert(httpUri.getPort == 8080) + assert(httpsUri.getPort == 8443) + } + + test("AWS STS HTTPS endpoints should be accepted") { + // These are the actual endpoints the StsClient will use + val stsEndpoints = List( + "https://sts.us-east-1.amazonaws.com/", + "https://sts.eu-west-1.amazonaws.com/", + "https://sts.ap-northeast-1.amazonaws.com/" + ) + + stsEndpoints.foreach { endpoint => + val uri = URI.create(endpoint) + + client.get(uri, Map.empty).attempt.map { result => + // Should not fail with scheme validation error + result.fold( + error => assert(!error.getMessage.contains("Unsupported URI scheme")), + _ => assert(true) + ) + } + } + } + + test("reject HTTP requests to AWS endpoints") { + val insecureAWSEndpoint = URI.create("http://sts.us-east-1.amazonaws.com/") + + client.get(insecureAWSEndpoint, Map.empty).attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("AWS endpoints require HTTPS")) + } + } + + test("reject HTTP PUT requests to AWS endpoints") { + val insecureAWSEndpoint = URI.create("http://sts.us-east-1.amazonaws.com/") + + client.put(insecureAWSEndpoint, Map.empty, "test body").attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("AWS endpoints require HTTPS")) + } + } \ No newline at end of file From 2184175e5fcc264fd7023b1a9eaf3e2f7b8e76b3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:51:00 +0900 Subject: [PATCH 117/215] Action sbt scalafmtAll --- .../ldbc/amazon/client/SimpleHttpClient.scala | 44 ++++++++++--------- .../scala/ldbc/amazon/client/StsClient.scala | 8 ++-- .../ldbc/amazon/util/SimpleJsonParser.scala | 4 +- .../client/SimpleHttpClientSecurityTest.scala | 34 +++++++------- .../amazon/util/SimpleJsonParserTest.scala | 22 +++++++--- 5 files changed, 63 insertions(+), 49 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 4b5c36fd1..a48e1dfdf 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -7,6 +7,7 @@ package ldbc.amazon.client import java.net.URI + import javax.net.ssl.SNIHostName import scala.concurrent.duration.* @@ -43,7 +44,7 @@ class SimpleHttpClient[F[_]: Network: Async]( )(using ev: MonadThrow[F]) extends HttpClient[F]: - private def isHttps(uri: URI): Boolean = + private def isHttps(uri: URI): Boolean = uri.getScheme != null && uri.getScheme.toLowerCase == "https" private def getDefaultPort(uri: URI): Int = @@ -53,55 +54,58 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = uri.getScheme match - case null => ev.raiseError(new SdkClientException("URI scheme is required")) - case scheme if scheme.toLowerCase == "http" => + case null => ev.raiseError(new SdkClientException("URI scheme is required")) + case scheme if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit case scheme if scheme.toLowerCase == "https" => ev.unit - case unsupported => ev.raiseError(new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.")) + case unsupported => + ev.raiseError( + new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") + ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then - ev.raiseError(new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${uri.getScheme}://${uri.getHost}")) - else - ev.unit + ev.raiseError( + new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") + ) + else ev.unit private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then for - socket <- Network[F].client(address) + socket <- Network[F].client(address) tlsContext <- Network[F].tlsContext.systemResource - tlsSocket <- tlsContext - .clientBuilder(socket) - .withParameters(TLSParameters(serverNames = Some(List(new SNIHostName(host))))) - .build + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(serverNames = Some(List(new SNIHostName(host))))) + .build yield tlsSocket - else - Network[F].client(address) + else Network[F].client(address) override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) host = uri.getHost port = getDefaultPort(uri) isSecure = isHttps(uri) path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) yield response override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) host = uri.getHost port = getDefaultPort(uri) isSecure = isHttps(uri) path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) yield response diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 560258668..9879d2ce2 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -83,11 +83,11 @@ object StsClient: httpClient: HttpClient[F] ): F[AssumeRoleWithWebIdentityResponse] = for - timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) + timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) sessionName <- request.roleSessionName.fold( - UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") - )(Concurrent[F].pure) - duration = request.durationSeconds.getOrElse(3600) + UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") + )(Concurrent[F].pure) + duration = request.durationSeconds.getOrElse(3600) // Build STS request stsEndpoint = s"https://sts.$region.amazonaws.com/" diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala index da6a286f8..51285af7e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -19,8 +19,8 @@ object SimpleJsonParser: def require(key: String): Either[String, String] = fields.get(key) match - case None => Left(s"Required field '$key' not found") - case Some(None) => Left(s"Field '$key' is null") + case None => Left(s"Required field '$key' not found") + case Some(None) => Left(s"Field '$key' is null") case Some(Some(value)) => Right(value) /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala index f4388fb9f..6747a5487 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala @@ -10,19 +10,19 @@ import java.net.URI import scala.concurrent.duration.* -import munit.CatsEffectSuite - import cats.effect.IO import fs2.io.net.Network +import munit.CatsEffectSuite + class SimpleHttpClientSecurityTest extends CatsEffectSuite: val client = new SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) test("reject URIs without scheme") { val uri = URI.create("example.com/path") - + client.get(uri, Map.empty).attempt.map { result => assert(result.isLeft) assert(result.left.toOption.get.getMessage.contains("URI scheme is required")) @@ -31,7 +31,7 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: test("reject unsupported URI schemes") { val uri = URI.create("ftp://example.com/path") - + client.get(uri, Map.empty).attempt.map { result => assert(result.isLeft) assert(result.left.toOption.get.getMessage.contains("Unsupported URI scheme: ftp")) @@ -40,7 +40,7 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: test("accept HTTP URIs") { val uri = URI.create("http://httpbin.org/get") - + // This test will fail if network is not available, but validates the scheme is accepted client.get(uri, Map.empty).attempt.map { result => // Either succeeds or fails with network error (not scheme validation error) @@ -53,10 +53,10 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: test("accept HTTPS URIs and use correct default port") { val uri = URI.create("https://httpbin.org/get") - + // This test validates HTTPS scheme is accepted and would use port 443 client.get(uri, Map.empty).attempt.map { result => - // Either succeeds or fails with network error (not scheme validation error) + // Either succeeds or fails with network error (not scheme validation error) result.fold( error => assert(!error.getMessage.contains("Unsupported URI scheme")), _ => assert(true) @@ -68,19 +68,19 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: // HTTP should default to port 80 val httpUri = URI.create("http://example.com") assert(httpUri.getPort == -1) // No explicit port - + // HTTPS should default to port 443 - val httpsUri = URI.create("https://example.com") + val httpsUri = URI.create("https://example.com") assert(httpsUri.getPort == -1) // No explicit port - + // The client should handle these correctly internally assert(true) // This is tested implicitly by the above tests } test("preserve explicit ports in URIs") { - val httpUri = URI.create("http://example.com:8080/path") + val httpUri = URI.create("http://example.com:8080/path") val httpsUri = URI.create("https://example.com:8443/path") - + assert(httpUri.getPort == 8080) assert(httpsUri.getPort == 8443) } @@ -92,10 +92,10 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: "https://sts.eu-west-1.amazonaws.com/", "https://sts.ap-northeast-1.amazonaws.com/" ) - + stsEndpoints.foreach { endpoint => val uri = URI.create(endpoint) - + client.get(uri, Map.empty).attempt.map { result => // Should not fail with scheme validation error result.fold( @@ -108,7 +108,7 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: test("reject HTTP requests to AWS endpoints") { val insecureAWSEndpoint = URI.create("http://sts.us-east-1.amazonaws.com/") - + client.get(insecureAWSEndpoint, Map.empty).attempt.map { result => assert(result.isLeft) assert(result.left.toOption.get.getMessage.contains("AWS endpoints require HTTPS")) @@ -117,9 +117,9 @@ class SimpleHttpClientSecurityTest extends CatsEffectSuite: test("reject HTTP PUT requests to AWS endpoints") { val insecureAWSEndpoint = URI.create("http://sts.us-east-1.amazonaws.com/") - + client.put(insecureAWSEndpoint, Map.empty, "test body").attempt.map { result => assert(result.isLeft) assert(result.left.toOption.get.getMessage.contains("AWS endpoints require HTTPS")) } - } \ No newline at end of file + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala index 83edc8312..1db0bd749 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala @@ -25,11 +25,15 @@ class SimpleJsonParserTest extends CatsEffectSuite: } test("parse object with multiple string fields") { - val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"}""" + val json = + """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"}""" val result = SimpleJsonParser.parse(json) assert(result.isRight) assertEquals(result.map(_.get("AccessKeyId")).toOption.flatten, Some("AKIAIOSFODNN7EXAMPLE")) - assertEquals(result.map(_.get("SecretAccessKey")).toOption.flatten, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) + assertEquals( + result.map(_.get("SecretAccessKey")).toOption.flatten, + Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + ) } test("parse object with number values") { @@ -161,7 +165,7 @@ class SimpleJsonParserTest extends CatsEffectSuite: } test("parse complex AWS credentials response") { - val json = """{ + val json = """{ "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", "Token": "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE", @@ -170,7 +174,13 @@ class SimpleJsonParserTest extends CatsEffectSuite: val result = SimpleJsonParser.parse(json) assert(result.isRight) assertEquals(result.map(_.get("AccessKeyId")).toOption.flatten, Some("ASIAIOSFODNN7EXAMPLE")) - assertEquals(result.map(_.get("SecretAccessKey")).toOption.flatten, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) - assertEquals(result.map(_.get("Token")).toOption.flatten, Some("IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE")) + assertEquals( + result.map(_.get("SecretAccessKey")).toOption.flatten, + Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + ) + assertEquals( + result.map(_.get("Token")).toOption.flatten, + Some("IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE") + ) assertEquals(result.map(_.get("Expiration")).toOption.flatten, Some("2024-12-06T12:34:56Z")) - } \ No newline at end of file + } From ab48faea65b06331c25d72577cc4bff24c8f94cc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:54:44 +0900 Subject: [PATCH 118/215] Create SimpleXmlParserTest --- .../amazon/util/SimpleXmlParserTest.scala | 247 ++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala new file mode 100644 index 000000000..47f9e6c61 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala @@ -0,0 +1,247 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +import munit.CatsEffectSuite + +class SimpleXmlParserTest extends CatsEffectSuite: + + test("decodeXmlEntities should decode all standard XML entities") { + val input = "&<>"'" + val expected = "&<>\"'" + assertEquals(SimpleXmlParser.decodeXmlEntities(input), expected) + } + + test("decodeXmlEntities should handle mixed content with entities") { + val input = "Hello & welcome to <XML> parsing "test"" + val expected = "Hello & welcome to parsing \"test\"" + assertEquals(SimpleXmlParser.decodeXmlEntities(input), expected) + } + + test("decodeXmlEntities should handle text without entities") { + val input = "Plain text without any entities" + assertEquals(SimpleXmlParser.decodeXmlEntities(input), input) + } + + test("decodeXmlEntities should handle empty string") { + assertEquals(SimpleXmlParser.decodeXmlEntities(""), "") + } + + test("extractTagContent should extract simple tag content") { + val xml = "John Doe" + val result = SimpleXmlParser.extractTagContent("name", xml) + assertEquals(result, Some("John Doe")) + } + + test("extractTagContent should extract content with whitespace trimmed") { + val xml = " John Doe " + val result = SimpleXmlParser.extractTagContent("name", xml) + assertEquals(result, Some("John Doe")) + } + + test("extractTagContent should extract content with XML entities decoded") { + val xml = "Hello & welcome to <XML>" + val result = SimpleXmlParser.extractTagContent("message", xml) + assertEquals(result, Some("Hello & welcome to ")) + } + + test("extractTagContent should return None for non-existent tag") { + val xml = "John Doe" + val result = SimpleXmlParser.extractTagContent("email", xml) + assertEquals(result, None) + } + + test("extractTagContent should return None for malformed XML (missing end tag)") { + val xml = "John Doe" + val result = SimpleXmlParser.extractTagContent("name", xml) + assertEquals(result, None) + } + + test("extractTagContent should extract from complex XML structure") { + val xml = """ + + + John Doe + john@example.com + + + """ + val nameResult = SimpleXmlParser.extractTagContent("name", xml) + val emailResult = SimpleXmlParser.extractTagContent("email", xml) + assertEquals(nameResult, Some("John Doe")) + assertEquals(emailResult, Some("john@example.com")) + } + + test("extractTagContent should handle nested tags with same name") { + val xml = "Outer NameInner Name" + val result = SimpleXmlParser.extractTagContent("name", xml) + // Should extract the first occurrence + assertEquals(result, Some("Outer Name")) + } + + test("extractSection should extract complete XML section") { + val xml = """ + + + John Doe + john@example.com + + success + + """ + val result = SimpleXmlParser.extractSection("user", xml) + assert(result.isDefined) + val userSection = result.get + assert(userSection.contains("")) + assert(userSection.contains("")) + assert(userSection.contains("John Doe")) + assert(userSection.contains("john@example.com")) + } + + test("extractSection should return None for non-existent section") { + val xml = "success" + val result = SimpleXmlParser.extractSection("user", xml) + assertEquals(result, None) + } + + test("extractSection should return None for malformed XML section") { + val xml = "John Doe" + val result = SimpleXmlParser.extractSection("user", xml) + assertEquals(result, None) + } + + test("extractSection should handle nested sections") { + val xml = """ + + + test + + + """ + val result = SimpleXmlParser.extractSection("inner", xml) + assert(result.isDefined) + val innerSection = result.get + assertEquals(innerSection.trim, "\n test\n ") + } + + test("requireTag should return content for existing tag") { + val xml = "AKIAIOSFODNN7EXAMPLE" + val result = SimpleXmlParser.requireTag("AccessKeyId", xml, "AccessKeyId not found") + assertEquals(result, "AKIAIOSFODNN7EXAMPLE") + } + + test("requireTag should throw exception for non-existent tag") { + val xml = "John Doe" + intercept[IllegalArgumentException] { + SimpleXmlParser.requireTag("email", xml, "Email tag not found") + } + } + + test("requireTag should throw exception for empty tag content") { + val xml = "" + intercept[IllegalArgumentException] { + SimpleXmlParser.requireTag("name", xml, "Name cannot be empty") + } + } + + test("requireTag should throw exception for whitespace-only content") { + val xml = " " + intercept[IllegalArgumentException] { + SimpleXmlParser.requireTag("name", xml, "Name cannot be empty") + } + } + + test("requireTag should handle valid content with entities") { + val xml = "Hello & welcome" + val result = SimpleXmlParser.requireTag("message", xml, "Message not found") + assertEquals(result, "Hello & welcome") + } + + test("parse AWS STS AssumeRoleWithWebIdentity response") { + val stsResponse = """ + + + + ASIAIOSFODNN7EXAMPLE + wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3 + 2024-12-06T12:00:00Z + + amzn1.account.AF6RHO7KZU5XRVQJGXK6HB56KR2A + + AROA3XFRBF535PLBQAARL:app-session-123 + arn:aws:sts::123456789012:assumed-role/TestRole/app-session-123 + + + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + + """ + + // Extract credentials section + val credentialsSection = SimpleXmlParser.extractSection("Credentials", stsResponse) + assert(credentialsSection.isDefined) + + // Extract individual credential fields + val credentials = credentialsSection.get + val accessKeyId = SimpleXmlParser.extractTagContent("AccessKeyId", credentials) + val secretAccessKey = SimpleXmlParser.extractTagContent("SecretAccessKey", credentials) + val sessionToken = SimpleXmlParser.extractTagContent("SessionToken", credentials) + val expiration = SimpleXmlParser.extractTagContent("Expiration", credentials) + + assertEquals(accessKeyId, Some("ASIAIOSFODNN7EXAMPLE")) + assertEquals(secretAccessKey, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) + assertEquals(sessionToken, Some("IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3")) + assertEquals(expiration, Some("2024-12-06T12:00:00Z")) + + // Test requireTag functionality + val requiredAccessKeyId = SimpleXmlParser.requireTag("AccessKeyId", credentials, "AccessKeyId is required") + assertEquals(requiredAccessKeyId, "ASIAIOSFODNN7EXAMPLE") + } + + test("parse AWS STS error response") { + val errorResponse = """ + + + Sender + InvalidParameterValue + The security token included in the request is invalid + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + """ + + val errorSection = SimpleXmlParser.extractSection("Error", errorResponse) + assert(errorSection.isDefined) + + val error = errorSection.get + val errorType = SimpleXmlParser.extractTagContent("Type", error) + val errorCode = SimpleXmlParser.extractTagContent("Code", error) + val errorMessage = SimpleXmlParser.extractTagContent("Message", error) + + assertEquals(errorType, Some("Sender")) + assertEquals(errorCode, Some("InvalidParameterValue")) + assertEquals(errorMessage, Some("The security token included in the request is invalid")) + } + + test("handle XML with special characters and entities") { + val xml = """Data contains <brackets> & "quotes" 'apostrophes'""" + val result = SimpleXmlParser.extractTagContent("message", xml) + assertEquals(result, Some("Data contains & \"quotes\" 'apostrophes'")) + } + + test("handle empty XML documents") { + val xml = "" + val result = SimpleXmlParser.extractTagContent("any", xml) + assertEquals(result, None) + } + + test("handle XML with CDATA sections") { + val xml = " and & entities]]>" + // Note: This simple parser doesn't handle CDATA, but should still extract content + val result = SimpleXmlParser.extractTagContent("data", xml) + assertEquals(result, Some(" and & entities]]>")) + } \ No newline at end of file From 5a3f77308d659f8928475430744ae3f76c246b6c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 15:59:40 +0900 Subject: [PATCH 119/215] Delete unused --- .../credentials/WebIdentityTokenFileCredentialsProvider.scala | 3 --- .../scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala | 1 - 2 files changed, 4 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 846bd46cb..e5f89d3d2 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -95,7 +95,6 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: webIdentityTokenFile = Path(tokenFile), roleArn = arn, roleSessionName = roleSessionName, - providerName = BusinessMetricFeatureId.CREDENTIALS_WEB_IDENTITY_TOKEN.code ) ) case _ => None @@ -125,13 +124,11 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: * @param webIdentityTokenFile Path to the JWT token file * @param roleArn The ARN of the IAM role to assume * @param roleSessionName Optional session name for the assumed role session - * @param providerName Provider identifier for logging and metrics */ case class WebIdentityTokenCredentialProperties( webIdentityTokenFile: Path, roleArn: String, roleSessionName: Option[String], - providerName: String ) object WebIdentityTokenFileCredentialsProvider: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala index 52c06a615..ab52cc57a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -61,5 +61,4 @@ enum BusinessMetricFeatureId(val code: String): case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") case CREDENTIALS_CONTAINER extends BusinessMetricFeatureId("1") - case CREDENTIALS_WEB_IDENTITY_TOKEN extends BusinessMetricFeatureId("k") case UNKNOWN extends BusinessMetricFeatureId("Unknown") From 657c4056d07d3ab48900e5ddbacdffbc55eeae34 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 16:16:50 +0900 Subject: [PATCH 120/215] Create DefaultAwsCredentialsIdentityTest --- .../DefaultAwsCredentialsIdentity.scala | 16 +- .../DefaultAwsCredentialsIdentityTest.scala | 389 ++++++++++++++++++ 2 files changed, 396 insertions(+), 9 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index c201c09ee..78f1fe448 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -30,15 +30,13 @@ final case class DefaultAwsCredentialsIdentity( builder.result() override def equals(obj: Any): Boolean = - if this == obj then true - else if obj == null || getClass != obj.getClass then false - else - obj match - case that: AwsCredentialsIdentity => - Objects.equals(accessKeyId, that.accessKeyId) && - Objects.equals(secretAccessKey, that.secretAccessKey) && - Objects.equals(accountId, that.accountId) - case _ => false + obj match + case that: DefaultAwsCredentialsIdentity => + (this eq that) || + (Objects.equals(accessKeyId, that.accessKeyId) && + Objects.equals(secretAccessKey, that.secretAccessKey) && + Objects.equals(accountId, that.accountId)) + case _ => false override def hashCode(): Int = var hashCode = 1 diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala new file mode 100644 index 000000000..a8fddd150 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala @@ -0,0 +1,389 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity.internal + +import java.time.Instant + +import munit.CatsEffectSuite + +import ldbc.amazon.identity.AwsCredentialsIdentity + +class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: + + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testAccountId = "123456789012" + private val testProviderName = "test-provider" + private val testExpirationTime = Instant.parse("2024-12-06T12:00:00Z") + + test("create minimal credentials identity") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, None) + assertEquals(credentials.expirationTime, None) + assertEquals(credentials.providerName, None) + } + + test("create full credentials identity with all fields") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, Some(testAccountId)) + assertEquals(credentials.expirationTime, Some(testExpirationTime)) + assertEquals(credentials.providerName, Some(testProviderName)) + } + + test("create credentials using factory method") { + val credentials = AwsCredentialsIdentity.create(testAccessKeyId, testSecretAccessKey) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, None) + assertEquals(credentials.expirationTime, None) + assertEquals(credentials.providerName, None) + } + + test("toString should not expose secret access key") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val stringRepresentation = credentials.toString + + // Should contain access key and other safe fields + assert(stringRepresentation.contains(testAccessKeyId)) + assert(stringRepresentation.contains(testAccountId)) + assert(stringRepresentation.contains(testProviderName)) + + // Should NOT contain secret access key + assert(!stringRepresentation.contains(testSecretAccessKey)) + + // Should contain class name + assert(stringRepresentation.contains("AwsCredentialsIdentity")) + } + + test("toString with minimal fields") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + val stringRepresentation = credentials.toString + + assertEquals(stringRepresentation, s"AwsCredentialsIdentity(accessKeyId=$testAccessKeyId)") + assert(!stringRepresentation.contains(testSecretAccessKey)) + } + + test("toString with partial fields") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = None + ) + + val stringRepresentation = credentials.toString + val expected = s"AwsCredentialsIdentity(accessKeyId=$testAccessKeyId, accountId=$testAccountId)" + + assertEquals(stringRepresentation, expected) + assert(!stringRepresentation.contains(testSecretAccessKey)) + } + + test("equals should work correctly for identical credentials") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials1, credentials2) + assertEquals(credentials1.hashCode(), credentials2.hashCode()) + } + + test("equals should work for credentials with different expiration and provider") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(Instant.parse("2025-01-01T00:00:00Z")), // Different expiration + providerName = Some("different-provider") // Different provider + ) + + // Should still be equal because equals only compares key fields + assertEquals(credentials1, credentials2) + assertEquals(credentials1.hashCode(), credentials2.hashCode()) + } + + test("equals should return false for different access key") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = "DIFFERENT_ACCESS_KEY", + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should return false for different secret access key") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = "DIFFERENT_SECRET_KEY", + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should return false for different account ID") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some("999999999999"), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should handle None vs Some account ID") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should return false for null object") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assert(credentials != null) + assert(!credentials.equals(null)) + } + + test("equals should return false for different class") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assert(credentials.toString != "not a credentials object") + assert(!credentials.equals("not a credentials object")) + } + + test("equals should work with same concrete type") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, // Different account ID + expirationTime = None, + providerName = None + ) + + val credentials3 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), // Same account ID + expirationTime = None, + providerName = None + ) + + // Different account IDs (Some vs None) + assert(credentials1 != credentials2) + + // Same account IDs + assertEquals(credentials1, credentials3) + } + + test("equals should return false for different AwsCredentialsIdentity implementations") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + // Factory method creates same type, but let's test interface behavior + val credentials2: AwsCredentialsIdentity = AwsCredentialsIdentity.create(testAccessKeyId, testSecretAccessKey) + + // These should be equal since factory creates DefaultAwsCredentialsIdentity + // but with different accountId (None vs Some) + assert(credentials1 != credentials2) + + // Test with exactly same fields through factory + val credentials3 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assertEquals(credentials2, credentials3) + } + + test("hashCode should be consistent") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val hashCode1 = credentials.hashCode() + val hashCode2 = credentials.hashCode() + + assertEquals(hashCode1, hashCode2) + } + + test("implements Identity interface correctly") { + val credentials: ldbc.amazon.identity.Identity = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials.expirationTime, Some(testExpirationTime)) + assertEquals(credentials.providerName, Some(testProviderName)) + } + + test("implements AwsCredentialsIdentity interface correctly") { + val credentials: AwsCredentialsIdentity = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, Some(testAccountId)) + assertEquals(credentials.expirationTime, Some(testExpirationTime)) + assertEquals(credentials.providerName, Some(testProviderName)) + } + + test("handle empty strings") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = "", + secretAccessKey = "", + accountId = Some(""), + expirationTime = None, + providerName = Some("") + ) + + assertEquals(credentials.accessKeyId, "") + assertEquals(credentials.secretAccessKey, "") + assertEquals(credentials.accountId, Some("")) + assertEquals(credentials.providerName, Some("")) + } \ No newline at end of file From 4dac0478e0628b42369e920107a3746e6c5d91ba Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 16:17:00 +0900 Subject: [PATCH 121/215] Fixed compile error --- .../WebIdentityTokenFileCredentialsProvider.scala | 1 - .../credentials/internal/WebIdentityCredentialsUtils.scala | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index e5f89d3d2..ae48a4e94 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -20,7 +20,6 @@ import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* -import ldbc.amazon.useragent.BusinessMetricFeatureId import ldbc.amazon.util.SdkSystemSetting /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 8ddbe0720..851e4c329 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -66,7 +66,7 @@ object WebIdentityCredentialsUtils: roleSessionName = config.roleSessionName ) stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest, region, httpClient) - credentials = convertStsResponseToCredentials(stsResponse, config) + credentials = convertStsResponseToCredentials(stsResponse) yield credentials /** @@ -118,14 +118,13 @@ object WebIdentityCredentialsUtils: */ private def convertStsResponseToCredentials( stsResponse: StsClient.AssumeRoleWithWebIdentityResponse, - config: WebIdentityTokenCredentialProperties ): AwsCredentials = AwsSessionCredentials( accessKeyId = stsResponse.accessKeyId, secretAccessKey = stsResponse.secretAccessKey, sessionToken = stsResponse.sessionToken, validateCredentials = false, - providerName = Some(config.providerName), + providerName = None, accountId = extractAccountIdFromArn(stsResponse.assumedRoleArn), expirationTime = Some(stsResponse.expiration) ) From 3e9c07907007fd29b6e6cdf09dc53109ad827e28 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 16:17:25 +0900 Subject: [PATCH 122/215] Action sbt scalafmtAll --- ...IdentityTokenFileCredentialsProvider.scala | 4 +- .../WebIdentityCredentialsUtils.scala | 2 +- .../DefaultAwsCredentialsIdentity.scala | 6 +- .../DefaultAwsCredentialsIdentityTest.scala | 244 +++++++++--------- .../amazon/util/SimpleXmlParserTest.scala | 52 ++-- 5 files changed, 154 insertions(+), 154 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index ae48a4e94..17525aa96 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -93,7 +93,7 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: WebIdentityTokenCredentialProperties( webIdentityTokenFile = Path(tokenFile), roleArn = arn, - roleSessionName = roleSessionName, + roleSessionName = roleSessionName ) ) case _ => None @@ -127,7 +127,7 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: case class WebIdentityTokenCredentialProperties( webIdentityTokenFile: Path, roleArn: String, - roleSessionName: Option[String], + roleSessionName: Option[String] ) object WebIdentityTokenFileCredentialsProvider: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 851e4c329..6513ad49b 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -117,7 +117,7 @@ object WebIdentityCredentialsUtils: * @return AWS session credentials */ private def convertStsResponseToCredentials( - stsResponse: StsClient.AssumeRoleWithWebIdentityResponse, + stsResponse: StsClient.AssumeRoleWithWebIdentityResponse ): AwsCredentials = AwsSessionCredentials( accessKeyId = stsResponse.accessKeyId, diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index 78f1fe448..a8e0826c4 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -33,9 +33,9 @@ final case class DefaultAwsCredentialsIdentity( obj match case that: DefaultAwsCredentialsIdentity => (this eq that) || - (Objects.equals(accessKeyId, that.accessKeyId) && - Objects.equals(secretAccessKey, that.secretAccessKey) && - Objects.equals(accountId, that.accountId)) + (Objects.equals(accessKeyId, that.accessKeyId) && + Objects.equals(secretAccessKey, that.secretAccessKey) && + Objects.equals(accountId, that.accountId)) case _ => false override def hashCode(): Int = diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala index a8fddd150..871368574 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala @@ -14,19 +14,19 @@ import ldbc.amazon.identity.AwsCredentialsIdentity class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: - private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" - private val testAccountId = "123456789012" - private val testProviderName = "test-provider" - private val testExpirationTime = Instant.parse("2024-12-06T12:00:00Z") + private val testAccountId = "123456789012" + private val testProviderName = "test-provider" + private val testExpirationTime = Instant.parse("2024-12-06T12:00:00Z") test("create minimal credentials identity") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) assertEquals(credentials.accessKeyId, testAccessKeyId) @@ -38,11 +38,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("create full credentials identity with all fields") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) assertEquals(credentials.accessKeyId, testAccessKeyId) @@ -64,11 +64,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("toString should not expose secret access key") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) val stringRepresentation = credentials.toString @@ -77,21 +77,21 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: assert(stringRepresentation.contains(testAccessKeyId)) assert(stringRepresentation.contains(testAccountId)) assert(stringRepresentation.contains(testProviderName)) - + // Should NOT contain secret access key assert(!stringRepresentation.contains(testSecretAccessKey)) - + // Should contain class name assert(stringRepresentation.contains("AwsCredentialsIdentity")) } test("toString with minimal fields") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) val stringRepresentation = credentials.toString @@ -102,15 +102,15 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("toString with partial fields") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = None + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = None ) val stringRepresentation = credentials.toString - val expected = s"AwsCredentialsIdentity(accessKeyId=$testAccessKeyId, accountId=$testAccountId)" + val expected = s"AwsCredentialsIdentity(accessKeyId=$testAccessKeyId, accountId=$testAccountId)" assertEquals(stringRepresentation, expected) assert(!stringRepresentation.contains(testSecretAccessKey)) @@ -118,19 +118,19 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should work correctly for identical credentials") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) assertEquals(credentials1, credentials2) @@ -139,19 +139,19 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should work for credentials with different expiration and provider") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(Instant.parse("2025-01-01T00:00:00Z")), // Different expiration - providerName = Some("different-provider") // Different provider + accountId = Some(testAccountId), + expirationTime = Some(Instant.parse("2025-01-01T00:00:00Z")), // Different expiration + providerName = Some("different-provider") // Different provider ) // Should still be equal because equals only compares key fields @@ -161,19 +161,19 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should return false for different access key") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = "DIFFERENT_ACCESS_KEY", + accessKeyId = "DIFFERENT_ACCESS_KEY", secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) assert(credentials1 != credentials2) @@ -182,19 +182,19 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should return false for different secret access key") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = "DIFFERENT_SECRET_KEY", - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) assert(credentials1 != credentials2) @@ -203,19 +203,19 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should return false for different account ID") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some("999999999999"), - expirationTime = None, - providerName = None + accountId = Some("999999999999"), + expirationTime = None, + providerName = None ) assert(credentials1 != credentials2) @@ -224,19 +224,19 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should handle None vs Some account ID") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) assert(credentials1 != credentials2) @@ -245,11 +245,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should return false for null object") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) assert(credentials != null) @@ -258,11 +258,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should return false for different class") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) assert(credentials.toString != "not a credentials object") @@ -271,27 +271,27 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should work with same concrete type") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) val credentials2 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, // Different account ID - expirationTime = None, - providerName = None + accountId = None, // Different account ID + expirationTime = None, + providerName = None ) val credentials3 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), // Same account ID - expirationTime = None, - providerName = None + accountId = Some(testAccountId), // Same account ID + expirationTime = None, + providerName = None ) // Different account IDs (Some vs None) @@ -303,27 +303,27 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("equals should return false for different AwsCredentialsIdentity implementations") { val credentials1 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = None, - providerName = None + accountId = Some(testAccountId), + expirationTime = None, + providerName = None ) // Factory method creates same type, but let's test interface behavior val credentials2: AwsCredentialsIdentity = AwsCredentialsIdentity.create(testAccessKeyId, testSecretAccessKey) - - // These should be equal since factory creates DefaultAwsCredentialsIdentity + + // These should be equal since factory creates DefaultAwsCredentialsIdentity // but with different accountId (None vs Some) assert(credentials1 != credentials2) // Test with exactly same fields through factory val credentials3 = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = None, - expirationTime = None, - providerName = None + accountId = None, + expirationTime = None, + providerName = None ) assertEquals(credentials2, credentials3) @@ -331,11 +331,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("hashCode should be consistent") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) val hashCode1 = credentials.hashCode() @@ -346,11 +346,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("implements Identity interface correctly") { val credentials: ldbc.amazon.identity.Identity = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) assertEquals(credentials.expirationTime, Some(testExpirationTime)) @@ -359,11 +359,11 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("implements AwsCredentialsIdentity interface correctly") { val credentials: AwsCredentialsIdentity = DefaultAwsCredentialsIdentity( - accessKeyId = testAccessKeyId, + accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, - accountId = Some(testAccountId), - expirationTime = Some(testExpirationTime), - providerName = Some(testProviderName) + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) ) assertEquals(credentials.accessKeyId, testAccessKeyId) @@ -375,15 +375,15 @@ class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: test("handle empty strings") { val credentials = DefaultAwsCredentialsIdentity( - accessKeyId = "", + accessKeyId = "", secretAccessKey = "", - accountId = Some(""), - expirationTime = None, - providerName = Some("") + accountId = Some(""), + expirationTime = None, + providerName = Some("") ) assertEquals(credentials.accessKeyId, "") assertEquals(credentials.secretAccessKey, "") assertEquals(credentials.accountId, Some("")) assertEquals(credentials.providerName, Some("")) - } \ No newline at end of file + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala index 47f9e6c61..b0f346f81 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala @@ -11,13 +11,13 @@ import munit.CatsEffectSuite class SimpleXmlParserTest extends CatsEffectSuite: test("decodeXmlEntities should decode all standard XML entities") { - val input = "&<>"'" + val input = "&<>"'" val expected = "&<>\"'" assertEquals(SimpleXmlParser.decodeXmlEntities(input), expected) } test("decodeXmlEntities should handle mixed content with entities") { - val input = "Hello & welcome to <XML> parsing "test"" + val input = "Hello & welcome to <XML> parsing "test"" val expected = "Hello & welcome to parsing \"test\"" assertEquals(SimpleXmlParser.decodeXmlEntities(input), expected) } @@ -32,37 +32,37 @@ class SimpleXmlParserTest extends CatsEffectSuite: } test("extractTagContent should extract simple tag content") { - val xml = "John Doe" + val xml = "John Doe" val result = SimpleXmlParser.extractTagContent("name", xml) assertEquals(result, Some("John Doe")) } test("extractTagContent should extract content with whitespace trimmed") { - val xml = " John Doe " + val xml = " John Doe " val result = SimpleXmlParser.extractTagContent("name", xml) assertEquals(result, Some("John Doe")) } test("extractTagContent should extract content with XML entities decoded") { - val xml = "Hello & welcome to <XML>" + val xml = "Hello & welcome to <XML>" val result = SimpleXmlParser.extractTagContent("message", xml) assertEquals(result, Some("Hello & welcome to ")) } test("extractTagContent should return None for non-existent tag") { - val xml = "John Doe" + val xml = "John Doe" val result = SimpleXmlParser.extractTagContent("email", xml) assertEquals(result, None) } test("extractTagContent should return None for malformed XML (missing end tag)") { - val xml = "John Doe" + val xml = "John Doe" val result = SimpleXmlParser.extractTagContent("name", xml) assertEquals(result, None) } test("extractTagContent should extract from complex XML structure") { - val xml = """ + val xml = """ John Doe @@ -70,21 +70,21 @@ class SimpleXmlParserTest extends CatsEffectSuite: """ - val nameResult = SimpleXmlParser.extractTagContent("name", xml) + val nameResult = SimpleXmlParser.extractTagContent("name", xml) val emailResult = SimpleXmlParser.extractTagContent("email", xml) assertEquals(nameResult, Some("John Doe")) assertEquals(emailResult, Some("john@example.com")) } test("extractTagContent should handle nested tags with same name") { - val xml = "Outer NameInner Name" + val xml = "Outer NameInner Name" val result = SimpleXmlParser.extractTagContent("name", xml) // Should extract the first occurrence assertEquals(result, Some("Outer Name")) } test("extractSection should extract complete XML section") { - val xml = """ + val xml = """ John Doe @@ -103,19 +103,19 @@ class SimpleXmlParserTest extends CatsEffectSuite: } test("extractSection should return None for non-existent section") { - val xml = "success" + val xml = "success" val result = SimpleXmlParser.extractSection("user", xml) assertEquals(result, None) } test("extractSection should return None for malformed XML section") { - val xml = "John Doe" + val xml = "John Doe" val result = SimpleXmlParser.extractSection("user", xml) assertEquals(result, None) } test("extractSection should handle nested sections") { - val xml = """ + val xml = """ test @@ -129,7 +129,7 @@ class SimpleXmlParserTest extends CatsEffectSuite: } test("requireTag should return content for existing tag") { - val xml = "AKIAIOSFODNN7EXAMPLE" + val xml = "AKIAIOSFODNN7EXAMPLE" val result = SimpleXmlParser.requireTag("AccessKeyId", xml, "AccessKeyId not found") assertEquals(result, "AKIAIOSFODNN7EXAMPLE") } @@ -156,7 +156,7 @@ class SimpleXmlParserTest extends CatsEffectSuite: } test("requireTag should handle valid content with entities") { - val xml = "Hello & welcome" + val xml = "Hello & welcome" val result = SimpleXmlParser.requireTag("message", xml, "Message not found") assertEquals(result, "Hello & welcome") } @@ -187,11 +187,11 @@ class SimpleXmlParserTest extends CatsEffectSuite: assert(credentialsSection.isDefined) // Extract individual credential fields - val credentials = credentialsSection.get - val accessKeyId = SimpleXmlParser.extractTagContent("AccessKeyId", credentials) + val credentials = credentialsSection.get + val accessKeyId = SimpleXmlParser.extractTagContent("AccessKeyId", credentials) val secretAccessKey = SimpleXmlParser.extractTagContent("SecretAccessKey", credentials) - val sessionToken = SimpleXmlParser.extractTagContent("SessionToken", credentials) - val expiration = SimpleXmlParser.extractTagContent("Expiration", credentials) + val sessionToken = SimpleXmlParser.extractTagContent("SessionToken", credentials) + val expiration = SimpleXmlParser.extractTagContent("Expiration", credentials) assertEquals(accessKeyId, Some("ASIAIOSFODNN7EXAMPLE")) assertEquals(secretAccessKey, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) @@ -217,9 +217,9 @@ class SimpleXmlParserTest extends CatsEffectSuite: val errorSection = SimpleXmlParser.extractSection("Error", errorResponse) assert(errorSection.isDefined) - val error = errorSection.get - val errorType = SimpleXmlParser.extractTagContent("Type", error) - val errorCode = SimpleXmlParser.extractTagContent("Code", error) + val error = errorSection.get + val errorType = SimpleXmlParser.extractTagContent("Type", error) + val errorCode = SimpleXmlParser.extractTagContent("Code", error) val errorMessage = SimpleXmlParser.extractTagContent("Message", error) assertEquals(errorType, Some("Sender")) @@ -228,13 +228,13 @@ class SimpleXmlParserTest extends CatsEffectSuite: } test("handle XML with special characters and entities") { - val xml = """Data contains <brackets> & "quotes" 'apostrophes'""" + val xml = """Data contains <brackets> & "quotes" 'apostrophes'""" val result = SimpleXmlParser.extractTagContent("message", xml) assertEquals(result, Some("Data contains & \"quotes\" 'apostrophes'")) } test("handle empty XML documents") { - val xml = "" + val xml = "" val result = SimpleXmlParser.extractTagContent("any", xml) assertEquals(result, None) } @@ -244,4 +244,4 @@ class SimpleXmlParserTest extends CatsEffectSuite: // Note: This simple parser doesn't handle CDATA, but should still extract content val result = SimpleXmlParser.extractTagContent("data", xml) assertEquals(result, Some(" and & entities]]>")) - } \ No newline at end of file + } From 45278378dc1a7d16878bb7917ede68e373d10ecc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 16:30:46 +0900 Subject: [PATCH 123/215] Replace region property --- .../WebIdentityTokenFileCredentialsProvider.scala | 4 ++-- .../internal/WebIdentityCredentialsUtils.scala | 9 +++++---- .../main/scala/ldbc/amazon/client/StsClient.scala | 14 +++++++------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 17525aa96..e9bb01384 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -144,7 +144,7 @@ object WebIdentityTokenFileCredentialsProvider: ): F[WebIdentityTokenFileCredentialsProvider[F]] = for httpClient <- createDefaultHttpClient[F]() - webIdentityUtils = WebIdentityCredentialsUtils.default[F] + webIdentityUtils = WebIdentityCredentialsUtils.default[F](region) yield new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) /** @@ -159,7 +159,7 @@ object WebIdentityTokenFileCredentialsProvider: httpClient: HttpClient[F], region: String = "us-east-1" ): WebIdentityTokenFileCredentialsProvider[F] = - val webIdentityUtils = WebIdentityCredentialsUtils.default[F] + val webIdentityUtils = WebIdentityCredentialsUtils.default[F](region) new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 6513ad49b..ed26ae3d3 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -65,7 +65,7 @@ object WebIdentityCredentialsUtils: webIdentityToken = token, roleSessionName = config.roleSessionName ) - stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest, region, httpClient) + stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest, httpClient) credentials = convertStsResponseToCredentials(stsResponse) yield credentials @@ -142,12 +142,13 @@ object WebIdentityCredentialsUtils: /** * Creates a default implementation of WebIdentityCredentialsUtils. - * + * + * @param region The AWS region for STS endpoint * @tparam F The effect type * @return A WebIdentityCredentialsUtils instance */ - def default[F[_]: Files: UUIDGen: Concurrent]: WebIdentityCredentialsUtils[F] = - val stsClient = StsClient.default[F] + def default[F[_]: Files: UUIDGen: Concurrent](region: String): WebIdentityCredentialsUtils[F] = + val stsClient = StsClient.default[F](region) Impl[F](stsClient) /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 9879d2ce2..f65c0b275 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -31,13 +31,11 @@ trait StsClient[F[_]]: * Performs AssumeRoleWithWebIdentity operation. * * @param request The STS request parameters - * @param region The AWS region for STS endpoint * @param httpClient HTTP client for making requests * @return STS response with temporary credentials */ def assumeRoleWithWebIdentity( request: StsClient.AssumeRoleWithWebIdentityRequest, - region: String, httpClient: HttpClient[F] ): F[StsClient.AssumeRoleWithWebIdentityResponse] @@ -75,11 +73,13 @@ object StsClient: assumedRoleArn: String ) - private case class Impl[F[_]: UUIDGen: Concurrent]() extends StsClient[F]: + private case class Impl[F[_]: UUIDGen: Concurrent]( + region: String, + stsEndpoint: String + ) extends StsClient[F]: def assumeRoleWithWebIdentity( request: AssumeRoleWithWebIdentityRequest, - region: String, httpClient: HttpClient[F] ): F[AssumeRoleWithWebIdentityResponse] = for @@ -90,7 +90,6 @@ object StsClient: duration = request.durationSeconds.getOrElse(3600) // Build STS request - stsEndpoint = s"https://sts.$region.amazonaws.com/" requestBody = buildRequestBody( request.copy( roleSessionName = Some(sessionName), @@ -112,11 +111,12 @@ object StsClient: /** * Creates a default implementation of StsClient. - * + * + * @param region The AWS region for STS endpoint * @tparam F The effect type * @return A StsClient instance */ - def default[F[_]: UUIDGen: Concurrent]: StsClient[F] = Impl[F]() + def default[F[_]: UUIDGen: Concurrent](region: String): StsClient[F] = Impl[F](region, s"https://sts.$region.amazonaws.com/") /** * Builds the STS request body in AWS Query format. From e998ebe88eff5bd0c5dc4181e63385ac83d20c0e Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:11:02 +0900 Subject: [PATCH 124/215] Replace httpClient property --- ...IdentityTokenFileCredentialsProvider.scala | 18 +++++---------- .../WebIdentityCredentialsUtils.scala | 15 +++++-------- .../scala/ldbc/amazon/client/HttpClient.scala | 2 ++ .../ldbc/amazon/client/SimpleHttpClient.scala | 13 +++++++++++ .../scala/ldbc/amazon/client/StsClient.scala | 22 +++++++++---------- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index e9bb01384..6ff900c5a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -61,8 +61,6 @@ import ldbc.amazon.util.SdkSystemSetting */ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: Concurrent]( webIdentityUtils: WebIdentityCredentialsUtils[F], - httpClient: HttpClient[F], - region: String = "us-east-1" ) extends AwsCredentialsProvider[F]: override def resolveCredentials(): F[AwsCredentials] = @@ -78,7 +76,7 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: ) ) case Some(webIdentityConfig) => - webIdentityUtils.assumeRoleWithWebIdentity(webIdentityConfig, region, httpClient) + webIdentityUtils.assumeRoleWithWebIdentity(webIdentityConfig) } yield credentials @@ -144,8 +142,8 @@ object WebIdentityTokenFileCredentialsProvider: ): F[WebIdentityTokenFileCredentialsProvider[F]] = for httpClient <- createDefaultHttpClient[F]() - webIdentityUtils = WebIdentityCredentialsUtils.default[F](region) - yield new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) + webIdentityUtils = WebIdentityCredentialsUtils.default[F](region, httpClient) + yield new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) /** * Creates a new Web Identity Token File credentials provider with custom HTTP client. @@ -159,24 +157,20 @@ object WebIdentityTokenFileCredentialsProvider: httpClient: HttpClient[F], region: String = "us-east-1" ): WebIdentityTokenFileCredentialsProvider[F] = - val webIdentityUtils = WebIdentityCredentialsUtils.default[F](region) - new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) + val webIdentityUtils = WebIdentityCredentialsUtils.default[F](region, httpClient) + new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) /** * Creates a new Web Identity Token File credentials provider with custom WebIdentityCredentialsUtils. * * @param webIdentityUtils Custom Web Identity credentials utility - * @param httpClient The HTTP client for STS requests - * @param region The AWS region for STS endpoint * @tparam F The effect type * @return A new WebIdentityTokenFileCredentialsProvider instance */ def create[F[_]: Env: SystemProperties: Concurrent]( webIdentityUtils: WebIdentityCredentialsUtils[F], - httpClient: HttpClient[F], - region: String = "us-east-1" ): WebIdentityTokenFileCredentialsProvider[F] = - new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils, httpClient, region) + new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) /** * Creates a default HTTP client for STS operations. diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index ed26ae3d3..bc817417c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -36,14 +36,10 @@ trait WebIdentityCredentialsUtils[F[_]]: * a Web Identity Token (JWT) for temporary AWS credentials. * * @param config The Web Identity Token configuration - * @param region The AWS region for STS endpoint (default: us-east-1) - * @param httpClient HTTP client for making STS requests * @return AWS credentials with session token */ def assumeRoleWithWebIdentity( config: WebIdentityTokenCredentialProperties, - region: String, - httpClient: HttpClient[F] ): F[AwsCredentials] object WebIdentityCredentialsUtils: @@ -52,10 +48,8 @@ object WebIdentityCredentialsUtils: stsClient: StsClient[F] ) extends WebIdentityCredentialsUtils[F]: - def assumeRoleWithWebIdentity( + override def assumeRoleWithWebIdentity( config: WebIdentityTokenCredentialProperties, - region: String, - httpClient: HttpClient[F] ): F[AwsCredentials] = for token <- readTokenFromFile(config.webIdentityTokenFile) @@ -65,7 +59,7 @@ object WebIdentityCredentialsUtils: webIdentityToken = token, roleSessionName = config.roleSessionName ) - stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest, httpClient) + stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest) credentials = convertStsResponseToCredentials(stsResponse) yield credentials @@ -144,11 +138,12 @@ object WebIdentityCredentialsUtils: * Creates a default implementation of WebIdentityCredentialsUtils. * * @param region The AWS region for STS endpoint + * @param httpClient HTTP client for making requests * @tparam F The effect type * @return A WebIdentityCredentialsUtils instance */ - def default[F[_]: Files: UUIDGen: Concurrent](region: String): WebIdentityCredentialsUtils[F] = - val stsClient = StsClient.default[F](region) + def default[F[_]: Files: UUIDGen: Concurrent](region: String, httpClient: HttpClient[F]): WebIdentityCredentialsUtils[F] = + val stsClient = StsClient.build[F](s"https://sts.$region.amazonaws.com/", httpClient) Impl[F](stsClient) /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala index 3a6140dc7..9141d9d72 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala @@ -13,3 +13,5 @@ trait HttpClient[F[_]]: def get(uri: URI, headers: Map[String, String]): F[HttpResponse] def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] + + def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index a48e1dfdf..3b0dd5071 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -110,6 +110,19 @@ class SimpleHttpClient[F[_]: Network: Async]( response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) yield response + override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = for h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index f65c0b275..3101a3a26 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -31,12 +31,10 @@ trait StsClient[F[_]]: * Performs AssumeRoleWithWebIdentity operation. * * @param request The STS request parameters - * @param httpClient HTTP client for making requests * @return STS response with temporary credentials */ def assumeRoleWithWebIdentity( - request: StsClient.AssumeRoleWithWebIdentityRequest, - httpClient: HttpClient[F] + request: StsClient.AssumeRoleWithWebIdentityRequest ): F[StsClient.AssumeRoleWithWebIdentityResponse] object StsClient: @@ -74,13 +72,12 @@ object StsClient: ) private case class Impl[F[_]: UUIDGen: Concurrent]( - region: String, - stsEndpoint: String + stsEndpoint: String, + httpClient: HttpClient[F] ) extends StsClient[F]: - def assumeRoleWithWebIdentity( + override def assumeRoleWithWebIdentity( request: AssumeRoleWithWebIdentityRequest, - httpClient: HttpClient[F] ): F[AssumeRoleWithWebIdentityResponse] = for timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) @@ -99,24 +96,25 @@ object StsClient: // Make HTTP request headers = Map( - "Content-Type" -> "application/x-amz-json-1.0", + "Content-Type" -> "application/x-www-form-urlencoded", "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRoleWithWebIdentity", "X-Amz-Date" -> timestamp ) - response <- httpClient.get(URI.create(stsEndpoint), headers) + response <- httpClient.post(URI.create(stsEndpoint), headers, requestBody) _ <- validateHttpResponse(response) stsResponse <- parseAssumeRoleResponse(response.body) yield stsResponse /** - * Creates a default implementation of StsClient. + * Creates implementation of StsClient. * - * @param region The AWS region for STS endpoint + * @param endpoint STS Endpoint + * @param httpClient HTTP client for making requests * @tparam F The effect type * @return A StsClient instance */ - def default[F[_]: UUIDGen: Concurrent](region: String): StsClient[F] = Impl[F](region, s"https://sts.$region.amazonaws.com/") + def build[F[_]: UUIDGen: Concurrent](endpoint: String, httpClient: HttpClient[F]): StsClient[F] = Impl[F](endpoint, httpClient) /** * Builds the STS request body in AWS Query format. From bdd161049076734c274910777819e3e394190fc8 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:11:29 +0900 Subject: [PATCH 125/215] Action sbt scalafmtAll --- .../WebIdentityTokenFileCredentialsProvider.scala | 4 ++-- .../internal/WebIdentityCredentialsUtils.scala | 9 ++++++--- .../scala/ldbc/amazon/client/SimpleHttpClient.scala | 2 +- .../main/scala/ldbc/amazon/client/StsClient.scala | 13 +++++++------ 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 6ff900c5a..60cb89161 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -60,7 +60,7 @@ import ldbc.amazon.util.SdkSystemSetting * @tparam F The effect type */ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: Concurrent]( - webIdentityUtils: WebIdentityCredentialsUtils[F], + webIdentityUtils: WebIdentityCredentialsUtils[F] ) extends AwsCredentialsProvider[F]: override def resolveCredentials(): F[AwsCredentials] = @@ -168,7 +168,7 @@ object WebIdentityTokenFileCredentialsProvider: * @return A new WebIdentityTokenFileCredentialsProvider instance */ def create[F[_]: Env: SystemProperties: Concurrent]( - webIdentityUtils: WebIdentityCredentialsUtils[F], + webIdentityUtils: WebIdentityCredentialsUtils[F] ): WebIdentityTokenFileCredentialsProvider[F] = new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index bc817417c..15c54a37e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -39,7 +39,7 @@ trait WebIdentityCredentialsUtils[F[_]]: * @return AWS credentials with session token */ def assumeRoleWithWebIdentity( - config: WebIdentityTokenCredentialProperties, + config: WebIdentityTokenCredentialProperties ): F[AwsCredentials] object WebIdentityCredentialsUtils: @@ -49,7 +49,7 @@ object WebIdentityCredentialsUtils: ) extends WebIdentityCredentialsUtils[F]: override def assumeRoleWithWebIdentity( - config: WebIdentityTokenCredentialProperties, + config: WebIdentityTokenCredentialProperties ): F[AwsCredentials] = for token <- readTokenFromFile(config.webIdentityTokenFile) @@ -142,7 +142,10 @@ object WebIdentityCredentialsUtils: * @tparam F The effect type * @return A WebIdentityCredentialsUtils instance */ - def default[F[_]: Files: UUIDGen: Concurrent](region: String, httpClient: HttpClient[F]): WebIdentityCredentialsUtils[F] = + def default[F[_]: Files: UUIDGen: Concurrent]( + region: String, + httpClient: HttpClient[F] + ): WebIdentityCredentialsUtils[F] = val stsClient = StsClient.build[F](s"https://sts.$region.amazonaws.com/", httpClient) Impl[F](stsClient) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 3b0dd5071..78a3d4cf7 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -118,7 +118,7 @@ class SimpleHttpClient[F[_]: Network: Async]( port = getDefaultPort(uri) isSecure = isHttps(uri) path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) yield response diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 3101a3a26..9273b8a1a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -34,7 +34,7 @@ trait StsClient[F[_]]: * @return STS response with temporary credentials */ def assumeRoleWithWebIdentity( - request: StsClient.AssumeRoleWithWebIdentityRequest + request: StsClient.AssumeRoleWithWebIdentityRequest ): F[StsClient.AssumeRoleWithWebIdentityResponse] object StsClient: @@ -72,12 +72,12 @@ object StsClient: ) private case class Impl[F[_]: UUIDGen: Concurrent]( - stsEndpoint: String, - httpClient: HttpClient[F] - ) extends StsClient[F]: + stsEndpoint: String, + httpClient: HttpClient[F] + ) extends StsClient[F]: override def assumeRoleWithWebIdentity( - request: AssumeRoleWithWebIdentityRequest, + request: AssumeRoleWithWebIdentityRequest ): F[AssumeRoleWithWebIdentityResponse] = for timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) @@ -114,7 +114,8 @@ object StsClient: * @tparam F The effect type * @return A StsClient instance */ - def build[F[_]: UUIDGen: Concurrent](endpoint: String, httpClient: HttpClient[F]): StsClient[F] = Impl[F](endpoint, httpClient) + def build[F[_]: UUIDGen: Concurrent](endpoint: String, httpClient: HttpClient[F]): StsClient[F] = + Impl[F](endpoint, httpClient) /** * Builds the STS request body in AWS Query format. From 03dfcd8c038c3212f117bc20b80b4bb26138c303 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:32:47 +0900 Subject: [PATCH 126/215] Delete unused --- ...IdentityTokenFileCredentialsProvider.scala | 33 +------------------ .../WebIdentityCredentialsUtils.scala | 1 - 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 60cb89161..3f4041a74 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -6,7 +6,6 @@ package ldbc.amazon.auth.credentials -import scala.concurrent.duration.* import cats.syntax.all.* @@ -14,10 +13,9 @@ import cats.effect.* import cats.effect.std.{ Env, SystemProperties, UUIDGen } import fs2.io.file.{ Files, Path } -import fs2.io.net.* import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils -import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } +import ldbc.amazon.client.HttpClient import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.util.SdkSystemSetting @@ -130,21 +128,6 @@ case class WebIdentityTokenCredentialProperties( object WebIdentityTokenFileCredentialsProvider: - /** - * Creates a new Web Identity Token File credentials provider with default settings. - * - * @param region The AWS region for STS endpoint (default: us-east-1) - * @tparam F The effect type - * @return A new WebIdentityTokenFileCredentialsProvider instance - */ - def apply[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( - region: String = "us-east-1" - ): F[WebIdentityTokenFileCredentialsProvider[F]] = - for - httpClient <- createDefaultHttpClient[F]() - webIdentityUtils = WebIdentityCredentialsUtils.default[F](region, httpClient) - yield new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) - /** * Creates a new Web Identity Token File credentials provider with custom HTTP client. * @@ -172,20 +155,6 @@ object WebIdentityTokenFileCredentialsProvider: ): WebIdentityTokenFileCredentialsProvider[F] = new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) - /** - * Creates a default HTTP client for STS operations. - * - * @tparam F The effect type - * @return A configured HTTP client - */ - private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = - Async[F].pure( - new SimpleHttpClient[F]( - connectTimeout = 30.seconds, - readTimeout = 30.seconds - ) - ) - /** * Checks if Web Identity Token authentication is available by verifying * the presence of required environment variables or system properties. diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index 15c54a37e..c7025722d 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -107,7 +107,6 @@ object WebIdentityCredentialsUtils: * Converts STS response to AWS credentials. * * @param stsResponse The STS AssumeRoleWithWebIdentity response - * @param config The Web Identity Token configuration * @return AWS session credentials */ private def convertStsResponseToCredentials( From 9fa0e6a52ff57051ce4ea43fa44048460c6d99fa Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:37:25 +0900 Subject: [PATCH 127/215] Delete unused --- .../credentials/WebIdentityTokenFileCredentialsProvider.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 3f4041a74..36ce08951 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -53,8 +53,6 @@ import ldbc.amazon.util.SdkSystemSetting * - AWS_WEB_IDENTITY_TOKEN_FILE=/var/run/secrets/eks.amazonaws.com/serviceaccount/token * * @param webIdentityUtils Web Identity credentials utility for STS operations - * @param httpClient HTTP client for STS requests - * @param region AWS region for STS endpoint * @tparam F The effect type */ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: Concurrent]( From 4db1f2178890e083b6315f0b21c6dd27d1c21dba Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:38:08 +0900 Subject: [PATCH 128/215] Create /StsClientTest --- .../ldbc/amazon/client/StsClientTest.scala | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala new file mode 100644 index 000000000..f8d260cf7 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala @@ -0,0 +1,135 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.time.Instant + +import scala.concurrent.duration.* + +import munit.CatsEffectSuite + +import cats.effect.IO +import cats.effect.std.UUIDGen + +import fs2.io.net.Network + +import ldbc.amazon.exception.{ SdkClientException, StsException } + +class StsClientTest extends CatsEffectSuite: + + // LocalStack STS endpoint (from docker-compose.yml) + private val localStackEndpoint = "http://localhost:4566" + + private val testRoleArn = "arn:aws:iam::000000000000:role/localstack-role" + private val roleSessionName = "test-session" + private val testAssumedRoleArn = s"arn:aws:sts::000000000000:assumed-role/localstack-role/$roleSessionName" + private val testWebIdentityToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + + // HTTP client for testing + private def httpClient: SimpleHttpClient[IO] = + new SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) + + // STS client configured for LocalStack + private def localStackStsClient: StsClient[IO] = StsClient.build(localStackEndpoint, httpClient) + + test("assumeRoleWithWebIdentity with LocalStack".flaky) { + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = testRoleArn, + webIdentityToken = testWebIdentityToken, + roleSessionName = Some(roleSessionName), + durationSeconds = Some(1800) + ) + + localStackStsClient.assumeRoleWithWebIdentity(request).map { response => + assert(response.accessKeyId.nonEmpty) + assert(response.secretAccessKey.nonEmpty) + assert(response.sessionToken.nonEmpty) + assert(response.expiration.isAfter(Instant.now())) + assertEquals(response.assumedRoleArn, testAssumedRoleArn) + } + } + + test("handle invalid role ARN") { + val client = localStackStsClient + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = "invalid-role-arn", + webIdentityToken = testWebIdentityToken + ) + + assertIOBoolean(client.assumeRoleWithWebIdentity(request).attempt.map(_.isLeft)) + } + + test("assumeRoleWithWebIdentity with auto-generated session name") { + val client = localStackStsClient + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = testRoleArn, + webIdentityToken = testWebIdentityToken + // No roleSessionName - should be auto-generated + ) + + client.assumeRoleWithWebIdentity(request).attempt.map { result => + // Test should handle auto-generation of session name + result.fold( + error => { + // Expected for LocalStack without proper STS setup + assert(error.isInstanceOf[StsException] || error.isInstanceOf[SdkClientException]) + }, + response => { + assert(response.accessKeyId.nonEmpty) + assert(response.secretAccessKey.nonEmpty) + assert(response.sessionToken.nonEmpty) + } + ) + } + } + + test("AssumeRoleWithWebIdentityResponse validation") { + val now = Instant.now() + val response = StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = "ASIAIOSFODNN7EXAMPLE", + secretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + sessionToken = "session-token", + expiration = now.plusSeconds(3600), + assumedRoleArn = testRoleArn + ) + + assertEquals(response.accessKeyId, "ASIAIOSFODNN7EXAMPLE") + assertEquals(response.secretAccessKey, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + assertEquals(response.sessionToken, "session-token") + assertEquals(response.expiration, now.plusSeconds(3600)) + assertEquals(response.assumedRoleArn, testRoleArn) + } + + test("buildRequestBody format") { + // Test the query parameter formatting + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = "arn:aws:iam::123456789012:role/TestRole", + webIdentityToken = "test-token", + roleSessionName = Some("test-session"), + durationSeconds = Some(1800) + ) + + // We can't access the private method directly, but we can test the behavior + // by ensuring our LocalStack client formats parameters correctly + val queryParams = Map( + "Action" -> "AssumeRoleWithWebIdentity", + "Version" -> "2011-06-15", + "RoleArn" -> request.roleArn, + "WebIdentityToken" -> request.webIdentityToken, + "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), + "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString + ) + + queryParams.foreach { case (key, value) => + assert(key.nonEmpty) + assert(value.nonEmpty) + } + + assert(queryParams("Action") == "AssumeRoleWithWebIdentity") + assert(queryParams("Version") == "2011-06-15") + assert(queryParams("RoleArn") == request.roleArn) + } From bf8aa161ea5290f0b42cf7b2c4584ffdc57c02ce Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:41:56 +0900 Subject: [PATCH 129/215] Added localstack container settings --- docker-compose.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index b94143052..421d2cbff 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,6 +18,22 @@ services: timeout: 20s retries: 10 + localstack: + image: localstack/localstack:latest + platform: linux/amd64 + container_name: ldbc-localstack + ports: + - "4566:4566" + - "4571:4571" + environment: + - SERVICES=iam,sts + - DEBUG=1 + - DOCKER_HOST=unix:///var/run/docker.sock + - AWS_DEFAULT_REGION=ap-northeast-1 + volumes: + - ./localstack:/etc/localstack/init/ready.d + - /var/run/docker.sock:/var/run/docker.sock + verdaccio: image: verdaccio/verdaccio:nightly-master container_name: verdaccio From 9fe12ad86ec1ebabcb020e4c8433812ec23c06db Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:42:09 +0900 Subject: [PATCH 130/215] Create localstack init.sh --- localstack/init.sh | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 localstack/init.sh diff --git a/localstack/init.sh b/localstack/init.sh new file mode 100644 index 000000000..7336f7e73 --- /dev/null +++ b/localstack/init.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Wait for LocalStack to be ready +echo "Waiting for LocalStack to be ready..." +sleep 5 + +# Set AWS credentials for LocalStack +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test + +aws --endpoint-url=http://localhost:4566 iam create-role \ + --role-name localstack-role \ + --assume-role-policy-document '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"arn:aws:iam::000000000000:root"},"Action":"sts:AssumeRole"}]}' From 120d1e4fd9f3174320695b7a8579775d7b24416f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:42:24 +0900 Subject: [PATCH 131/215] Added validateRoleArn --- .../main/scala/ldbc/amazon/client/StsClient.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 9273b8a1a..a743991fe 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -80,6 +80,7 @@ object StsClient: request: AssumeRoleWithWebIdentityRequest ): F[AssumeRoleWithWebIdentityResponse] = for + _ <- validateRoleArn(request.roleArn) timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) sessionName <- request.roleSessionName.fold( UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") @@ -106,6 +107,17 @@ object StsClient: stsResponse <- parseAssumeRoleResponse(response.body) yield stsResponse + private def validateRoleArn(roleArn: String): F[Unit] = + val roleArnPattern = """^arn:aws:iam::\d{12}:role/[\w+=,.@-]+$""".r + roleArnPattern.findFirstIn(roleArn) match { + case Some(_) => Concurrent[F].unit + case None => Concurrent[F].raiseError( + new IllegalArgumentException( + s"An error occurred (ValidationError) when calling the AssumeRole operation: $roleArn is invalid" + ) + ) + } + /** * Creates implementation of StsClient. * From 6b1d7d3ddba4bcda21c88119f90a57742d8328bd Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:42:43 +0900 Subject: [PATCH 132/215] Action sbt scalafmtAll --- ...IdentityTokenFileCredentialsProvider.scala | 1 - .../scala/ldbc/amazon/client/StsClient.scala | 11 ++-- .../ldbc/amazon/client/StsClientTest.scala | 64 ++++++++++--------- 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 36ce08951..4f1f22788 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -6,7 +6,6 @@ package ldbc.amazon.auth.credentials - import cats.syntax.all.* import cats.effect.* diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index a743991fe..0ae3e890d 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -80,7 +80,7 @@ object StsClient: request: AssumeRoleWithWebIdentityRequest ): F[AssumeRoleWithWebIdentityResponse] = for - _ <- validateRoleArn(request.roleArn) + _ <- validateRoleArn(request.roleArn) timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) sessionName <- request.roleSessionName.fold( UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") @@ -111,11 +111,12 @@ object StsClient: val roleArnPattern = """^arn:aws:iam::\d{12}:role/[\w+=,.@-]+$""".r roleArnPattern.findFirstIn(roleArn) match { case Some(_) => Concurrent[F].unit - case None => Concurrent[F].raiseError( - new IllegalArgumentException( - s"An error occurred (ValidationError) when calling the AssumeRole operation: $roleArn is invalid" + case None => + Concurrent[F].raiseError( + new IllegalArgumentException( + s"An error occurred (ValidationError) when calling the AssumeRole operation: $roleArn is invalid" + ) ) - ) } /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala index f8d260cf7..4490aeb52 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala @@ -10,13 +10,13 @@ import java.time.Instant import scala.concurrent.duration.* -import munit.CatsEffectSuite - -import cats.effect.IO import cats.effect.std.UUIDGen +import cats.effect.IO import fs2.io.net.Network +import munit.CatsEffectSuite + import ldbc.amazon.exception.{ SdkClientException, StsException } class StsClientTest extends CatsEffectSuite: @@ -24,13 +24,14 @@ class StsClientTest extends CatsEffectSuite: // LocalStack STS endpoint (from docker-compose.yml) private val localStackEndpoint = "http://localhost:4566" - private val testRoleArn = "arn:aws:iam::000000000000:role/localstack-role" - private val roleSessionName = "test-session" - private val testAssumedRoleArn = s"arn:aws:sts::000000000000:assumed-role/localstack-role/$roleSessionName" - private val testWebIdentityToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + private val testRoleArn = "arn:aws:iam::000000000000:role/localstack-role" + private val roleSessionName = "test-session" + private val testAssumedRoleArn = s"arn:aws:sts::000000000000:assumed-role/localstack-role/$roleSessionName" + private val testWebIdentityToken = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" // HTTP client for testing - private def httpClient: SimpleHttpClient[IO] = + private def httpClient: SimpleHttpClient[IO] = new SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) // STS client configured for LocalStack @@ -38,10 +39,10 @@ class StsClientTest extends CatsEffectSuite: test("assumeRoleWithWebIdentity with LocalStack".flaky) { val request = StsClient.AssumeRoleWithWebIdentityRequest( - roleArn = testRoleArn, + roleArn = testRoleArn, webIdentityToken = testWebIdentityToken, - roleSessionName = Some(roleSessionName), - durationSeconds = Some(1800) + roleSessionName = Some(roleSessionName), + durationSeconds = Some(1800) ) localStackStsClient.assumeRoleWithWebIdentity(request).map { response => @@ -54,9 +55,9 @@ class StsClientTest extends CatsEffectSuite: } test("handle invalid role ARN") { - val client = localStackStsClient + val client = localStackStsClient val request = StsClient.AssumeRoleWithWebIdentityRequest( - roleArn = "invalid-role-arn", + roleArn = "invalid-role-arn", webIdentityToken = testWebIdentityToken ) @@ -64,9 +65,9 @@ class StsClientTest extends CatsEffectSuite: } test("assumeRoleWithWebIdentity with auto-generated session name") { - val client = localStackStsClient + val client = localStackStsClient val request = StsClient.AssumeRoleWithWebIdentityRequest( - roleArn = testRoleArn, + roleArn = testRoleArn, webIdentityToken = testWebIdentityToken // No roleSessionName - should be auto-generated ) @@ -88,13 +89,13 @@ class StsClientTest extends CatsEffectSuite: } test("AssumeRoleWithWebIdentityResponse validation") { - val now = Instant.now() + val now = Instant.now() val response = StsClient.AssumeRoleWithWebIdentityResponse( - accessKeyId = "ASIAIOSFODNN7EXAMPLE", + accessKeyId = "ASIAIOSFODNN7EXAMPLE", secretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", - sessionToken = "session-token", - expiration = now.plusSeconds(3600), - assumedRoleArn = testRoleArn + sessionToken = "session-token", + expiration = now.plusSeconds(3600), + assumedRoleArn = testRoleArn ) assertEquals(response.accessKeyId, "ASIAIOSFODNN7EXAMPLE") @@ -107,26 +108,27 @@ class StsClientTest extends CatsEffectSuite: test("buildRequestBody format") { // Test the query parameter formatting val request = StsClient.AssumeRoleWithWebIdentityRequest( - roleArn = "arn:aws:iam::123456789012:role/TestRole", + roleArn = "arn:aws:iam::123456789012:role/TestRole", webIdentityToken = "test-token", - roleSessionName = Some("test-session"), - durationSeconds = Some(1800) + roleSessionName = Some("test-session"), + durationSeconds = Some(1800) ) // We can't access the private method directly, but we can test the behavior // by ensuring our LocalStack client formats parameters correctly val queryParams = Map( - "Action" -> "AssumeRoleWithWebIdentity", - "Version" -> "2011-06-15", - "RoleArn" -> request.roleArn, + "Action" -> "AssumeRoleWithWebIdentity", + "Version" -> "2011-06-15", + "RoleArn" -> request.roleArn, "WebIdentityToken" -> request.webIdentityToken, - "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), - "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString + "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), + "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString ) - queryParams.foreach { case (key, value) => - assert(key.nonEmpty) - assert(value.nonEmpty) + queryParams.foreach { + case (key, value) => + assert(key.nonEmpty) + assert(value.nonEmpty) } assert(queryParams("Action") == "AssumeRoleWithWebIdentity") From 5f472569a7e781c73272f0d3c4eeb81595fc5fec Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 21:48:10 +0900 Subject: [PATCH 133/215] Split SimpleHttpClient to multi platform --- .../ldbc/amazon/client/SimpleHttpClient.scala | 207 +++++++++++++++++ .../ldbc/amazon/client/SimpleHttpClient.scala | 18 +- .../ldbc/amazon/client/SimpleHttpClient.scala | 212 ++++++++++++++++++ 3 files changed, 425 insertions(+), 12 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala rename module/ldbc-aws-authentication-plugin/{shared => jvm}/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala (99%) create mode 100644 module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala new file mode 100644 index 000000000..039ea3394 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -0,0 +1,207 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import cats.MonadThrow +import cats.effect.* +import cats.effect.syntax.all.* +import cats.syntax.all.* +import com.comcast.ip4s.* +import fs2.* +import fs2.io.net.* +import fs2.io.net.tls.* +import ldbc.amazon.exception.* + +import java.net.URI +import scala.concurrent.duration.* + +/** + * Secure HTTP client that supports both HTTP and HTTPS protocols. + * + * Security Features: + * - Validates URI schemes and rejects unsupported protocols + * - Uses TLS for HTTPS connections with proper certificate validation + * - Defaults to secure ports (443 for HTTPS, 80 for HTTP) + * - Prevents credentials from being sent over cleartext connections + * + * This addresses the security vulnerability where AWS credentials + * could be sent over unencrypted HTTP connections. + */ +class SimpleHttpClient[F[_]: Network: Async]( + connectTimeout: Duration, + readTimeout: Duration +)(using ev: MonadThrow[F]) + extends HttpClient[F]: + + private def isHttps(uri: URI): Boolean = + uri.getScheme != null && uri.getScheme.toLowerCase == "https" + + private def getDefaultPort(uri: URI): Int = + if uri.getPort > 0 then uri.getPort + else if isHttps(uri) then 443 + else 80 + + private def validateScheme(uri: URI): F[Unit] = + uri.getScheme match + case null => ev.raiseError(new SdkClientException("URI scheme is required")) + case scheme if scheme.toLowerCase == "http" => + // Log warning for HTTP usage, but allow it for non-sensitive endpoints + ev.unit + case scheme if scheme.toLowerCase == "https" => ev.unit + case unsupported => + ev.raiseError( + new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") + ) + + private def validateSecurityRequirements(uri: URI): F[Unit] = + // AWS endpoints should always use HTTPS + if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then + ev.raiseError( + new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") + ) + else ev.unit + + private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + if isSecure then + for + socket <- Network[F].client(address) + tlsContext <- Network[F].tlsContext.systemResource + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(servername = Some(host))) + .build + yield tlsSocket + else Network[F].client(address) + + override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) + yield response + + override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + + override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + + private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = + for + h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- ev.fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + yield SocketAddress(h, p) + + private def sendRequest( + socket: Socket[F], + method: String, + host: String, + port: Int, + isSecure: Boolean, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[Unit] = + val defaultPort = if isSecure then 443 else 80 + val hostHeader = if port == defaultPort then host else s"$host:$port" + val contentHeaders = body match { + case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) + case None => Map.empty + } + val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") + + val requestLine = s"$method $path HTTP/1.1\r\n" + val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString + val requestWithHeaders = requestLine + headerLines + "\r\n" + val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) + + Stream + .emit(fullRequest) + .through(text.utf8.encode) + .through(socket.writes) + .compile + .drain + + private def parseStatusLine(line: String): F[Int] = + // "HTTP/1.1 200 OK" -> 200 + line.split(" ").toList match + case _ :: code :: _ => + code.toIntOption match + case Some(c) => ev.pure(c) + case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + + private def parseHeaderLine(line: String): Option[(String, String)] = + line.split(": ", 2).toList match + case key :: value :: Nil => Some(key -> value) + case _ => None + + private def parseHttpResponse(raw: String): F[HttpResponse] = + val lines = raw.split("\r\n").toList + + lines match + case statusLine :: rest => + parseStatusLine(statusLine).flatMap: statusCode => + val (headerLines, bodyLines) = rest.span(_.nonEmpty) + val headers = headerLines.flatMap(parseHeaderLine).toMap + val body = bodyLines.drop(1).mkString("\r\n") // drop empty line + + ev.pure(HttpResponse(statusCode, headers, body)) + case _ => + ev.raiseError(CredentialsFetchError("Empty response")) + + private def receiveResponse(socket: Socket[F]): F[HttpResponse] = + socket.reads + .through(text.utf8.decode) + .compile + .string + .flatMap(parseHttpResponse) + + private def makeRequest( + address: SocketAddress[Host], + host: String, + port: Int, + isSecure: Boolean, + method: String, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[HttpResponse] = + createSocket(address, isSecure, host) + .use { socket => + for + _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) + response <- receiveResponse(socket) + yield response + } + .timeout(connectTimeout + readTimeout) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala similarity index 99% rename from module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala rename to module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 78a3d4cf7..4cdf35c13 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -6,26 +6,20 @@ package ldbc.amazon.client -import java.net.URI - -import javax.net.ssl.SNIHostName - -import scala.concurrent.duration.* - -import com.comcast.ip4s.* - -import cats.syntax.all.* import cats.MonadThrow - import cats.effect.* import cats.effect.syntax.all.* - +import cats.syntax.all.* +import com.comcast.ip4s.* import fs2.* import fs2.io.net.* import fs2.io.net.tls.* - import ldbc.amazon.exception.* +import java.net.URI +import javax.net.ssl.SNIHostName +import scala.concurrent.duration.* + /** * Secure HTTP client that supports both HTTP and HTTPS protocols. * diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala new file mode 100644 index 000000000..e3bdc99d6 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -0,0 +1,212 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.syntax.all.* +import cats.MonadThrow + +import cats.effect.* +import cats.effect.syntax.all.* + +import fs2.* +import fs2.io.net.* +import fs2.io.net.tls.* + +import ldbc.amazon.exception.* + +/** + * Secure HTTP client that supports both HTTP and HTTPS protocols. + * + * Security Features: + * - Validates URI schemes and rejects unsupported protocols + * - Uses TLS for HTTPS connections with proper certificate validation + * - Defaults to secure ports (443 for HTTPS, 80 for HTTP) + * - Prevents credentials from being sent over cleartext connections + * + * This addresses the security vulnerability where AWS credentials + * could be sent over unencrypted HTTP connections. + */ +class SimpleHttpClient[F[_]: Network: Async]( + connectTimeout: Duration, + readTimeout: Duration +)(using ev: MonadThrow[F]) + extends HttpClient[F]: + + private def isHttps(uri: URI): Boolean = + uri.getScheme != null && uri.getScheme.toLowerCase == "https" + + private def getDefaultPort(uri: URI): Int = + if uri.getPort > 0 then uri.getPort + else if isHttps(uri) then 443 + else 80 + + private def validateScheme(uri: URI): F[Unit] = + uri.getScheme match + case null => ev.raiseError(new SdkClientException("URI scheme is required")) + case scheme if scheme.toLowerCase == "http" => + // Log warning for HTTP usage, but allow it for non-sensitive endpoints + ev.unit + case scheme if scheme.toLowerCase == "https" => ev.unit + case unsupported => + ev.raiseError( + new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") + ) + + private def validateSecurityRequirements(uri: URI): F[Unit] = + // AWS endpoints should always use HTTPS + if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then + ev.raiseError( + new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") + ) + else ev.unit + + private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + if isSecure then + for + socket <- Network[F].client(address) + tlsContext <- Network[F].tlsContext.systemResource + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(serverName = Some(host))) + .build + yield tlsSocket + else Network[F].client(address) + + override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) + yield response + + override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + + override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + + private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = + for + h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- ev.fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + yield SocketAddress(h, p) + + private def sendRequest( + socket: Socket[F], + method: String, + host: String, + port: Int, + isSecure: Boolean, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[Unit] = + val defaultPort = if isSecure then 443 else 80 + val hostHeader = if port == defaultPort then host else s"$host:$port" + val contentHeaders = body match { + case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) + case None => Map.empty + } + val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") + + val requestLine = s"$method $path HTTP/1.1\r\n" + val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString + val requestWithHeaders = requestLine + headerLines + "\r\n" + val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) + + Stream + .emit(fullRequest) + .through(text.utf8.encode) + .through(socket.writes) + .compile + .drain + + private def parseStatusLine(line: String): F[Int] = + // "HTTP/1.1 200 OK" -> 200 + line.split(" ").toList match + case _ :: code :: _ => + code.toIntOption match + case Some(c) => ev.pure(c) + case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + + private def parseHeaderLine(line: String): Option[(String, String)] = + line.split(": ", 2).toList match + case key :: value :: Nil => Some(key -> value) + case _ => None + + private def parseHttpResponse(raw: String): F[HttpResponse] = + val lines = raw.split("\r\n").toList + + lines match + case statusLine :: rest => + parseStatusLine(statusLine).flatMap: statusCode => + val (headerLines, bodyLines) = rest.span(_.nonEmpty) + val headers = headerLines.flatMap(parseHeaderLine).toMap + val body = bodyLines.drop(1).mkString("\r\n") // drop empty line + + ev.pure(HttpResponse(statusCode, headers, body)) + case _ => + ev.raiseError(CredentialsFetchError("Empty response")) + + private def receiveResponse(socket: Socket[F]): F[HttpResponse] = + socket.reads + .through(text.utf8.decode) + .compile + .string + .flatMap(parseHttpResponse) + + private def makeRequest( + address: SocketAddress[Host], + host: String, + port: Int, + isSecure: Boolean, + method: String, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[HttpResponse] = + createSocket(address, isSecure, host) + .use { socket => + for + _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) + response <- receiveResponse(socket) + yield response + } + .timeout(connectTimeout + readTimeout) From 80571f84b8ed2c62bf822c654db8b4d0d120f0bb Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 22:06:07 +0900 Subject: [PATCH 134/215] Action sbt scalafmtAll --- .../ldbc/amazon/client/SimpleHttpClient.scala | 15 ++++++++++----- .../ldbc/amazon/client/SimpleHttpClient.scala | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 039ea3394..6728fb0a4 100644 --- a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -6,18 +6,23 @@ package ldbc.amazon.client +import java.net.URI + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.syntax.all.* import cats.MonadThrow + import cats.effect.* import cats.effect.syntax.all.* -import cats.syntax.all.* -import com.comcast.ip4s.* + import fs2.* import fs2.io.net.* import fs2.io.net.tls.* -import ldbc.amazon.exception.* -import java.net.URI -import scala.concurrent.duration.* +import ldbc.amazon.exception.* /** * Secure HTTP client that supports both HTTP and HTTPS protocols. diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 4cdf35c13..78a3d4cf7 100644 --- a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -6,19 +6,25 @@ package ldbc.amazon.client +import java.net.URI + +import javax.net.ssl.SNIHostName + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.syntax.all.* import cats.MonadThrow + import cats.effect.* import cats.effect.syntax.all.* -import cats.syntax.all.* -import com.comcast.ip4s.* + import fs2.* import fs2.io.net.* import fs2.io.net.tls.* -import ldbc.amazon.exception.* -import java.net.URI -import javax.net.ssl.SNIHostName -import scala.concurrent.duration.* +import ldbc.amazon.exception.* /** * Secure HTTP client that supports both HTTP and HTTPS protocols. From 5110f696e346a4dc2b5ae50b253b191e5e932380 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 22:11:24 +0900 Subject: [PATCH 135/215] Create RdsIamAuthTokenGeneratorTest --- .../token/RdsIamAuthTokenGeneratorTest.scala | 495 ++++++++++++++++++ 1 file changed, 495 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala new file mode 100644 index 000000000..29a4c20cf --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala @@ -0,0 +1,495 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.token + +import java.time.Instant +import java.net.URLDecoder + +import munit.CatsEffectSuite + +import cats.effect.IO + +import fs2.hashing.Hashing + +import ldbc.amazon.auth.credentials.{ AwsBasicCredentials, AwsSessionCredentials } + +class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: + + // Test fixtures + private val testHostname = "my-db-instance.ap-northeast-1.rds.amazonaws.com" + private val testPort = 3306 + private val testUsername = "db_user" + private val testRegion = "ap-northeast-1" + + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + + private def createGenerator(): RdsIamAuthTokenGenerator[IO] = + new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = testUsername, + region = testRegion + ) + + test("generateToken with basic credentials") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = Some("test-provider"), + accountId = Some("123456789012"), + expirationTime = Some(Instant.now().plusSeconds(3600)) + ) + + generator.generateToken(credentials).map { token => + // Verify token structure + assert(token.startsWith(s"$testHostname:$testPort/?")) + assert(token.contains("Action=connect")) + assert(token.contains(s"DBUser=$testUsername")) + assert(token.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")) + assert(token.contains(s"X-Amz-Credential=$testAccessKeyId")) + assert(token.contains("X-Amz-Date=")) + assert(token.contains("X-Amz-Expires=900")) + assert(token.contains("X-Amz-SignedHeaders=host")) + assert(token.contains("X-Amz-Signature=")) + + // Should NOT contain session token for basic credentials + assert(!token.contains("X-Amz-Security-Token")) + } + } + + test("generateToken with session credentials") { + val generator = createGenerator() + val credentials = AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = true, + providerName = Some("test-provider"), + accountId = Some("123456789012"), + expirationTime = Some(Instant.now().plusSeconds(3600)) + ) + + generator.generateToken(credentials).map { token => + // Verify token structure + assert(token.startsWith(s"$testHostname:$testPort/?")) + assert(token.contains("Action=connect")) + assert(token.contains(s"DBUser=$testUsername")) + assert(token.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")) + assert(token.contains(s"X-Amz-Credential=$testAccessKeyId")) + assert(token.contains("X-Amz-Date=")) + assert(token.contains("X-Amz-Expires=900")) + assert(token.contains("X-Amz-SignedHeaders=host")) + assert(token.contains("X-Amz-Signature=")) + + // Should contain session token for session credentials + assert(token.contains("X-Amz-Security-Token=")) + assert(token.contains(java.net.URLEncoder.encode(testSessionToken, "UTF-8"))) + } + } + + test("generateToken produces consistent output structure") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator.generateToken(credentials) + token2 <- generator.generateToken(credentials) + } yield { + // Tokens may differ due to timestamp, but structure should be consistent + assert(token1.startsWith(s"$testHostname:$testPort/?")) + assert(token2.startsWith(s"$testHostname:$testPort/?")) + + // Both should contain the same parameter names + val params1 = token1.split("\\?", 2)(1).split("&").map(_.split("=")(0)).toSet + val params2 = token2.split("\\?", 2)(1).split("&").map(_.split("=")(0)).toSet + assertEquals(params1, params2) + } + } + + test("token format validation") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Parse the token URL + val parts = token.split("\\?", 2) + assertEquals(parts.length, 2) + + val hostPart = parts(0) + val queryPart = parts(1) + + assertEquals(hostPart, s"$testHostname:$testPort/") + + // Parse query parameters + val queryParams = queryPart.split("&").map { param => + val kv = param.split("=", 2) + kv(0) -> URLDecoder.decode(kv(1), "UTF-8") + }.toMap + + // Verify required parameters + assertEquals(queryParams("Action"), "connect") + assertEquals(queryParams("DBUser"), testUsername) + assertEquals(queryParams("X-Amz-Algorithm"), "AWS4-HMAC-SHA256") + assertEquals(queryParams("X-Amz-Expires"), "900") + assertEquals(queryParams("X-Amz-SignedHeaders"), "host") + + // Verify credential format + val credential = queryParams("X-Amz-Credential") + assert(credential.startsWith(testAccessKeyId)) + assert(credential.contains(testRegion)) + assert(credential.contains("rds-db")) + assert(credential.contains("aws4_request")) + + // Verify date format (should be ISO 8601 basic format) + val date = queryParams("X-Amz-Date") + assert(date.matches("""\d{8}T\d{6}Z""")) + + // Verify signature is present and non-empty + val signature = queryParams("X-Amz-Signature") + assert(signature.nonEmpty) + assert(signature.matches("[0-9a-f]+")) // Should be lowercase hex + } + } + + test("token parameter ordering") { + val generator = createGenerator() + val credentials = AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Extract query string + val queryString = token.split("\\?", 2)(1) + val paramsBefore = queryString.split("&").filterNot(_.startsWith("X-Amz-Signature=")).toList + val paramsSorted = paramsBefore.sortBy(_.split("=")(0)) + + // Parameters should be in alphabetical order (required for AWS SigV4) + assertEquals(paramsBefore, paramsSorted) + } + } + + test("different credentials produce different signatures") { + val generator = createGenerator() + + val credentials1 = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + val credentials2 = AwsBasicCredentials( + accessKeyId = "AKIADIFFERENTKEY", + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator.generateToken(credentials1) + token2 <- generator.generateToken(credentials2) + } yield { + assert(token1 != token2) + + // Extract signatures + val sig1 = token1.split("X-Amz-Signature=")(1) + val sig2 = token2.split("X-Amz-Signature=")(1) + assert(sig1 != sig2) + } + } + + test("different regions produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = testUsername, + region = "us-east-1" + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = testUsername, + region = "us-west-2" + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + + // Verify region is included in credential scope + assert(token1.contains("us-east-1")) + assert(token2.contains("us-west-2")) + } + } + + test("different hostnames produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = "host1.region.rds.amazonaws.com", + port = testPort, + username = testUsername, + region = testRegion + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = "host2.region.rds.amazonaws.com", + port = testPort, + username = testUsername, + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + assert(token1.startsWith("host1.region.rds.amazonaws.com")) + assert(token2.startsWith("host2.region.rds.amazonaws.com")) + } + } + + test("different ports produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = 3306, + username = testUsername, + region = testRegion + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = 5432, + username = testUsername, + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + assert(token1.contains(":3306/")) + assert(token2.contains(":5432/")) + } + } + + test("different usernames produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = "user1", + region = testRegion + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = "user2", + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + assert(token1.contains("DBUser=user1")) + assert(token2.contains("DBUser=user2")) + } + } + + test("URL encoding is applied correctly") { + val generator = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = "user with spaces@domain.com", + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Username should be URL encoded + assert(token.contains("DBUser=user%20with%20spaces%40domain.com")) + + // Verify AWS-specific encoding rules + assert(!token.contains("+")) // Spaces should be %20, not + + + // Extract all parameters to verify proper encoding + val queryString = token.split("\\?", 2)(1) + val params = queryString.split("&") + + params.foreach { param => + val parts = param.split("=", 2) + if (parts.length == 2) { + val value = parts(1) + + // Verify no invalid characters remain unencoded + assert(!value.contains(" ")) + assert(!value.contains("@")) + } + } + } + } + + test("signature calculation consistency") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + val signature = token.split("X-Amz-Signature=")(1) + + // Signature should be 64 characters (SHA-256 hex) + assertEquals(signature.length, 64) + + // Signature should be lowercase hex + assert(signature.forall(c => c.isDigit || ('a' <= c && c <= 'f'))) + } + } + + test("implements AuthTokenGenerator trait") { + val generator: AuthTokenGenerator[IO] = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + assert(token.nonEmpty) + assert(token.contains(testHostname)) + } + } + + test("token contains correct service and terminator") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Extract credential scope from the credential parameter + val queryString = token.split("\\?", 2)(1) + val credentialParam = queryString.split("&") + .find(_.startsWith("X-Amz-Credential=")) + .getOrElse(fail("X-Amz-Credential parameter not found")) + + val credentialValue = URLDecoder.decode(credentialParam.split("=")(1), "UTF-8") + + // Credential format: accessKeyId/date/region/service/terminator + val parts = credentialValue.split("/") + assertEquals(parts.length, 5) + assertEquals(parts(0), testAccessKeyId) + assertEquals(parts(2), testRegion) + assertEquals(parts(3), "rds-db") // SERVICE constant + assertEquals(parts(4), "aws4_request") // TERMINATOR constant + } + } + + test("token expiration is set to 900 seconds") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + assert(token.contains("X-Amz-Expires=900")) + } + } \ No newline at end of file From 9efb5933b65bdae1b13db100d33f3e7e584d3002 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 22:26:09 +0900 Subject: [PATCH 136/215] Added scala java time dependencies --- build.sbt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/build.sbt b/build.sbt index 0ae94d85b..20f9c0d08 100644 --- a/build.sbt +++ b/build.sbt @@ -148,9 +148,10 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP .module("aws-authentication-plugin", "") .settings( libraryDependencies ++= Seq( - "co.fs2" %%% "fs2-core" % "3.12.2", - "co.fs2" %%% "fs2-io" % "3.12.2", - "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test + "co.fs2" %%% "fs2-core" % "3.12.2", + "co.fs2" %%% "fs2-io" % "3.12.2", + "io.github.cquiroz" %%% "scala-java-time" % "2.5.0", + "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test ) ) .jsSettings( From a8a12dbb5231227c4772b193e63d86ec50d4a3d6 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 22:26:33 +0900 Subject: [PATCH 137/215] Delete unused --- .../ContainerCredentialsProvider.scala | 29 ++----------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala index 161c0a461..583fd7a1e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -9,17 +9,15 @@ package ldbc.amazon.auth.credentials import java.net.URI import java.time.Instant -import scala.concurrent.duration.* import cats.syntax.all.* -import cats.effect.{ Async, Concurrent } +import cats.effect.Concurrent import cats.effect.std.Env import fs2.io.file.{ Files, Path } -import fs2.io.net.* -import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } +import ldbc.amazon.client.HttpClient import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.useragent.BusinessMetricFeatureId @@ -237,15 +235,6 @@ private object ContainerCredentialsResponse: object ContainerCredentialsProvider: - /** - * Creates a new Container credentials provider with default settings. - * - * @tparam F The effect type - * @return A new ContainerCredentialsProvider instance - */ - def apply[F[_]: Files: Env: Network: Async](): F[ContainerCredentialsProvider[F]] = - createDefaultHttpClient[F]().map(httpClient => new ContainerCredentialsProvider[F](httpClient)) - /** * Creates a new Container credentials provider with custom HTTP client. * @@ -258,20 +247,6 @@ object ContainerCredentialsProvider: ): ContainerCredentialsProvider[F] = new ContainerCredentialsProvider[F](httpClient) - /** - * Creates a default HTTP client for container credential operations. - * - * @tparam F The effect type - * @return A configured HTTP client - */ - private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = - Async[F].pure( - new SimpleHttpClient[F]( - connectTimeout = 5.seconds, - readTimeout = 5.seconds - ) - ) - /** * Checks if Container credentials are available by verifying * the presence of required environment variables. From bdfb009f10e759c4a06d266ffb5246e41a408bc7 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sat, 6 Dec 2025 22:27:00 +0900 Subject: [PATCH 138/215] Action sbt scalafmtAll --- .../ContainerCredentialsProvider.scala | 3 +- .../token/RdsIamAuthTokenGeneratorTest.scala | 326 +++++++++--------- 2 files changed, 166 insertions(+), 163 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala index 583fd7a1e..8ba167183 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -9,11 +9,10 @@ package ldbc.amazon.auth.credentials import java.net.URI import java.time.Instant - import cats.syntax.all.* -import cats.effect.Concurrent import cats.effect.std.Env +import cats.effect.Concurrent import fs2.io.file.{ Files, Path } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala index 29a4c20cf..d06129852 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala @@ -6,46 +6,46 @@ package ldbc.amazon.auth.token -import java.time.Instant import java.net.URLDecoder - -import munit.CatsEffectSuite +import java.time.Instant import cats.effect.IO import fs2.hashing.Hashing +import munit.CatsEffectSuite + import ldbc.amazon.auth.credentials.{ AwsBasicCredentials, AwsSessionCredentials } class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: // Test fixtures private val testHostname = "my-db-instance.ap-northeast-1.rds.amazonaws.com" - private val testPort = 3306 + private val testPort = 3306 private val testUsername = "db_user" - private val testRegion = "ap-northeast-1" + private val testRegion = "ap-northeast-1" - private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" - private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" private def createGenerator(): RdsIamAuthTokenGenerator[IO] = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = testPort, + port = testPort, username = testUsername, - region = testRegion + region = testRegion ) test("generateToken with basic credentials") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = Some("test-provider"), - accountId = Some("123456789012"), - expirationTime = Some(Instant.now().plusSeconds(3600)) + providerName = Some("test-provider"), + accountId = Some("123456789012"), + expirationTime = Some(Instant.now().plusSeconds(3600)) ) generator.generateToken(credentials).map { token => @@ -59,22 +59,22 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: assert(token.contains("X-Amz-Expires=900")) assert(token.contains("X-Amz-SignedHeaders=host")) assert(token.contains("X-Amz-Signature=")) - + // Should NOT contain session token for basic credentials assert(!token.contains("X-Amz-Security-Token")) } } test("generateToken with session credentials") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsSessionCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, - sessionToken = testSessionToken, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, validateCredentials = true, - providerName = Some("test-provider"), - accountId = Some("123456789012"), - expirationTime = Some(Instant.now().plusSeconds(3600)) + providerName = Some("test-provider"), + accountId = Some("123456789012"), + expirationTime = Some(Instant.now().plusSeconds(3600)) ) generator.generateToken(credentials).map { token => @@ -88,7 +88,7 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: assert(token.contains("X-Amz-Expires=900")) assert(token.contains("X-Amz-SignedHeaders=host")) assert(token.contains("X-Amz-Signature=")) - + // Should contain session token for session credentials assert(token.contains("X-Amz-Security-Token=")) assert(token.contains(java.net.URLEncoder.encode(testSessionToken, "UTF-8"))) @@ -96,14 +96,14 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: } test("generateToken produces consistent output structure") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) for { @@ -113,7 +113,7 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: // Tokens may differ due to timestamp, but structure should be consistent assert(token1.startsWith(s"$testHostname:$testPort/?")) assert(token2.startsWith(s"$testHostname:$testPort/?")) - + // Both should contain the same parameter names val params1 = token1.split("\\?", 2)(1).split("&").map(_.split("=")(0)).toSet val params2 = token2.split("\\?", 2)(1).split("&").map(_.split("=")(0)).toSet @@ -122,50 +122,53 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: } test("token format validation") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => // Parse the token URL val parts = token.split("\\?", 2) assertEquals(parts.length, 2) - - val hostPart = parts(0) + + val hostPart = parts(0) val queryPart = parts(1) - + assertEquals(hostPart, s"$testHostname:$testPort/") - + // Parse query parameters - val queryParams = queryPart.split("&").map { param => - val kv = param.split("=", 2) - kv(0) -> URLDecoder.decode(kv(1), "UTF-8") - }.toMap - + val queryParams = queryPart + .split("&") + .map { param => + val kv = param.split("=", 2) + kv(0) -> URLDecoder.decode(kv(1), "UTF-8") + } + .toMap + // Verify required parameters assertEquals(queryParams("Action"), "connect") assertEquals(queryParams("DBUser"), testUsername) assertEquals(queryParams("X-Amz-Algorithm"), "AWS4-HMAC-SHA256") assertEquals(queryParams("X-Amz-Expires"), "900") assertEquals(queryParams("X-Amz-SignedHeaders"), "host") - + // Verify credential format val credential = queryParams("X-Amz-Credential") assert(credential.startsWith(testAccessKeyId)) assert(credential.contains(testRegion)) assert(credential.contains("rds-db")) assert(credential.contains("aws4_request")) - + // Verify date format (should be ISO 8601 basic format) val date = queryParams("X-Amz-Date") assert(date.matches("""\d{8}T\d{6}Z""")) - + // Verify signature is present and non-empty val signature = queryParams("X-Amz-Signature") assert(signature.nonEmpty) @@ -174,23 +177,23 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: } test("token parameter ordering") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsSessionCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, - sessionToken = testSessionToken, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => // Extract query string - val queryString = token.split("\\?", 2)(1) + val queryString = token.split("\\?", 2)(1) val paramsBefore = queryString.split("&").filterNot(_.startsWith("X-Amz-Signature=")).toList val paramsSorted = paramsBefore.sortBy(_.split("=")(0)) - + // Parameters should be in alphabetical order (required for AWS SigV4) assertEquals(paramsBefore, paramsSorted) } @@ -198,23 +201,23 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("different credentials produce different signatures") { val generator = createGenerator() - + val credentials1 = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) - + val credentials2 = AwsBasicCredentials( - accessKeyId = "AKIADIFFERENTKEY", - secretAccessKey = testSecretAccessKey, + accessKeyId = "AKIADIFFERENTKEY", + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) for { @@ -222,7 +225,7 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: token2 <- generator.generateToken(credentials2) } yield { assert(token1 != token2) - + // Extract signatures val sig1 = token1.split("X-Amz-Signature=")(1) val sig2 = token2.split("X-Amz-Signature=")(1) @@ -233,25 +236,25 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("different regions produce different tokens") { val generator1 = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = testPort, + port = testPort, username = testUsername, - region = "us-east-1" + region = "us-east-1" ) - + val generator2 = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = testPort, + port = testPort, username = testUsername, - region = "us-west-2" + region = "us-west-2" ) - + val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) for { @@ -259,7 +262,7 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: token2 <- generator2.generateToken(credentials) } yield { assert(token1 != token2) - + // Verify region is included in credential scope assert(token1.contains("us-east-1")) assert(token2.contains("us-west-2")) @@ -269,25 +272,25 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("different hostnames produce different tokens") { val generator1 = new RdsIamAuthTokenGenerator[IO]( hostname = "host1.region.rds.amazonaws.com", - port = testPort, + port = testPort, username = testUsername, - region = testRegion + region = testRegion ) - + val generator2 = new RdsIamAuthTokenGenerator[IO]( hostname = "host2.region.rds.amazonaws.com", - port = testPort, + port = testPort, username = testUsername, - region = testRegion + region = testRegion ) - + val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) for { @@ -303,25 +306,25 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("different ports produce different tokens") { val generator1 = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = 3306, + port = 3306, username = testUsername, - region = testRegion + region = testRegion ) - + val generator2 = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = 5432, + port = 5432, username = testUsername, - region = testRegion + region = testRegion ) - + val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) for { @@ -337,25 +340,25 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("different usernames produce different tokens") { val generator1 = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = testPort, + port = testPort, username = "user1", - region = testRegion + region = testRegion ) - + val generator2 = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = testPort, + port = testPort, username = "user2", - region = testRegion + region = testRegion ) - + val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) for { @@ -371,36 +374,36 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("URL encoding is applied correctly") { val generator = new RdsIamAuthTokenGenerator[IO]( hostname = testHostname, - port = testPort, + port = testPort, username = "user with spaces@domain.com", - region = testRegion + region = testRegion ) - + val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => // Username should be URL encoded assert(token.contains("DBUser=user%20with%20spaces%40domain.com")) - + // Verify AWS-specific encoding rules assert(!token.contains("+")) // Spaces should be %20, not + - + // Extract all parameters to verify proper encoding val queryString = token.split("\\?", 2)(1) - val params = queryString.split("&") - + val params = queryString.split("&") + params.foreach { param => val parts = param.split("=", 2) - if (parts.length == 2) { + if parts.length == 2 then { val value = parts(1) - + // Verify no invalid characters remain unencoded assert(!value.contains(" ")) assert(!value.contains("@")) @@ -410,22 +413,22 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: } test("signature calculation consistency") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => val signature = token.split("X-Amz-Signature=")(1) - + // Signature should be 64 characters (SHA-256 hex) assertEquals(signature.length, 64) - + // Signature should be lowercase hex assert(signature.forall(c => c.isDigit || ('a' <= c && c <= 'f'))) } @@ -434,12 +437,12 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: test("implements AuthTokenGenerator trait") { val generator: AuthTokenGenerator[IO] = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => @@ -449,47 +452,48 @@ class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: } test("token contains correct service and terminator") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => // Extract credential scope from the credential parameter - val queryString = token.split("\\?", 2)(1) - val credentialParam = queryString.split("&") + val queryString = token.split("\\?", 2)(1) + val credentialParam = queryString + .split("&") .find(_.startsWith("X-Amz-Credential=")) .getOrElse(fail("X-Amz-Credential parameter not found")) - + val credentialValue = URLDecoder.decode(credentialParam.split("=")(1), "UTF-8") - + // Credential format: accessKeyId/date/region/service/terminator val parts = credentialValue.split("/") assertEquals(parts.length, 5) assertEquals(parts(0), testAccessKeyId) assertEquals(parts(2), testRegion) - assertEquals(parts(3), "rds-db") // SERVICE constant + assertEquals(parts(3), "rds-db") // SERVICE constant assertEquals(parts(4), "aws4_request") // TERMINATOR constant } } test("token expiration is set to 900 seconds") { - val generator = createGenerator() + val generator = createGenerator() val credentials = AwsBasicCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, validateCredentials = true, - providerName = None, - accountId = None, - expirationTime = None + providerName = None, + accountId = None, + expirationTime = None ) generator.generateToken(credentials).map { token => assert(token.contains("X-Amz-Expires=900")) } - } \ No newline at end of file + } From e6c89530d386944c77de81b48aa5901c3d340312 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 14:48:19 +0900 Subject: [PATCH 139/215] Fixed bytesToHex for js --- .../scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 39a07cb01..6ba784f53 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -196,7 +196,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( * @return Lowercase hexadecimal string representation */ private def bytesToHex(bytes: Array[Byte]): String = - bytes.map("%02x".format(_)).mkString + bytes.map(b => "%02x".format(b & 0xFF)).mkString /** * Computes the SHA-256 hash of a string and returns it as a lowercase hex string. From 9c845f619a80ffad33dc5830e73d60f6704c7e89 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 14:58:38 +0900 Subject: [PATCH 140/215] Action sbt scalafmtAll --- .../scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala index 6ba784f53..3d1fb8c8e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -196,7 +196,7 @@ class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( * @return Lowercase hexadecimal string representation */ private def bytesToHex(bytes: Array[Byte]): String = - bytes.map(b => "%02x".format(b & 0xFF)).mkString + bytes.map(b => "%02x".format(b & 0xff)).mkString /** * Computes the SHA-256 hash of a string and returns it as a lowercase hex string. From 7295f898017cf1c758f905976879c9e74499af8b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 15:25:41 +0900 Subject: [PATCH 141/215] Create ContainerCredentialsProvider Test --- .../ContainerCredentialsProviderTest.scala | 535 ++++++++++++++++++ 1 file changed, 535 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala new file mode 100644 index 000000000..c076212e1 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala @@ -0,0 +1,535 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import cats.effect.std.Env +import cats.effect.{ IO, Ref } + +import fs2.io.file.{ Files, Path } + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.AwsSessionCredentials +import ldbc.amazon.client.{ HttpClient, HttpResponse } +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials + +class ContainerCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val validJsonResponse = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z", + "RoleArn": "arn:aws:iam::123456789012:role/test-role" + }""" + + private val minimalJsonResponse = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z" + }""" + + private val invalidJsonResponse = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE" + }""" + + private val eksEndpoint = "http://169.254.170.23/v1/credentials" + private val authToken = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + + // Mock HTTP client + private def mockHttpClient( + responseBody: String, + statusCode: Int = 200, + captureRequest: Option[Ref[IO, Option[MockRequest]]] = None + ): HttpClient[IO] = + new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + for + _ <- captureRequest match + case Some(ref) => ref.set(Some(MockRequest(uri, headers))) + case None => IO.unit + yield HttpResponse( + statusCode = statusCode, + headers = Map.empty, + body = responseBody + ) + + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("POST not supported in mock")) + + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("PUT not supported in mock")) + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + // Use real Files instance and mock file operations differently + + case class MockRequest(uri: URI, headers: Map[String, String]) + + test("resolveCredentials with ECS relative URI") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + // Verify credentials + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, "ASIAIOSFODNN7EXAMPLE") + assertEquals(session.secretAccessKey, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + assertEquals(session.sessionToken, "session-token-123") + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("1")) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.expirationTime, Some(Instant.parse("2024-12-31T23:59:59Z"))) + case _ => fail("Expected AwsSessionCredentials") + + // Verify HTTP request + captured match + case Some(request) => + assertEquals(request.uri.toString, "http://169.254.170.2/v2/credentials/test-uuid") + assertEquals(request.headers("Authorization"), authToken) + assertEquals(request.headers("Accept"), "application/json") + assertEquals(request.headers("User-Agent"), "aws-sdk-scala/ldbc") + case None => fail("Expected captured request") + } + + test("resolveCredentials with EKS full URI") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint, + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(minimalJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + // Verify credentials without RoleArn + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, "ASIAIOSFODNN7EXAMPLE") + assertEquals(session.secretAccessKey, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + assertEquals(session.sessionToken, "session-token-123") + assertEquals(session.accountId, None) // No RoleArn means no account ID + case _ => fail("Expected AwsSessionCredentials") + + // Verify HTTP request + captured match + case Some(request) => + assertEquals(request.uri.toString, eksEndpoint) + assertEquals(request.headers("Authorization"), authToken) + case None => fail("Expected captured request") + } + + test("resolveCredentials with token from direct environment variable") { + // This test works in both JVM and JavaScript environments + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> s" $authToken \n" // With whitespace + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify token was trimmed and used (ContainerCredentialsProvider trims the token) + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), authToken.trim) + case None => fail("Expected captured request") + } + + test("resolveCredentials without authorization token") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + // No authorization token + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify no Authorization header + captured match + case Some(request) => + assert(!request.headers.contains("Authorization")) + case None => fail("Expected captured request") + } + + test("fail when no container credentials environment variables are set") { + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(validJsonResponse) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load container credentials")) + assert(exception.getMessage.contains("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")) + assert(exception.getMessage.contains("AWS_CONTAINER_CREDENTIALS_FULL_URI")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when HTTP request fails") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient("Internal Server Error", statusCode = 500) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Container credentials request failed")) + assert(exception.getMessage.contains("500")) + assert(exception.getMessage.contains("Internal Server Error")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when JSON response is invalid") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient(invalidJsonResponse) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to parse container credentials response")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when token file does not exist") { + val tokenFilePath = Path("/tmp/non-existent-token") + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> tokenFilePath.toString + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient(validJsonResponse) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + yield + // Should succeed without token (Authorization header optional) + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + } + + test("handle empty authorization token") { + // Test with empty environment variable + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> " \n " // Only whitespace + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify no Authorization header for empty token + captured match + case Some(request) => + assert(!request.headers.contains("Authorization")) + case None => fail("Expected captured request") + } + + test("prefer direct token over token file path") { + // Test priority without actual file operations + val directToken = "direct-token" + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> directToken, + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> "/some/file/path" // File path provided but direct token should take precedence + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify direct token was used, not file token + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), directToken) + case None => fail("Expected captured request") + } + + test("prefer relative URI over full URI") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify relative URI was used (ECS endpoint) + captured match + case Some(request) => + assertEquals(request.uri.toString, "http://169.254.170.2/v2/credentials/test-uuid") + case None => fail("Expected captured request") + } + + test("extractAccountIdFromRoleArn extracts account ID correctly") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val responseWithValidArn = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z", + "RoleArn": "arn:aws:iam::987654321098:role/test-role" + }""" + + val httpClient = mockHttpClient(responseWithValidArn) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accountId, Some("987654321098")) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("extractAccountIdFromRoleArn handles invalid ARN") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val responseWithInvalidArn = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z", + "RoleArn": "invalid-arn-format" + }""" + + val httpClient = mockHttpClient(responseWithInvalidArn) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accountId, None) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("isAvailable returns true when relative URI is set") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, true) + } + } + + test("isAvailable returns true when full URI is set") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint + ) + + given Env[IO] = mockEnv(envVars) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, true) + } + } + + test("isAvailable returns false when no URIs are set") { + given Env[IO] = mockEnv(Map.empty) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, false) + } + } + + test("isAvailable returns false when URIs are empty") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> " ", + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> "" + ) + + given Env[IO] = mockEnv(envVars) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, false) + } + } + + test("implements AwsCredentialsProvider trait") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient(validJsonResponse) + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().map { credentials => + assert(credentials.isInstanceOf[AwsCredentials]) + } + } + + // JavaScript-compatible tests for file-based functionality + test("resolveCredentials with token file path (JavaScript)") { + // In JavaScript environment, test with direct token instead of file + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify token was used + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), authToken) + case None => fail("Expected captured request") + } + + test("token preference logic (JavaScript)") { + // Test that direct token takes precedence over file token path + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> "direct-token", + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> "/non/existent/file" + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify direct token was used + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), "direct-token") + case None => fail("Expected captured request") + } \ No newline at end of file From 006aaa97ca8320c10749b0111509915a72d40c1d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 15:36:12 +0900 Subject: [PATCH 142/215] Change method put to post --- .../js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala | 2 +- .../src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala | 2 +- .../src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 6728fb0a4..c72ca51be 100644 --- a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -118,7 +118,7 @@ class SimpleHttpClient[F[_]: Network: Async]( path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) yield response private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 78a3d4cf7..d2eba08ae 100644 --- a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -120,7 +120,7 @@ class SimpleHttpClient[F[_]: Network: Async]( path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) yield response private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index e3bdc99d6..151cf203f 100644 --- a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -118,7 +118,7 @@ class SimpleHttpClient[F[_]: Network: Async]( path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) yield response private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = From e3d9b5cc6ab80f0e4a61080582bd17abac201ca5 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 15:41:47 +0900 Subject: [PATCH 143/215] Create EnvironmentVariableCredentialsProvider Test --- ...nmentVariableCredentialsProviderTest.scala | 378 ++++++++++++++++++ 1 file changed, 378 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala new file mode 100644 index 000000000..f7e084e4f --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala @@ -0,0 +1,378 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.effect.std.Env +import cats.effect.IO + +import munit.CatsEffectSuite + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials +import ldbc.amazon.util.SdkSystemSetting + +class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + test("resolveCredentials with basic credentials") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + assertEquals(basic.validateCredentials, false) + assertEquals(basic.providerName, Some("g")) + assertEquals(basic.accountId, None) + assertEquals(basic.expirationTime, None) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("resolveCredentials with session credentials") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> testSessionToken + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("g")) + assertEquals(session.accountId, None) + assertEquals(session.expirationTime, None) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("fail when AWS_ACCESS_KEY_ID is missing") { + val envVars = Map( + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when AWS_SECRET_ACCESS_KEY is missing") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Secret key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when no environment variables are set") { + given Env[IO] = mockEnv(Map.empty) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("handle empty environment variable values") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> "", + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Empty string is accepted + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle whitespace-only environment variable values") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> " ", + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Trimmed to empty string + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle empty session token") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> "" + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Empty string, not None + case _ => fail("Expected AwsSessionCredentials even with empty session token") + } + } + + test("handle whitespace-only session token") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> " \n " + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Trimmed to empty string + case _ => fail("Expected AwsSessionCredentials even with whitespace session token") + } + } + + test("trim whitespace from credentials") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> s" $testAccessKeyId ", + "AWS_SECRET_ACCESS_KEY" -> s"\n$testSecretAccessKey\t", + "AWS_SESSION_TOKEN" -> s" $testSessionToken \n" + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("provider name is correct") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assertEquals(credentials.providerName, Some("g")) // BusinessMetricFeatureId.CREDENTIALS_ENV_VARS.code + } + } + + test("loadSetting method reads from environment") { + val testValue = "test-setting-value" + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testValue + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, Some(testValue)) + } + } + + test("loadSetting returns None for missing environment variable") { + given Env[IO] = mockEnv(Map.empty) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, None) + } + } + + test("implements AwsCredentialsProvider trait") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assert(credentials.isInstanceOf[AwsCredentials]) + } + } + + test("consistent behavior across multiple calls") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> testSessionToken + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + for + credentials1 <- provider.resolveCredentials() + credentials2 <- provider.resolveCredentials() + yield + assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) + assertEquals(credentials1.secretAccessKey, credentials2.secretAccessKey) + credentials1 match + case session1: AwsSessionCredentials => + credentials2 match + case session2: AwsSessionCredentials => + assertEquals(session1.sessionToken, session2.sessionToken) + case _ => fail("Both credentials should be AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") + } + + test("validates credentials format") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + // Verify access key format (starts with AKIA for IAM users) + assert(credentials.accessKeyId.startsWith("AKIA")) + assert(credentials.accessKeyId.length >= 16) + + // Verify secret key is not empty and has reasonable length + assert(credentials.secretAccessKey.nonEmpty) + assert(credentials.secretAccessKey.length >= 20) + } + } + + test("case sensitivity of environment variable names") { + val envVars = Map( + "aws_access_key_id" -> testAccessKeyId, // lowercase + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + // Should fail because AWS expects exact case + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load")) + case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") + case Right(_) => fail("Should fail with incorrect case environment variable") + } + } + + test("alternative environment variable names are not supported") { + val envVars = Map( + "AWS_ACCESS_KEY" -> testAccessKeyId, // Not AWS_ACCESS_KEY_ID + "AWS_SECRET_KEY" -> testSecretAccessKey // Not AWS_SECRET_ACCESS_KEY + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load")) + case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") + case Right(_) => fail("Should fail with non-standard environment variable names") + } + } + + test("credentials validation is disabled by default") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.validateCredentials, false) + case session: AwsSessionCredentials => + assertEquals(session.validateCredentials, false) + } + } \ No newline at end of file From ff129917039f95558c1cc7d82126c9e495bd5f08 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 15:42:09 +0900 Subject: [PATCH 144/215] Action sbt scalafmtAll --- .../ContainerCredentialsProviderTest.scala | 168 +++++++++--------- ...nmentVariableCredentialsProviderTest.scala | 94 +++++----- 2 files changed, 130 insertions(+), 132 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala index c076212e1..fc102bf32 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala @@ -9,8 +9,8 @@ package ldbc.amazon.auth.credentials import java.net.URI import java.time.Instant -import cats.effect.std.Env import cats.effect.{ IO, Ref } +import cats.effect.std.Env import fs2.io.file.{ Files, Path } @@ -44,29 +44,28 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: }""" private val eksEndpoint = "http://169.254.170.23/v1/credentials" - private val authToken = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + private val authToken = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." // Mock HTTP client private def mockHttpClient( - responseBody: String, - statusCode: Int = 200, + responseBody: String, + statusCode: Int = 200, captureRequest: Option[Ref[IO, Option[MockRequest]]] = None ): HttpClient[IO] = new HttpClient[IO]: override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = - for - _ <- captureRequest match - case Some(ref) => ref.set(Some(MockRequest(uri, headers))) - case None => IO.unit + for _ <- captureRequest match + case Some(ref) => ref.set(Some(MockRequest(uri, headers))) + case None => IO.unit yield HttpResponse( statusCode = statusCode, - headers = Map.empty, - body = responseBody + headers = Map.empty, + body = responseBody ) override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = IO.raiseError(new UnsupportedOperationException("POST not supported in mock")) - + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = IO.raiseError(new UnsupportedOperationException("PUT not supported in mock")) @@ -85,18 +84,18 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with ECS relative URI") { val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield // Verify credentials credentials match @@ -123,18 +122,18 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with EKS full URI") { val envVars = Map( "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint, - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(minimalJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(minimalJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield // Verify credentials without RoleArn credentials match @@ -157,22 +156,22 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: // This test works in both JVM and JavaScript environments val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> s" $authToken \n" // With whitespace + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> s" $authToken \n" // With whitespace ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify token was trimmed and used (ContainerCredentialsProvider trims the token) captured match @@ -186,20 +185,20 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" // No authorization token ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify no Authorization header captured match @@ -212,7 +211,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: given Env[IO] = mockEnv(Map.empty) val httpClient = mockHttpClient(validJsonResponse) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val provider = ContainerCredentialsProvider.create[IO](httpClient) provider.resolveCredentials().attempt.map { case Left(exception: SdkClientException) => @@ -227,11 +226,11 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" ) - + given Env[IO] = mockEnv(envVars) val httpClient = mockHttpClient("Internal Server Error", statusCode = 500) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val provider = ContainerCredentialsProvider.create[IO](httpClient) provider.resolveCredentials().attempt.map { case Left(exception: SdkClientException) => @@ -246,11 +245,11 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" ) - + given Env[IO] = mockEnv(envVars) val httpClient = mockHttpClient(invalidJsonResponse) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val provider = ContainerCredentialsProvider.create[IO](httpClient) provider.resolveCredentials().attempt.map { case Left(exception: SdkClientException) => @@ -261,45 +260,44 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: test("fail when token file does not exist") { val tokenFilePath = Path("/tmp/non-existent-token") - val envVars = Map( + val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> tokenFilePath.toString ) - + given Env[IO] = mockEnv(envVars) val httpClient = mockHttpClient(validJsonResponse) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val provider = ContainerCredentialsProvider.create[IO](httpClient) - for - credentials <- provider.resolveCredentials() + for credentials <- provider.resolveCredentials() yield // Should succeed without token (Authorization header optional) credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") } test("handle empty authorization token") { // Test with empty environment variable val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> " \n " // Only whitespace + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> " \n " // Only whitespace ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify no Authorization header for empty token captured match @@ -311,25 +309,25 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: test("prefer direct token over token file path") { // Test priority without actual file operations val directToken = "direct-token" - val envVars = Map( + val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> directToken, + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> directToken, "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> "/some/file/path" // File path provided but direct token should take precedence ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify direct token was used, not file token captured match @@ -341,22 +339,22 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: test("prefer relative URI over full URI") { val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify relative URI was used (ECS endpoint) captured match @@ -369,7 +367,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" ) - + given Env[IO] = mockEnv(envVars) val responseWithValidArn = """{ @@ -381,7 +379,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: }""" val httpClient = mockHttpClient(responseWithValidArn) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val provider = ContainerCredentialsProvider.create[IO](httpClient) provider.resolveCredentials().map { case session: AwsSessionCredentials => @@ -394,7 +392,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" ) - + given Env[IO] = mockEnv(envVars) val responseWithInvalidArn = """{ @@ -406,7 +404,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: }""" val httpClient = mockHttpClient(responseWithInvalidArn) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val provider = ContainerCredentialsProvider.create[IO](httpClient) provider.resolveCredentials().map { case session: AwsSessionCredentials => @@ -419,7 +417,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" ) - + given Env[IO] = mockEnv(envVars) ContainerCredentialsProvider.isAvailable[IO]().map { available => @@ -431,7 +429,7 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint ) - + given Env[IO] = mockEnv(envVars) ContainerCredentialsProvider.isAvailable[IO]().map { available => @@ -450,9 +448,9 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: test("isAvailable returns false when URIs are empty") { val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> " ", - "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> "" + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> "" ) - + given Env[IO] = mockEnv(envVars) ContainerCredentialsProvider.isAvailable[IO]().map { available => @@ -464,11 +462,11 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" ) - + given Env[IO] = mockEnv(envVars) val httpClient = mockHttpClient(validJsonResponse) - val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = ContainerCredentialsProvider.create[IO](httpClient) provider.resolveCredentials().map { credentials => @@ -481,22 +479,22 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: // In JavaScript environment, test with direct token instead of file val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify token was used captured match @@ -509,27 +507,27 @@ class ContainerCredentialsProviderTest extends CatsEffectSuite: // Test that direct token takes precedence over file token path val envVars = Map( "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", - "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> "direct-token", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> "direct-token", "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> "/non/existent/file" ) - + given Env[IO] = mockEnv(envVars) val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) - val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) - val provider = ContainerCredentialsProvider.create[IO](httpClient) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) for credentials <- provider.resolveCredentials() - captured <- requestCapture.get + captured <- requestCapture.get yield credentials match case _: AwsSessionCredentials => // Success - case _ => fail("Expected AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") // Verify direct token was used captured match case Some(request) => assertEquals(request.headers("Authorization"), "direct-token") case None => fail("Expected captured request") - } \ No newline at end of file + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala index f7e084e4f..5e574bf2b 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala @@ -18,9 +18,9 @@ import ldbc.amazon.util.SdkSystemSetting class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: // Test fixtures - private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" - private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" // Mock environment private def mockEnv(envVars: Map[String, String]): Env[IO] = @@ -32,10 +32,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with basic credentials") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -54,11 +54,11 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with session credentials") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, - "AWS_SESSION_TOKEN" -> testSessionToken + "AWS_SESSION_TOKEN" -> testSessionToken ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -80,7 +80,7 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -97,7 +97,7 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: val envVars = Map( "AWS_ACCESS_KEY_ID" -> testAccessKeyId ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -125,10 +125,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("handle empty environment variable values") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> "", + "AWS_ACCESS_KEY_ID" -> "", "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -143,10 +143,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("handle whitespace-only environment variable values") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> " ", + "AWS_ACCESS_KEY_ID" -> " ", "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -161,11 +161,11 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("handle empty session token") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, - "AWS_SESSION_TOKEN" -> "" + "AWS_SESSION_TOKEN" -> "" ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -181,11 +181,11 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("handle whitespace-only session token") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, - "AWS_SESSION_TOKEN" -> " \n " + "AWS_SESSION_TOKEN" -> " \n " ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -201,11 +201,11 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("trim whitespace from credentials") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> s" $testAccessKeyId ", + "AWS_ACCESS_KEY_ID" -> s" $testAccessKeyId ", "AWS_SECRET_ACCESS_KEY" -> s"\n$testSecretAccessKey\t", - "AWS_SESSION_TOKEN" -> s" $testSessionToken \n" + "AWS_SESSION_TOKEN" -> s" $testSessionToken \n" ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -221,10 +221,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("provider name is correct") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -236,10 +236,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("loadSetting method reads from environment") { val testValue = "test-setting-value" - val envVars = Map( + val envVars = Map( "AWS_ACCESS_KEY_ID" -> testValue ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -261,13 +261,13 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("implements AwsCredentialsProvider trait") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) - val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = new EnvironmentVariableCredentialsProvider[IO] provider.resolveCredentials().map { credentials => @@ -277,11 +277,11 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("consistent behavior across multiple calls") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, - "AWS_SESSION_TOKEN" -> testSessionToken + "AWS_SESSION_TOKEN" -> testSessionToken ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -303,10 +303,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("validates credentials format") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -315,7 +315,7 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: // Verify access key format (starts with AKIA for IAM users) assert(credentials.accessKeyId.startsWith("AKIA")) assert(credentials.accessKeyId.length >= 16) - + // Verify secret key is not empty and has reasonable length assert(credentials.secretAccessKey.nonEmpty) assert(credentials.secretAccessKey.length >= 20) @@ -324,10 +324,10 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: test("case sensitivity of environment variable names") { val envVars = Map( - "aws_access_key_id" -> testAccessKeyId, // lowercase + "aws_access_key_id" -> testAccessKeyId, // lowercase "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -336,17 +336,17 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: provider.resolveCredentials().attempt.map { case Left(exception: SdkClientException) => assert(exception.getMessage.contains("Unable to load")) - case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") - case Right(_) => fail("Should fail with incorrect case environment variable") + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with incorrect case environment variable") } } test("alternative environment variable names are not supported") { val envVars = Map( - "AWS_ACCESS_KEY" -> testAccessKeyId, // Not AWS_ACCESS_KEY_ID - "AWS_SECRET_KEY" -> testSecretAccessKey // Not AWS_SECRET_ACCESS_KEY + "AWS_ACCESS_KEY" -> testAccessKeyId, // Not AWS_ACCESS_KEY_ID + "AWS_SECRET_KEY" -> testSecretAccessKey // Not AWS_SECRET_ACCESS_KEY ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -354,17 +354,17 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: provider.resolveCredentials().attempt.map { case Left(exception: SdkClientException) => assert(exception.getMessage.contains("Unable to load")) - case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") - case Right(_) => fail("Should fail with non-standard environment variable names") + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with non-standard environment variable names") } } test("credentials validation is disabled by default") { val envVars = Map( - "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey ) - + given Env[IO] = mockEnv(envVars) val provider = new EnvironmentVariableCredentialsProvider[IO] @@ -375,4 +375,4 @@ class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: case session: AwsSessionCredentials => assertEquals(session.validateCredentials, false) } - } \ No newline at end of file + } From de3d3e29cd802eb3253480e94c6d65f194987bc4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 15:50:23 +0900 Subject: [PATCH 145/215] Create SystemPropertyCredentialsProvider Test --- ...ystemPropertyCredentialsProviderTest.scala | 471 ++++++++++++++++++ 1 file changed, 471 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala new file mode 100644 index 000000000..c3e1f2c03 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala @@ -0,0 +1,471 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.effect.std.SystemProperties +import cats.effect.IO + +import munit.CatsEffectSuite + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials +import ldbc.amazon.util.SdkSystemSetting + +class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + + // Mock system properties + private def mockSystemProperties(sysProps: Map[String, String]): SystemProperties[IO] = + new SystemProperties[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(sysProps.get(name)) + override def clear(key: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("clear not supported in mock")) + override def set(key: String, value: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("set not supported in mock")) + + test("resolveCredentials with basic credentials") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + assertEquals(basic.validateCredentials, false) + assertEquals(basic.providerName, Some("f")) + assertEquals(basic.accountId, None) + assertEquals(basic.expirationTime, None) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("resolveCredentials with session credentials") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("f")) + assertEquals(session.accountId, None) + assertEquals(session.expirationTime, None) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("resolveCredentials with account ID") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.accountId" -> "123456789012" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + assertEquals(basic.accountId, Some("123456789012")) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("fail when aws.accessKeyId is missing") { + val sysProps = Map( + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when aws.secretAccessKey is missing") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Secret key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when no system properties are set") { + given SystemProperties[IO] = mockSystemProperties(Map.empty) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("handle empty system property values") { + val sysProps = Map( + "aws.accessKeyId" -> "", + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Empty string is accepted + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle whitespace-only system property values") { + val sysProps = Map( + "aws.accessKeyId" -> " ", + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Trimmed to empty string + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle empty session token") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> "" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Empty string, not None + case _ => fail("Expected AwsSessionCredentials even with empty session token") + } + } + + test("handle whitespace-only session token") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> " \n " + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Trimmed to empty string + case _ => fail("Expected AwsSessionCredentials even with whitespace session token") + } + } + + test("trim whitespace from credentials") { + val sysProps = Map( + "aws.accessKeyId" -> s" $testAccessKeyId ", + "aws.secretAccessKey" -> s"\n$testSecretAccessKey\t", + "aws.sessionToken" -> s" $testSessionToken \n" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("provider name is correct") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assertEquals(credentials.providerName, Some("f")) // BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code + } + } + + test("loadSetting method reads from system properties") { + val testValue = "test-setting-value" + val sysProps = Map( + "aws.accessKeyId" -> testValue + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, Some(testValue)) + } + } + + test("loadSetting returns None for missing system property") { + given SystemProperties[IO] = mockSystemProperties(Map.empty) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, None) + } + } + + test("implements AwsCredentialsProvider trait") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assert(credentials.isInstanceOf[AwsCredentials]) + } + } + + test("consistent behavior across multiple calls") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + for + credentials1 <- provider.resolveCredentials() + credentials2 <- provider.resolveCredentials() + yield + assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) + assertEquals(credentials1.secretAccessKey, credentials2.secretAccessKey) + credentials1 match + case session1: AwsSessionCredentials => + credentials2 match + case session2: AwsSessionCredentials => + assertEquals(session1.sessionToken, session2.sessionToken) + case _ => fail("Both credentials should be AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") + } + + test("validates credentials format") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + // Verify access key format (starts with AKIA for IAM users) + assert(credentials.accessKeyId.startsWith("AKIA")) + assert(credentials.accessKeyId.length >= 16) + + // Verify secret key is not empty and has reasonable length + assert(credentials.secretAccessKey.nonEmpty) + assert(credentials.secretAccessKey.length >= 20) + } + } + + test("system property keys are correct") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken, + "aws.accountId" -> "123456789012" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + for + accessKey <- provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID) + secretKey <- provider.loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY) + sessionToken <- provider.loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN) + accountId <- provider.loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID) + yield + assertEquals(accessKey, Some(testAccessKeyId)) + assertEquals(secretKey, Some(testSecretAccessKey)) + assertEquals(sessionToken, Some(testSessionToken)) + assertEquals(accountId, Some("123456789012")) + } + + test("different from environment variables - uses system properties") { + // This test ensures that SystemPropertyCredentialsProvider only reads from system properties, + // not environment variables, even if the system property names are similar + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, // System property format + "aws.secretAccessKey" -> testSecretAccessKey, + // Note: NO AWS_ACCESS_KEY_ID (environment variable format) + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("case sensitivity of system property names") { + // System properties are case-sensitive in Java + val sysProps = Map( + "AWS.ACCESSKEYID" -> testAccessKeyId, // Wrong case + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") + case Right(_) => fail("Should fail with incorrect case system property") + } + } + + test("alternative system property names are not supported") { + val sysProps = Map( + "aws.access.key.id" -> testAccessKeyId, // Not aws.accessKeyId + "aws.secret.access.key" -> testSecretAccessKey // Not aws.secretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") + case Right(_) => fail("Should fail with non-standard system property names") + } + } + + test("credentials validation is disabled by default") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.validateCredentials, false) + case session: AwsSessionCredentials => + assertEquals(session.validateCredentials, false) + } + } + + test("handles all supported AWS system properties") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken, + "aws.accountId" -> "123456789012" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("f")) + assert(session.expirationTime.isEmpty) + case _ => fail("Expected AwsSessionCredentials") + } + } \ No newline at end of file From 87a1479956ef1fd8716d80a939144f99edad0dd7 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 16:04:22 +0900 Subject: [PATCH 146/215] Create ProfileCredentialsProvider Test --- .../ProfileCredentialsProviderTest.scala | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala new file mode 100644 index 000000000..9d02d70f7 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala @@ -0,0 +1,153 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.effect.std.SystemProperties +import cats.effect.IO + +import fs2.io.file.Files + +import munit.CatsEffectSuite + +import ldbc.amazon.exception.SdkClientException + +class ProfileCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testHomeDir = "/home/test" + + // Mock system properties + private def mockSystemProperties(homeDir: Option[String] = Some(testHomeDir)): SystemProperties[IO] = + new SystemProperties[IO]: + override def get(name: String): IO[Option[String]] = + name match + case "user.home" => IO.pure(homeDir) + case _ => IO.pure(None) + override def clear(key: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("clear not supported in mock")) + override def set(key: String, value: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("set not supported in mock")) + + test("ProfileCredentialsProvider creation succeeds with default parameters") { + given SystemProperties[IO] = mockSystemProperties() + given Files[IO] = Files.forIO + + for + provider <- ProfileCredentialsProvider.default[IO]() + yield + // Basic test - provider was created successfully + assert(provider != null) + } + + test("ProfileCredentialsProvider creation succeeds with named profile") { + given SystemProperties[IO] = mockSystemProperties() + given Files[IO] = Files.forIO + + for + provider <- ProfileCredentialsProvider.default[IO]("dev") + yield + // Basic test - provider was created successfully + assert(provider != null) + } + + test("ProfileCredentialsProvider fails when user.home is missing") { + given SystemProperties[IO] = mockSystemProperties(homeDir = None) + given Files[IO] = Files.forIO + + for + provider <- ProfileCredentialsProvider.default[IO]() + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + // Should fail because user.home is not available + assert(true) // Test passed + case _ => fail("Expected SdkClientException") + } + + test("ProfileCredentialsProvider companion object methods work") { + val profile = ProfileCredentialsProvider.Profile("test", Map("key" -> "value")) + assertEquals(profile.name, "test") + assertEquals(profile.properties, Map("key" -> "value")) + } + + test("ProfileFile case class creation works") { + import java.time.Instant + import ProfileCredentialsProvider.* + + val profiles = Map("default" -> Profile("default", Map("aws_access_key_id" -> "test"))) + val instant = Instant.now() + val profileFile = ProfileFile(profiles, instant) + + assertEquals(profileFile.profiles, profiles) + assertEquals(profileFile.lastModified, instant) + } + + test("Profile case class creation works") { + import ProfileCredentialsProvider.* + + val properties = Map( + "aws_access_key_id" -> "AKIAIOSFODNN7EXAMPLE", + "aws_secret_access_key" -> "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + ) + val profile = Profile("production", properties) + + assertEquals(profile.name, "production") + assertEquals(profile.properties, properties) + } + + // Test basic functionality that doesn't require file system access + test("ProfileCredentialsProvider implements AwsCredentialsProvider trait") { + given SystemProperties[IO] = mockSystemProperties() + given Files[IO] = Files.forIO + + for + provider <- ProfileCredentialsProvider.default[IO]() + providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider + yield + // Type check passes - provider implements the trait correctly + assert(providerTrait != null) + } + + test("ProfileCredentialsProvider default factory creates provider with correct profile name") { + given SystemProperties[IO] = mockSystemProperties() + given Files[IO] = Files.forIO + + for + defaultProvider <- ProfileCredentialsProvider.default[IO]() + namedProvider <- ProfileCredentialsProvider.default[IO]("custom") + yield + // Both providers should be created successfully + assert(defaultProvider != null) + assert(namedProvider != null) + } + + // Test thread safety of provider creation + test("ProfileCredentialsProvider factory is thread-safe") { + given SystemProperties[IO] = mockSystemProperties() + given Files[IO] = Files.forIO + + for + providers <- IO.parSequenceN(10)((1 to 10).map(_ => ProfileCredentialsProvider.default[IO]()).toList) + yield + // All providers should be created successfully + providers.foreach(provider => assert(provider != null)) + } + + test("ProfileCredentialsProvider handles various profile names correctly") { + given SystemProperties[IO] = mockSystemProperties() + given Files[IO] = Files.forIO + + val profileNames = List("default", "dev", "staging", "production", "test-profile", "profile_with_underscores") + + for + providers <- IO.traverse(profileNames)(ProfileCredentialsProvider.default[IO](_)) + yield + // All providers should be created successfully regardless of profile name format + providers.foreach(provider => assert(provider != null)) + assertEquals(providers.length, profileNames.length) + } \ No newline at end of file From e75cdfb24d25685d422873ca964bd1ffa4899f66 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 16:11:56 +0900 Subject: [PATCH 147/215] Create InstanceProfileCredentialsProvider Test --- ...stanceProfileCredentialsProviderTest.scala | 535 ++++++++++++++++++ 1 file changed, 535 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala new file mode 100644 index 000000000..4fa3b6c3b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala @@ -0,0 +1,535 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import cats.effect.std.Env +import cats.effect.{ IO, Ref } + +import munit.CatsEffectSuite + +import ldbc.amazon.client.{ HttpClient, HttpResponse } +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials + +class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testRoleName = "test-role" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val testMetadataToken = "AQAEAFUsWMxSzWnKOL7wJKWJHFGDL+gUSCMd6FKPYa8wOYwR=" + private val futureExpiration = Instant.now().plusSeconds(3600) + + // Sample IMDS JSON response + private val validCredentialsResponse = s"""{ + "Code": "Success", + "LastUpdated": "2024-01-01T12:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "$testAccessKeyId", + "SecretAccessKey": "$testSecretAccessKey", + "Token": "$testSessionToken", + "Expiration": "${futureExpiration.toString}" + }""" + + private val failedCredentialsResponse = s"""{ + "Code": "Failed", + "LastUpdated": "2024-01-01T12:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "", + "SecretAccessKey": "", + "Token": "", + "Expiration": "" + }""" + + // Mock HTTP client + case class MockRequest(uri: URI, headers: Map[String, String], method: String, body: Option[String]) + + private def mockHttpClient( + responses: Map[String, HttpResponse], + captureRequest: Option[Ref[IO, List[MockRequest]]] = None + ): HttpClient[IO] = + new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + for + _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "GET", None)) + case None => IO.unit + yield responses.getOrElse( + uri.toString, + HttpResponse(404, Map.empty, "Not Found") + ) + + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + for + _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "POST", Some(body))) + case None => IO.unit + yield responses.getOrElse( + uri.toString, + HttpResponse(404, Map.empty, "Not Found") + ) + + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + for + _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "PUT", Some(body))) + case None => IO.unit + yield responses.getOrElse( + uri.toString, + HttpResponse(200, Map.empty, testMetadataToken) + ) + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + test("resolveCredentials with successful IMDSv2 flow") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("0")) // BusinessMetricFeatureId.CREDENTIALS_IMDS.code + assertEquals(session.accountId, None) + assertEquals(session.expirationTime, Some(futureExpiration)) + case _ => fail("Expected AwsSessionCredentials") + + // Verify the request flow + val tokenRequest = requests.find(_.method == "PUT") + assert(tokenRequest.isDefined) + assert(tokenRequest.get.headers.contains("X-aws-ec2-metadata-token-ttl-seconds")) + + val roleRequest = requests.find(_.uri.toString.endsWith("security-credentials/")) + assert(roleRequest.isDefined) + assert(roleRequest.get.headers.contains("X-aws-ec2-metadata-token")) + + val credRequest = requests.find(_.uri.toString.endsWith(s"security-credentials/$testRoleName")) + assert(credRequest.isDefined) + assert(credRequest.get.headers.contains("X-aws-ec2-metadata-token")) + } + + test("resolveCredentials with IMDSv1 fallback when token acquisition fails") { + val responses = Map( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + IO.pure(responses.getOrElse(uri.toString, HttpResponse(404, Map.empty, "Not Found"))) + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("POST not supported")) + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + // Simulate token acquisition failure + IO.raiseError(new Exception("Token acquisition failed")) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") + } + + test("fail when EC2 metadata is disabled") { + given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_DISABLED" -> "true")) + + val httpClient = mockHttpClient(Map.empty) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("EC2 metadata service is disabled")) + case _ => fail("Expected SdkClientException") + } + + test("fail when no IAM roles are available") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, "") // Empty response + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("No IAM roles found")) + case _ => fail("Expected SdkClientException") + } + + test("fail when instance profile is not attached (403)") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(403, Map.empty, "Forbidden") + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Forbidden (403)")) + assert(exception.getMessage.contains("No instance profile attached")) + case _ => fail("Expected SdkClientException") + } + + test("fail when metadata service is not available (404)") { + val responses = Map( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(404, Map.empty, "Not Found") + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Not Found (404)")) + assert(exception.getMessage.contains("Instance metadata not available")) + case _ => fail("Expected SdkClientException") + } + + test("fail when metadata token is invalid (401)") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(401, Map.empty, "Unauthorized") + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unauthorized (401)")) + assert(exception.getMessage.contains("Invalid or expired metadata token")) + case _ => fail("Expected SdkClientException") + } + + test("fail when credentials response has failed status") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, failedCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to retrieve credentials")) + assert(exception.getMessage.contains("Failed")) + case _ => fail("Expected SdkClientException") + } + + test("fail when credentials response is malformed JSON") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, "invalid json") + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield + result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to parse")) + case _ => fail("Expected SdkClientException") + } + + test("use custom IMDS endpoint when environment variable is set") { + val customEndpoint = "http://169.254.169.254:8080" + val responses = Map( + s"$customEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + s"$customEndpoint/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"$customEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_SERVICE_ENDPOINT" -> customEndpoint)) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify custom endpoint was used + assert(requests.exists(_.uri.toString.startsWith(customEndpoint))) + } + + test("strip trailing slash from custom endpoint") { + val customEndpoint = "http://169.254.169.254:8080/" + val expectedEndpoint = "http://169.254.169.254:8080" + val responses = Map( + s"$expectedEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + s"$expectedEndpoint/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"$expectedEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_SERVICE_ENDPOINT" -> customEndpoint)) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify trailing slash was stripped + assert(requests.exists(_.uri.toString.startsWith(expectedEndpoint))) + assert(!requests.exists(_.uri.toString.contains("//latest"))) + } + + test("caching works correctly - multiple calls return cached credentials") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials1 <- provider.resolveCredentials() + credentials2 <- provider.resolveCredentials() + requests <- requestCapture.get + yield + // Both calls should return the same credentials + assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) + assertEquals(credentials1.secretAccessKey, credentials2.secretAccessKey) + + // Only one set of requests should have been made (first call) + val tokenRequests = requests.filter(_.method == "PUT") + assertEquals(tokenRequests.length, 1) + } + + test("credentials are refreshed when close to expiration") { + val shortExpirationTime = Instant.now().plusSeconds(180) // 3 minutes from now (less than 4 minute buffer) + val shortExpirationResponse = s"""{ + "Code": "Success", + "LastUpdated": "2024-01-01T12:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "$testAccessKeyId", + "SecretAccessKey": "$testSecretAccessKey", + "Token": "$testSessionToken", + "Expiration": "${shortExpirationTime.toString}" + }""" + + val longExpirationTime = Instant.now().plusSeconds(3600) // 1 hour from now + val newAccessKeyId = "ASIANEWACCESSKEY123" + val refreshedResponse = s"""{ + "Code": "Success", + "LastUpdated": "2024-01-01T12:30:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "$newAccessKeyId", + "SecretAccessKey": "$testSecretAccessKey", + "Token": "$testSessionToken", + "Expiration": "${longExpirationTime.toString}" + }""" + + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName) + ) + + // First response with short expiration, second with long expiration + val callCount = Ref.unsafe[IO, Int](0) + val httpClient = new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + if uri.toString.endsWith(s"security-credentials/$testRoleName") then + callCount.updateAndGet(_ + 1).map { count => + if count == 1 then + HttpResponse(200, Map.empty, shortExpirationResponse) + else + HttpResponse(200, Map.empty, refreshedResponse) + } + else + IO.pure(responses.getOrElse(uri.toString, HttpResponse(404, Map.empty, "Not Found"))) + + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("POST not supported")) + + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.pure(HttpResponse(200, Map.empty, testMetadataToken)) + + given Env[IO] = mockEnv(Map.empty) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials1 <- provider.resolveCredentials() // Gets short expiration credentials + credentials2 <- provider.resolveCredentials() // Should refresh due to short expiration + count <- callCount.get + yield + // First credentials should have short expiration + assertEquals(credentials1.expirationTime, Some(shortExpirationTime)) + + // Second credentials should be refreshed with new access key + assertEquals(credentials2.accessKeyId, newAccessKeyId) + assertEquals(credentials2.expirationTime, Some(longExpirationTime)) + + // Should have made 2 credential requests due to refresh + assertEquals(count, 2) + } + + test("implements AwsCredentialsProvider trait") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider + credentials <- providerTrait.resolveCredentials() + yield + assert(credentials.isInstanceOf[AwsCredentials]) + } + + test("handles multiple IAM roles by selecting the first one") { + val multipleRoles = s"$testRoleName\nrole2\nrole3" + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, multipleRoles), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + } + + test("request headers are set correctly") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + ) + + given Env[IO] = mockEnv(Map.empty) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify PUT request headers for token acquisition + val tokenRequest = requests.find(_.method == "PUT") + assert(tokenRequest.isDefined) + assertEquals(tokenRequest.get.headers("X-aws-ec2-metadata-token-ttl-seconds"), "21600") + + // Verify GET request headers include token and user agent + val getRequests = requests.filter(_.method == "GET") + getRequests.foreach { request => + assertEquals(request.headers("Accept"), "application/json") + assertEquals(request.headers("User-Agent"), "aws-sdk-scala/ldbc") + assertEquals(request.headers("X-aws-ec2-metadata-token"), testMetadataToken) + } + } \ No newline at end of file From e44b33edbc96fd2fe04342a016384f90256506b9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 16:12:13 +0900 Subject: [PATCH 148/215] Action sbt scalafmtAll --- ...stanceProfileCredentialsProviderTest.scala | 300 ++++++++++-------- .../ProfileCredentialsProviderTest.scala | 57 ++-- ...ystemPropertyCredentialsProviderTest.scala | 133 ++++---- 3 files changed, 262 insertions(+), 228 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala index 4fa3b6c3b..e57db2970 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala @@ -9,8 +9,8 @@ package ldbc.amazon.auth.credentials import java.net.URI import java.time.Instant -import cats.effect.std.Env import cats.effect.{ IO, Ref } +import cats.effect.std.Env import munit.CatsEffectSuite @@ -21,12 +21,12 @@ import ldbc.amazon.identity.AwsCredentials class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: // Test fixtures - private val testRoleName = "test-role" - private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testRoleName = "test-role" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" - private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" - private val testMetadataToken = "AQAEAFUsWMxSzWnKOL7wJKWJHFGDL+gUSCMd6FKPYa8wOYwR=" - private val futureExpiration = Instant.now().plusSeconds(3600) + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val testMetadataToken = "AQAEAFUsWMxSzWnKOL7wJKWJHFGDL+gUSCMd6FKPYa8wOYwR=" + private val futureExpiration = Instant.now().plusSeconds(3600) // Sample IMDS JSON response private val validCredentialsResponse = s"""{ @@ -36,7 +36,7 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: "AccessKeyId": "$testAccessKeyId", "SecretAccessKey": "$testSecretAccessKey", "Token": "$testSessionToken", - "Expiration": "${futureExpiration.toString}" + "Expiration": "${ futureExpiration.toString }" }""" private val failedCredentialsResponse = s"""{ @@ -53,35 +53,32 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: case class MockRequest(uri: URI, headers: Map[String, String], method: String, body: Option[String]) private def mockHttpClient( - responses: Map[String, HttpResponse], + responses: Map[String, HttpResponse], captureRequest: Option[Ref[IO, List[MockRequest]]] = None ): HttpClient[IO] = new HttpClient[IO]: override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = - for - _ <- captureRequest match - case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "GET", None)) - case None => IO.unit + for _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "GET", None)) + case None => IO.unit yield responses.getOrElse( uri.toString, HttpResponse(404, Map.empty, "Not Found") ) override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = - for - _ <- captureRequest match - case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "POST", Some(body))) - case None => IO.unit + for _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "POST", Some(body))) + case None => IO.unit yield responses.getOrElse( uri.toString, HttpResponse(404, Map.empty, "Not Found") ) override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = - for - _ <- captureRequest match - case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "PUT", Some(body))) - case None => IO.unit + for _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "PUT", Some(body))) + case None => IO.unit yield responses.getOrElse( uri.toString, HttpResponse(200, Map.empty, testMetadataToken) @@ -99,18 +96,22 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) - val httpClient = mockHttpClient(responses, Some(requestCapture)) + val httpClient = mockHttpClient(responses, Some(requestCapture)) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials <- provider.resolveCredentials() - requests <- requestCapture.get + requests <- requestCapture.get yield credentials match case session: AwsSessionCredentials => @@ -140,7 +141,11 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with IMDSv1 fallback when token acquisition fails") { val responses = Map( "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) @@ -155,15 +160,14 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: IO.raiseError(new Exception("Token acquisition failed")) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials <- provider.resolveCredentials() - yield - credentials match - case session: AwsSessionCredentials => - assertEquals(session.accessKeyId, testAccessKeyId) - assertEquals(session.secretAccessKey, testSecretAccessKey) - assertEquals(session.sessionToken, testSessionToken) - case _ => fail("Expected AwsSessionCredentials") + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") } test("fail when EC2 metadata is disabled") { @@ -173,18 +177,21 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("EC2 metadata service is disabled")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("EC2 metadata service is disabled")) + case _ => fail("Expected SdkClientException") } test("fail when no IAM roles are available") { val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), - "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, "") // Empty response + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse( + 200, + Map.empty, + "" + ) // Empty response ) given Env[IO] = mockEnv(Map.empty) @@ -193,12 +200,11 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("No IAM roles found")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("No IAM roles found")) + case _ => fail("Expected SdkClientException") } test("fail when instance profile is not attached (403)") { @@ -213,13 +219,12 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("Forbidden (403)")) - assert(exception.getMessage.contains("No instance profile attached")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Forbidden (403)")) + assert(exception.getMessage.contains("No instance profile attached")) + case _ => fail("Expected SdkClientException") } test("fail when metadata service is not available (404)") { @@ -233,19 +238,22 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("Not Found (404)")) - assert(exception.getMessage.contains("Instance metadata not available")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Not Found (404)")) + assert(exception.getMessage.contains("Instance metadata not available")) + case _ => fail("Expected SdkClientException") } test("fail when metadata token is invalid (401)") { val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), - "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(401, Map.empty, "Unauthorized") + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse( + 401, + Map.empty, + "Unauthorized" + ) ) given Env[IO] = mockEnv(Map.empty) @@ -254,20 +262,23 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("Unauthorized (401)")) - assert(exception.getMessage.contains("Invalid or expired metadata token")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unauthorized (401)")) + assert(exception.getMessage.contains("Invalid or expired metadata token")) + case _ => fail("Expected SdkClientException") } test("fail when credentials response has failed status") { val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, failedCredentialsResponse) + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + failedCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) @@ -276,20 +287,23 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("Failed to retrieve credentials")) - assert(exception.getMessage.contains("Failed")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to retrieve credentials")) + assert(exception.getMessage.contains("Failed")) + case _ => fail("Expected SdkClientException") } test("fail when credentials response is malformed JSON") { val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, "invalid json") + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + "invalid json" + ) ) given Env[IO] = mockEnv(Map.empty) @@ -298,31 +312,34 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: for provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("Failed to parse")) - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to parse")) + case _ => fail("Expected SdkClientException") } test("use custom IMDS endpoint when environment variable is set") { val customEndpoint = "http://169.254.169.254:8080" - val responses = Map( - s"$customEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + val responses = Map( + s"$customEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), s"$customEndpoint/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"$customEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"$customEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_SERVICE_ENDPOINT" -> customEndpoint)) val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) - val httpClient = mockHttpClient(responses, Some(requestCapture)) + val httpClient = mockHttpClient(responses, Some(requestCapture)) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials <- provider.resolveCredentials() - requests <- requestCapture.get + requests <- requestCapture.get yield credentials match case session: AwsSessionCredentials => @@ -334,23 +351,27 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: } test("strip trailing slash from custom endpoint") { - val customEndpoint = "http://169.254.169.254:8080/" + val customEndpoint = "http://169.254.169.254:8080/" val expectedEndpoint = "http://169.254.169.254:8080" - val responses = Map( + val responses = Map( s"$expectedEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), s"$expectedEndpoint/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"$expectedEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"$expectedEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_SERVICE_ENDPOINT" -> customEndpoint)) val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) - val httpClient = mockHttpClient(responses, Some(requestCapture)) + val httpClient = mockHttpClient(responses, Some(requestCapture)) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials <- provider.resolveCredentials() - requests <- requestCapture.get + requests <- requestCapture.get yield credentials match case session: AwsSessionCredentials => @@ -366,19 +387,23 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) - val httpClient = mockHttpClient(responses, Some(requestCapture)) + val httpClient = mockHttpClient(responses, Some(requestCapture)) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials1 <- provider.resolveCredentials() credentials2 <- provider.resolveCredentials() - requests <- requestCapture.get + requests <- requestCapture.get yield // Both calls should return the same credentials assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) @@ -390,7 +415,7 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: } test("credentials are refreshed when close to expiration") { - val shortExpirationTime = Instant.now().plusSeconds(180) // 3 minutes from now (less than 4 minute buffer) + val shortExpirationTime = Instant.now().plusSeconds(180) // 3 minutes from now (less than 4 minute buffer) val shortExpirationResponse = s"""{ "Code": "Success", "LastUpdated": "2024-01-01T12:00:00Z", @@ -398,19 +423,19 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: "AccessKeyId": "$testAccessKeyId", "SecretAccessKey": "$testSecretAccessKey", "Token": "$testSessionToken", - "Expiration": "${shortExpirationTime.toString}" + "Expiration": "${ shortExpirationTime.toString }" }""" val longExpirationTime = Instant.now().plusSeconds(3600) // 1 hour from now - val newAccessKeyId = "ASIANEWACCESSKEY123" - val refreshedResponse = s"""{ + val newAccessKeyId = "ASIANEWACCESSKEY123" + val refreshedResponse = s"""{ "Code": "Success", "LastUpdated": "2024-01-01T12:30:00Z", "Type": "AWS-HMAC", "AccessKeyId": "$newAccessKeyId", "SecretAccessKey": "$testSecretAccessKey", "Token": "$testSessionToken", - "Expiration": "${longExpirationTime.toString}" + "Expiration": "${ longExpirationTime.toString }" }""" val responses = Map( @@ -419,18 +444,15 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: ) // First response with short expiration, second with long expiration - val callCount = Ref.unsafe[IO, Int](0) + val callCount = Ref.unsafe[IO, Int](0) val httpClient = new HttpClient[IO]: override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = if uri.toString.endsWith(s"security-credentials/$testRoleName") then callCount.updateAndGet(_ + 1).map { count => - if count == 1 then - HttpResponse(200, Map.empty, shortExpirationResponse) - else - HttpResponse(200, Map.empty, refreshedResponse) + if count == 1 then HttpResponse(200, Map.empty, shortExpirationResponse) + else HttpResponse(200, Map.empty, refreshedResponse) } - else - IO.pure(responses.getOrElse(uri.toString, HttpResponse(404, Map.empty, "Not Found"))) + else IO.pure(responses.getOrElse(uri.toString, HttpResponse(404, Map.empty, "Not Found"))) override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = IO.raiseError(new UnsupportedOperationException("POST not supported")) @@ -441,18 +463,18 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: given Env[IO] = mockEnv(Map.empty) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials1 <- provider.resolveCredentials() // Gets short expiration credentials credentials2 <- provider.resolveCredentials() // Should refresh due to short expiration - count <- callCount.get + count <- callCount.get yield // First credentials should have short expiration assertEquals(credentials1.expirationTime, Some(shortExpirationTime)) - + // Second credentials should be refreshed with new access key assertEquals(credentials2.accessKeyId, newAccessKeyId) assertEquals(credentials2.expirationTime, Some(longExpirationTime)) - + // Should have made 2 credential requests due to refresh assertEquals(count, 2) } @@ -461,7 +483,11 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) @@ -472,16 +498,23 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider credentials <- providerTrait.resolveCredentials() - yield - assert(credentials.isInstanceOf[AwsCredentials]) + yield assert(credentials.isInstanceOf[AwsCredentials]) } test("handles multiple IAM roles by selecting the first one") { val multipleRoles = s"$testRoleName\nrole2\nrole3" - val responses = Map( + val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), - "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, multipleRoles), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse( + 200, + Map.empty, + multipleRoles + ), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) @@ -489,31 +522,34 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: val httpClient = mockHttpClient(responses) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials <- provider.resolveCredentials() - yield - credentials match - case session: AwsSessionCredentials => - assertEquals(session.accessKeyId, testAccessKeyId) - case _ => fail("Expected AwsSessionCredentials") + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") } test("request headers are set correctly") { val responses = Map( "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), - s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse(200, Map.empty, validCredentialsResponse) + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) ) given Env[IO] = mockEnv(Map.empty) val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) - val httpClient = mockHttpClient(responses, Some(requestCapture)) + val httpClient = mockHttpClient(responses, Some(requestCapture)) for - provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) credentials <- provider.resolveCredentials() - requests <- requestCapture.get + requests <- requestCapture.get yield credentials match case session: AwsSessionCredentials => @@ -532,4 +568,4 @@ class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: assertEquals(request.headers("User-Agent"), "aws-sdk-scala/ldbc") assertEquals(request.headers("X-aws-ec2-metadata-token"), testMetadataToken) } - } \ No newline at end of file + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala index 9d02d70f7..146e625f3 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala @@ -34,10 +34,9 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileCredentialsProvider creation succeeds with default parameters") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO - for - provider <- ProfileCredentialsProvider.default[IO]() + for provider <- ProfileCredentialsProvider.default[IO]() yield // Basic test - provider was created successfully assert(provider != null) @@ -45,10 +44,9 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileCredentialsProvider creation succeeds with named profile") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO - for - provider <- ProfileCredentialsProvider.default[IO]("dev") + for provider <- ProfileCredentialsProvider.default[IO]("dev") yield // Basic test - provider was created successfully assert(provider != null) @@ -56,17 +54,16 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileCredentialsProvider fails when user.home is missing") { given SystemProperties[IO] = mockSystemProperties(homeDir = None) - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO for provider <- ProfileCredentialsProvider.default[IO]() - result <- provider.resolveCredentials().attempt - yield - result match - case Left(exception: SdkClientException) => - // Should fail because user.home is not available - assert(true) // Test passed - case _ => fail("Expected SdkClientException") + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + // Should fail because user.home is not available + assert(true) // Test passed + case _ => fail("Expected SdkClientException") } test("ProfileCredentialsProvider companion object methods work") { @@ -78,24 +75,24 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileFile case class creation works") { import java.time.Instant import ProfileCredentialsProvider.* - - val profiles = Map("default" -> Profile("default", Map("aws_access_key_id" -> "test"))) - val instant = Instant.now() + + val profiles = Map("default" -> Profile("default", Map("aws_access_key_id" -> "test"))) + val instant = Instant.now() val profileFile = ProfileFile(profiles, instant) - + assertEquals(profileFile.profiles, profiles) assertEquals(profileFile.lastModified, instant) } test("Profile case class creation works") { import ProfileCredentialsProvider.* - + val properties = Map( - "aws_access_key_id" -> "AKIAIOSFODNN7EXAMPLE", + "aws_access_key_id" -> "AKIAIOSFODNN7EXAMPLE", "aws_secret_access_key" -> "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" ) val profile = Profile("production", properties) - + assertEquals(profile.name, "production") assertEquals(profile.properties, properties) } @@ -103,7 +100,7 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: // Test basic functionality that doesn't require file system access test("ProfileCredentialsProvider implements AwsCredentialsProvider trait") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO for provider <- ProfileCredentialsProvider.default[IO]() @@ -115,11 +112,11 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileCredentialsProvider default factory creates provider with correct profile name") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO for defaultProvider <- ProfileCredentialsProvider.default[IO]() - namedProvider <- ProfileCredentialsProvider.default[IO]("custom") + namedProvider <- ProfileCredentialsProvider.default[IO]("custom") yield // Both providers should be created successfully assert(defaultProvider != null) @@ -129,10 +126,9 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: // Test thread safety of provider creation test("ProfileCredentialsProvider factory is thread-safe") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO - for - providers <- IO.parSequenceN(10)((1 to 10).map(_ => ProfileCredentialsProvider.default[IO]()).toList) + for providers <- IO.parSequenceN(10)((1 to 10).map(_ => ProfileCredentialsProvider.default[IO]()).toList) yield // All providers should be created successfully providers.foreach(provider => assert(provider != null)) @@ -140,14 +136,13 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileCredentialsProvider handles various profile names correctly") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO + given Files[IO] = Files.forIO val profileNames = List("default", "dev", "staging", "production", "test-profile", "profile_with_underscores") - for - providers <- IO.traverse(profileNames)(ProfileCredentialsProvider.default[IO](_)) + for providers <- IO.traverse(profileNames)(ProfileCredentialsProvider.default[IO](_)) yield // All providers should be created successfully regardless of profile name format providers.foreach(provider => assert(provider != null)) assertEquals(providers.length, profileNames.length) - } \ No newline at end of file + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala index c3e1f2c03..6e4de9e81 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala @@ -18,9 +18,9 @@ import ldbc.amazon.util.SdkSystemSetting class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: // Test fixtures - private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" - private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" // Mock system properties private def mockSystemProperties(sysProps: Map[String, String]): SystemProperties[IO] = @@ -34,10 +34,10 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with basic credentials") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -56,11 +56,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with session credentials") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.sessionToken" -> testSessionToken + "aws.sessionToken" -> testSessionToken ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -80,11 +80,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("resolveCredentials with account ID") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.accountId" -> "123456789012" + "aws.accountId" -> "123456789012" ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -102,7 +102,7 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: val sysProps = Map( "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -119,7 +119,7 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: val sysProps = Map( "aws.accessKeyId" -> testAccessKeyId ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -147,10 +147,10 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("handle empty system property values") { val sysProps = Map( - "aws.accessKeyId" -> "", + "aws.accessKeyId" -> "", "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -165,10 +165,10 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("handle whitespace-only system property values") { val sysProps = Map( - "aws.accessKeyId" -> " ", + "aws.accessKeyId" -> " ", "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -183,11 +183,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("handle empty session token") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.sessionToken" -> "" + "aws.sessionToken" -> "" ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -203,11 +203,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("handle whitespace-only session token") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.sessionToken" -> " \n " + "aws.sessionToken" -> " \n " ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -223,11 +223,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("trim whitespace from credentials") { val sysProps = Map( - "aws.accessKeyId" -> s" $testAccessKeyId ", + "aws.accessKeyId" -> s" $testAccessKeyId ", "aws.secretAccessKey" -> s"\n$testSecretAccessKey\t", - "aws.sessionToken" -> s" $testSessionToken \n" + "aws.sessionToken" -> s" $testSessionToken \n" ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -243,25 +243,28 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("provider name is correct") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] provider.resolveCredentials().map { credentials => - assertEquals(credentials.providerName, Some("f")) // BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code + assertEquals( + credentials.providerName, + Some("f") + ) // BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code } } test("loadSetting method reads from system properties") { val testValue = "test-setting-value" - val sysProps = Map( + val sysProps = Map( "aws.accessKeyId" -> testValue ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -283,13 +286,13 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("implements AwsCredentialsProvider trait") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) - val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = new SystemPropertyCredentialsProvider[IO] provider.resolveCredentials().map { credentials => @@ -299,11 +302,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("consistent behavior across multiple calls") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.sessionToken" -> testSessionToken + "aws.sessionToken" -> testSessionToken ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -325,10 +328,10 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("validates credentials format") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -337,7 +340,7 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: // Verify access key format (starts with AKIA for IAM users) assert(credentials.accessKeyId.startsWith("AKIA")) assert(credentials.accessKeyId.length >= 16) - + // Verify secret key is not empty and has reasonable length assert(credentials.secretAccessKey.nonEmpty) assert(credentials.secretAccessKey.length >= 20) @@ -346,21 +349,21 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("system property keys are correct") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.sessionToken" -> testSessionToken, - "aws.accountId" -> "123456789012" + "aws.sessionToken" -> testSessionToken, + "aws.accountId" -> "123456789012" ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] for - accessKey <- provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID) - secretKey <- provider.loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY) + accessKey <- provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID) + secretKey <- provider.loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY) sessionToken <- provider.loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN) - accountId <- provider.loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID) + accountId <- provider.loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID) yield assertEquals(accessKey, Some(testAccessKeyId)) assertEquals(secretKey, Some(testSecretAccessKey)) @@ -372,11 +375,11 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: // This test ensures that SystemPropertyCredentialsProvider only reads from system properties, // not environment variables, even if the system property names are similar val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, // System property format - "aws.secretAccessKey" -> testSecretAccessKey, + "aws.accessKeyId" -> testAccessKeyId, // System property format + "aws.secretAccessKey" -> testSecretAccessKey // Note: NO AWS_ACCESS_KEY_ID (environment variable format) ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -392,10 +395,10 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("case sensitivity of system property names") { // System properties are case-sensitive in Java val sysProps = Map( - "AWS.ACCESSKEYID" -> testAccessKeyId, // Wrong case + "AWS.ACCESSKEYID" -> testAccessKeyId, // Wrong case "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -404,17 +407,17 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: case Left(exception: SdkClientException) => assert(exception.getMessage.contains("Unable to load credentials from system settings")) assert(exception.getMessage.contains("Access key")) - case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") - case Right(_) => fail("Should fail with incorrect case system property") + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with incorrect case system property") } } test("alternative system property names are not supported") { val sysProps = Map( - "aws.access.key.id" -> testAccessKeyId, // Not aws.accessKeyId - "aws.secret.access.key" -> testSecretAccessKey // Not aws.secretAccessKey + "aws.access.key.id" -> testAccessKeyId, // Not aws.accessKeyId + "aws.secret.access.key" -> testSecretAccessKey // Not aws.secretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -422,17 +425,17 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: provider.resolveCredentials().attempt.map { case Left(exception: SdkClientException) => assert(exception.getMessage.contains("Unable to load credentials from system settings")) - case Left(other) => fail(s"Expected SdkClientException, got ${other.getClass.getSimpleName}") - case Right(_) => fail("Should fail with non-standard system property names") + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with non-standard system property names") } } test("credentials validation is disabled by default") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -447,12 +450,12 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: test("handles all supported AWS system properties") { val sysProps = Map( - "aws.accessKeyId" -> testAccessKeyId, + "aws.accessKeyId" -> testAccessKeyId, "aws.secretAccessKey" -> testSecretAccessKey, - "aws.sessionToken" -> testSessionToken, - "aws.accountId" -> "123456789012" + "aws.sessionToken" -> testSessionToken, + "aws.accountId" -> "123456789012" ) - + given SystemProperties[IO] = mockSystemProperties(sysProps) val provider = new SystemPropertyCredentialsProvider[IO] @@ -468,4 +471,4 @@ class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: assert(session.expirationTime.isEmpty) case _ => fail("Expected AwsSessionCredentials") } - } \ No newline at end of file + } From aa71091be58c44242bebe2123140d0b777ac98cc Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 18:04:48 +0900 Subject: [PATCH 149/215] Create WebIdentityTokenFileCredentialsProvider --- ...tityTokenFileCredentialsProviderTest.scala | 372 ++++++++++++++++++ 1 file changed, 372 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala new file mode 100644 index 000000000..ee58dec4a --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala @@ -0,0 +1,372 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import cats.effect.{ IO, Ref } +import cats.effect.std.{ Env, SystemProperties } + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.auth.credentials.AwsSessionCredentials +import ldbc.amazon.identity.AwsCredentials + +class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testRoleArn = "arn:aws:iam::123456789012:role/test-role" + private val testTokenFile = "/var/run/secrets/eks.amazonaws.com/serviceaccount/token" + private val testSessionName = "test-session" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val futureExpiration = Instant.now().plusSeconds(3600) + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + // Mock system properties + private def mockSystemProperties(sysProps: Map[String, String] = Map.empty): SystemProperties[IO] = + new SystemProperties[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(sysProps.get(name)) + override def clear(key: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("clear not supported in mock")) + override def set(key: String, value: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("set not supported in mock")) + + // Mock WebIdentityCredentialsUtils + private def mockWebIdentityUtils( + response: Option[AwsCredentials] = None, + shouldFail: Boolean = false, + errorMessage: String = "Mock STS error", + captureRequestsTo: Option[Ref[IO, List[WebIdentityTokenCredentialProperties]]] = None + ): WebIdentityCredentialsUtils[IO] = + (config: WebIdentityTokenCredentialProperties) => for + _ <- captureRequestsTo match + case Some(ref) => ref.update(_ :+ config) + case None => IO.unit + yield + if shouldFail then throw new SdkClientException(errorMessage) + else + response.getOrElse( + AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = false, + providerName = None, + accountId = Some("123456789012"), + expirationTime = Some(futureExpiration) + ) + ) + + test("resolveCredentials with successful Web Identity Token authentication") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn, + "AWS_ROLE_SESSION_NAME" -> testSessionName + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.expirationTime, Some(futureExpiration)) + case _ => fail("Expected AwsSessionCredentials") + + // Verify the correct configuration was passed + assertEquals(requests.length, 1) + val request = requests.head + assertEquals(request.webIdentityTokenFile.toString, testTokenFile) + assertEquals(request.roleArn, testRoleArn) + assertEquals(request.roleSessionName, Some(testSessionName)) + } + + test("resolveCredentials with system properties instead of environment variables") { + val sysProps = Map( + "aws.webIdentityTokenFile" -> testTokenFile, + "aws.roleArn" -> testRoleArn, + "aws.roleSessionName" -> testSessionName + ) + + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("environment variables take precedence over system properties") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + val sysProps = Map( + "aws.webIdentityTokenFile" -> "/wrong/path", + "aws.roleArn" -> "arn:aws:iam::999999999999:role/wrong-role" + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => assert(true) + case _ => fail("Expected AwsSessionCredentials") + + // Verify environment variables were used + val request = requests.head + assertEquals(request.webIdentityTokenFile.toString, testTokenFile) + assertEquals(request.roleArn, testRoleArn) + } + + test("resolveCredentials without role session name uses None") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => assert(true) + case _ => fail("Expected AwsSessionCredentials") + + // Verify no session name was provided + val request = requests.head + assertEquals(request.roleSessionName, None) + } + + test("fail when token file is not specified") { + val envVars = Map( + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + assert(exception.getMessage.contains("AWS_WEB_IDENTITY_TOKEN_FILE")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when role ARN is not specified") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + assert(exception.getMessage.contains("AWS_ROLE_ARN")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when both token file and role ARN are missing") { + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when WebIdentityCredentialsUtils throws an exception") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils(shouldFail = true, errorMessage = "STS service unavailable") + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assertEquals(exception.getMessage, "STS service unavailable") + case _ => fail("Expected SdkClientException") + } + } + + test("implements AwsCredentialsProvider trait") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + // Type check passes - provider implements the trait correctly + val providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider + + assertIOBoolean(providerTrait.resolveCredentials().map(_.isInstanceOf[AwsCredentials])) + } + + test("isAvailable returns true when required configuration is present") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIOBoolean( + WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), + ) + } + + test("isAvailable returns true with system properties") { + val sysProps = Map( + "aws.webIdentityTokenFile" -> testTokenFile, + "aws.roleArn" -> testRoleArn + ) + + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties(sysProps) + + assertIOBoolean(WebIdentityTokenFileCredentialsProvider.isAvailable[IO]()) + } + + test("isAvailable returns false when token file is missing") { + val envVars = Map( + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("isAvailable returns false when role ARN is missing") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("isAvailable returns false when both are missing") { + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("isAvailable returns false when values are empty or whitespace") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> " ", + "AWS_ROLE_ARN" -> "" + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("handles whitespace trimming in configuration values") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> s" $testTokenFile ", + "AWS_ROLE_ARN" -> s" $testRoleArn ", + "AWS_ROLE_SESSION_NAME" -> s" $testSessionName " + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => assert(true) + case _ => fail("Expected AwsSessionCredentials") + + // Verify trimmed values were used + val request = requests.head + assertEquals(request.webIdentityTokenFile.toString, testTokenFile) + assertEquals(request.roleArn, testRoleArn) + assertEquals(request.roleSessionName, Some(testSessionName)) + } \ No newline at end of file From 2b84d98c621c79d94b6e0ac6dfbcd9a51f525a4a Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 18:05:08 +0900 Subject: [PATCH 150/215] Action sbt scalafmtAll --- ...tityTokenFileCredentialsProviderTest.scala | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala index ee58dec4a..31955ae26 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala @@ -14,8 +14,8 @@ import cats.effect.std.{ Env, SystemProperties } import munit.CatsEffectSuite import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils -import ldbc.amazon.exception.SdkClientException import ldbc.amazon.auth.credentials.AwsSessionCredentials +import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.AwsCredentials class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: @@ -49,29 +49,29 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: // Mock WebIdentityCredentialsUtils private def mockWebIdentityUtils( - response: Option[AwsCredentials] = None, - shouldFail: Boolean = false, - errorMessage: String = "Mock STS error", - captureRequestsTo: Option[Ref[IO, List[WebIdentityTokenCredentialProperties]]] = None + response: Option[AwsCredentials] = None, + shouldFail: Boolean = false, + errorMessage: String = "Mock STS error", + captureRequestsTo: Option[Ref[IO, List[WebIdentityTokenCredentialProperties]]] = None ): WebIdentityCredentialsUtils[IO] = - (config: WebIdentityTokenCredentialProperties) => for - _ <- captureRequestsTo match - case Some(ref) => ref.update(_ :+ config) - case None => IO.unit - yield - if shouldFail then throw new SdkClientException(errorMessage) - else - response.getOrElse( - AwsSessionCredentials( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, - sessionToken = testSessionToken, - validateCredentials = false, - providerName = None, - accountId = Some("123456789012"), - expirationTime = Some(futureExpiration) + (config: WebIdentityTokenCredentialProperties) => + for _ <- captureRequestsTo match + case Some(ref) => ref.update(_ :+ config) + case None => IO.unit + yield + if shouldFail then throw new SdkClientException(errorMessage) + else + response.getOrElse( + AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = false, + providerName = None, + accountId = Some("123456789012"), + expirationTime = Some(futureExpiration) + ) ) - ) test("resolveCredentials with successful Web Identity Token authentication") { val envVars = Map( @@ -122,9 +122,9 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) provider.resolveCredentials().map { - case session: AwsSessionCredentials => - assertEquals(session.accessKeyId, testAccessKeyId) - case _ => fail("Expected AwsSessionCredentials") + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") } } @@ -133,7 +133,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, "AWS_ROLE_ARN" -> testRoleArn ) - + val sysProps = Map( "aws.webIdentityTokenFile" -> "/wrong/path", "aws.roleArn" -> "arn:aws:iam::999999999999:role/wrong-role" @@ -166,7 +166,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_ROLE_ARN" -> testRoleArn ) - given Env[IO] = mockEnv(envVars) + given Env[IO] = mockEnv(envVars) given SystemProperties[IO] = mockSystemProperties() val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) @@ -191,7 +191,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_ROLE_ARN" -> testRoleArn ) - given Env[IO] = mockEnv(envVars) + given Env[IO] = mockEnv(envVars) given SystemProperties[IO] = mockSystemProperties() val webIdentityUtils = mockWebIdentityUtils() @@ -210,7 +210,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile ) - given Env[IO] = mockEnv(envVars) + given Env[IO] = mockEnv(envVars) given SystemProperties[IO] = mockSystemProperties() val webIdentityUtils = mockWebIdentityUtils() @@ -225,16 +225,16 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: } test("fail when both token file and role ARN are missing") { - given Env[IO] = mockEnv(Map.empty) + given Env[IO] = mockEnv(Map.empty) given SystemProperties[IO] = mockSystemProperties() val webIdentityUtils = mockWebIdentityUtils() val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) provider.resolveCredentials().attempt.map { - case Left(exception: SdkClientException) => - assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) - case _ => fail("Expected SdkClientException") + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + case _ => fail("Expected SdkClientException") } } @@ -244,7 +244,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_ROLE_ARN" -> testRoleArn ) - given Env[IO] = mockEnv(envVars) + given Env[IO] = mockEnv(envVars) given SystemProperties[IO] = mockSystemProperties() val webIdentityUtils = mockWebIdentityUtils(shouldFail = true, errorMessage = "STS service unavailable") @@ -263,7 +263,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_ROLE_ARN" -> testRoleArn ) - given Env[IO] = mockEnv(envVars) + given Env[IO] = mockEnv(envVars) given SystemProperties[IO] = mockSystemProperties() val webIdentityUtils = mockWebIdentityUtils() @@ -285,7 +285,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: given SystemProperties[IO] = mockSystemProperties() assertIOBoolean( - WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), + WebIdentityTokenFileCredentialsProvider.isAvailable[IO]() ) } @@ -349,7 +349,7 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: "AWS_ROLE_SESSION_NAME" -> s" $testSessionName " ) - given Env[IO] = mockEnv(envVars) + given Env[IO] = mockEnv(envVars) given SystemProperties[IO] = mockSystemProperties() val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) @@ -369,4 +369,4 @@ class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: assertEquals(request.webIdentityTokenFile.toString, testTokenFile) assertEquals(request.roleArn, testRoleArn) assertEquals(request.roleSessionName, Some(testSessionName)) - } \ No newline at end of file + } From 1ed759ac7be05b52794dfa8888eb214339c3be46 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 19:06:00 +0900 Subject: [PATCH 151/215] Create WebIdentityCredentialsUtils Test --- .../WebIdentityCredentialsUtilsTest.scala | 353 ++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala new file mode 100644 index 000000000..88ef3e5f1 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala @@ -0,0 +1,353 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials.internal + +import java.net.URI +import java.time.Instant + +import cats.effect.{ IO, Ref } +import cats.effect.std.UUIDGen + +import fs2.io.file.Files + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.{ AwsSessionCredentials, WebIdentityTokenCredentialProperties } +import ldbc.amazon.client.{ HttpClient, HttpResponse, StsClient } +import ldbc.amazon.exception.{ InvalidTokenException, StsException } + +class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: + + // Test fixtures + private val testRoleArn = "arn:aws:iam::123456789012:role/test-role" + private val testSessionName = "test-session" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val futureExpiration = Instant.now().plusSeconds(3600) + private val testAssumedRoleArn = "arn:aws:sts::123456789012:assumed-role/test-role/test-session" + + // Valid JWT token (simplified for testing) + private val validJwtToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature" + + // For these tests, we'll focus on mocking the StsClient behavior since Files[IO] is sealed + // We'll create temporary files for file-based tests when needed + + // Mock StsClient + private def mockStsClient( + response: Option[StsClient.AssumeRoleWithWebIdentityResponse] = None, + shouldFail: Boolean = false, + errorMessage: String = "Mock STS error", + captureRequestsTo: Option[Ref[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]]] = None + ): StsClient[IO] = + (request: StsClient.AssumeRoleWithWebIdentityRequest) => for + _ <- captureRequestsTo match + case Some(ref) => ref.update(_ :+ request) + case None => IO.unit + yield + if shouldFail then throw new StsException(errorMessage) + else + response.getOrElse( + StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = testAssumedRoleArn + ) + ) + + test("assumeRoleWithWebIdentity with successful flow") { + // Create a temporary file with the JWT token for testing + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) + val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + requests <- requestCapture.get + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, None) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.expirationTime, Some(futureExpiration)) + case _ => fail("Expected AwsSessionCredentials") + + // Verify STS request was made correctly + assertEquals(requests.length, 1) + val request = requests.head + assertEquals(request.roleArn, testRoleArn) + assertEquals(request.webIdentityToken, validJwtToken) + assertEquals(request.roleSessionName, Some(testSessionName)) + } + } + + // Focus on tests that can be validated without complex file mocking + // JWT validation tests use temporary files for realistic testing + + test("fail when JWT token has invalid format - too few parts") { + val invalidToken = "header.payload" // Missing signature part + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient() + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("Invalid JWT token format")) + assert(exception.getMessage.contains("Expected 3 parts, got 2")) + case _ => fail("Expected InvalidTokenException") + } + } + + test("fail when JWT token has invalid format - too many parts") { + val invalidToken = "header.payload.signature.extra" // Too many parts + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient() + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("Invalid JWT token format")) + assert(exception.getMessage.contains("Expected 3 parts, got 4")) + case _ => fail("Expected InvalidTokenException") + } + } + + test("fail when JWT token has empty parts") { + val invalidToken = "header..signature" // Empty payload + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient() + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("JWT token contains empty parts")) + case _ => fail("Expected InvalidTokenException") + } + } + + test("fail when STS client throws exception") { + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient(shouldFail = true, errorMessage = "STS service unavailable") + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + result match + case Left(exception: StsException) => + assertEquals(exception.getMessage, "STS service unavailable") + case _ => fail("Expected StsException") + } + } + + test("extract account ID from ARN correctly") { + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val customAssumedRoleArn = "arn:aws:sts::999888777666:assumed-role/my-role/my-session" + val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = customAssumedRoleArn + ) + + val stsClient = mockStsClient(response = Some(customResponse)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accountId, Some("999888777666")) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("handle malformed ARN gracefully") { + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val malformedArn = "invalid:arn:format" + val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = malformedArn + ) + + val stsClient = mockStsClient(response = Some(customResponse)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accountId, None) // Should be None for malformed ARN + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("default factory creates WebIdentityCredentialsUtils with proper STS client") { + val mockHttpClient: HttpClient[IO] = new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("HTTP requests not supported in test")) + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("HTTP requests not supported in test")) + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("HTTP requests not supported in test")) + + val webIdentityUtils = WebIdentityCredentialsUtils.default[IO]("us-east-1", mockHttpClient) + assert(webIdentityUtils != null) + } + + test("reads token from file with correct trimming") { + val tokenWithWhitespace = s" $validJwtToken \n\t" + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(tokenWithWhitespace)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) + val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + requests <- requestCapture.get + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify the token was trimmed correctly + assertEquals(requests.length, 1) + val request = requests.head + assertEquals(request.webIdentityToken, validJwtToken) // Should be trimmed + } + } From 71643d384d66ee47f6be1017e178d74c9fa9a81c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 19:24:46 +0900 Subject: [PATCH 152/215] Action sbt scalafmtAll --- .../WebIdentityCredentialsUtilsTest.scala | 191 +++++++++--------- 1 file changed, 93 insertions(+), 98 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala index 88ef3e5f1..f62c02595 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala @@ -32,41 +32,42 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: private val testAssumedRoleArn = "arn:aws:sts::123456789012:assumed-role/test-role/test-session" // Valid JWT token (simplified for testing) - private val validJwtToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature" + private val validJwtToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature" // For these tests, we'll focus on mocking the StsClient behavior since Files[IO] is sealed // We'll create temporary files for file-based tests when needed // Mock StsClient private def mockStsClient( - response: Option[StsClient.AssumeRoleWithWebIdentityResponse] = None, - shouldFail: Boolean = false, - errorMessage: String = "Mock STS error", - captureRequestsTo: Option[Ref[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]]] = None + response: Option[StsClient.AssumeRoleWithWebIdentityResponse] = None, + shouldFail: Boolean = false, + errorMessage: String = "Mock STS error", + captureRequestsTo: Option[Ref[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]]] = None ): StsClient[IO] = - (request: StsClient.AssumeRoleWithWebIdentityRequest) => for - _ <- captureRequestsTo match - case Some(ref) => ref.update(_ :+ request) - case None => IO.unit - yield - if shouldFail then throw new StsException(errorMessage) - else - response.getOrElse( - StsClient.AssumeRoleWithWebIdentityResponse( - accessKeyId = testAccessKeyId, - secretAccessKey = testSecretAccessKey, - sessionToken = testSessionToken, - expiration = futureExpiration, - assumedRoleArn = testAssumedRoleArn + (request: StsClient.AssumeRoleWithWebIdentityRequest) => + for _ <- captureRequestsTo match + case Some(ref) => ref.update(_ :+ request) + case None => IO.unit + yield + if shouldFail then throw new StsException(errorMessage) + else + response.getOrElse( + StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = testAssumedRoleArn + ) ) - ) test("assumeRoleWithWebIdentity with successful flow") { // Create a temporary file with the JWT token for testing val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -76,14 +77,14 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) - val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) + val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) + val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) for credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) requests <- requestCapture.get - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup yield credentials match case session: AwsSessionCredentials => @@ -110,11 +111,11 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: test("fail when JWT token has invalid format - too few parts") { val invalidToken = "header.payload" // Missing signature part - + val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -124,28 +125,27 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val stsClient = mockStsClient() + val stsClient = mockStsClient() val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) - for + for result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup - yield - result match - case Left(exception: InvalidTokenException) => - assert(exception.getMessage.contains("Invalid JWT token format")) - assert(exception.getMessage.contains("Expected 3 parts, got 2")) - case _ => fail("Expected InvalidTokenException") + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("Invalid JWT token format")) + assert(exception.getMessage.contains("Expected 3 parts, got 2")) + case _ => fail("Expected InvalidTokenException") } } test("fail when JWT token has invalid format - too many parts") { val invalidToken = "header.payload.signature.extra" // Too many parts - + val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -155,28 +155,27 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val stsClient = mockStsClient() + val stsClient = mockStsClient() val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) - for + for result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup - yield - result match - case Left(exception: InvalidTokenException) => - assert(exception.getMessage.contains("Invalid JWT token format")) - assert(exception.getMessage.contains("Expected 3 parts, got 4")) - case _ => fail("Expected InvalidTokenException") + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("Invalid JWT token format")) + assert(exception.getMessage.contains("Expected 3 parts, got 4")) + case _ => fail("Expected InvalidTokenException") } } test("fail when JWT token has empty parts") { val invalidToken = "header..signature" // Empty payload - + val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -186,25 +185,24 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val stsClient = mockStsClient() + val stsClient = mockStsClient() val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) - for + for result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup - yield - result match - case Left(exception: InvalidTokenException) => - assert(exception.getMessage.contains("JWT token contains empty parts")) - case _ => fail("Expected InvalidTokenException") + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("JWT token contains empty parts")) + case _ => fail("Expected InvalidTokenException") } } test("fail when STS client throws exception") { val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -214,25 +212,24 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val stsClient = mockStsClient(shouldFail = true, errorMessage = "STS service unavailable") + val stsClient = mockStsClient(shouldFail = true, errorMessage = "STS service unavailable") val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) - for + for result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup - yield - result match - case Left(exception: StsException) => - assertEquals(exception.getMessage, "STS service unavailable") - case _ => fail("Expected StsException") + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: StsException) => + assertEquals(exception.getMessage, "STS service unavailable") + case _ => fail("Expected StsException") } } test("extract account ID from ARN correctly") { val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -243,7 +240,7 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: ) val customAssumedRoleArn = "arn:aws:sts::999888777666:assumed-role/my-role/my-session" - val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( + val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, sessionToken = testSessionToken, @@ -251,25 +248,24 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: assumedRoleArn = customAssumedRoleArn ) - val stsClient = mockStsClient(response = Some(customResponse)) + val stsClient = mockStsClient(response = Some(customResponse)) val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) - for + for credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup - yield - credentials match - case session: AwsSessionCredentials => - assertEquals(session.accountId, Some("999888777666")) - case _ => fail("Expected AwsSessionCredentials") + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accountId, Some("999888777666")) + case _ => fail("Expected AwsSessionCredentials") } } test("handle malformed ARN gracefully") { val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -279,7 +275,7 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val malformedArn = "invalid:arn:format" + val malformedArn = "invalid:arn:format" val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( accessKeyId = testAccessKeyId, secretAccessKey = testSecretAccessKey, @@ -288,17 +284,16 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: assumedRoleArn = malformedArn ) - val stsClient = mockStsClient(response = Some(customResponse)) + val stsClient = mockStsClient(response = Some(customResponse)) val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) - for + for credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup - yield - credentials match - case session: AwsSessionCredentials => - assertEquals(session.accountId, None) // Should be None for malformed ARN - case _ => fail("Expected AwsSessionCredentials") + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accountId, None) // Should be None for malformed ARN + case _ => fail("Expected AwsSessionCredentials") } } @@ -317,11 +312,11 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: test("reads token from file with correct trimming") { val tokenWithWhitespace = s" $validJwtToken \n\t" - + val tempFile = for - tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) - _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(tokenWithWhitespace)).compile.drain + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(tokenWithWhitespace)).compile.drain yield tokenFile tempFile.flatMap { tokenFilePath => @@ -331,14 +326,14 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: roleSessionName = Some(testSessionName) ) - val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) - val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) + val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) + val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) for credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) requests <- requestCapture.get - _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup yield credentials match case session: AwsSessionCredentials => From 3b056818c1e13989f8f143c41eea2b0e84b63141 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 20:18:43 +0900 Subject: [PATCH 153/215] Added scaladoc --- .../credentials/AwsBasicCredentials.scala | 14 ++ .../credentials/AwsSessionCredentials.scala | 16 ++ .../ContainerCredentialsProvider.scala | 203 ++++++++++++++++ .../DefaultCredentialsProviderChain.scala | 81 ++++++ ...vironmentVariableCredentialsProvider.scala | 31 ++- .../InstanceProfileCredentialsProvider.scala | 95 ++++++++ .../ProfileCredentialsProvider.scala | 59 +++++ .../SystemPropertyCredentialsProvider.scala | 31 ++- ...IdentityTokenFileCredentialsProvider.scala | 118 +++++++++ .../SystemSettingsCredentialsProvider.scala | 113 +++++++++ .../WebIdentityCredentialsUtils.scala | 12 + .../scala/ldbc/amazon/client/HttpClient.scala | 102 ++++++++ .../ldbc/amazon/client/HttpResponse.scala | 46 ++++ .../scala/ldbc/amazon/client/StsClient.scala | 41 ++++ .../exception/CredentialsFetchError.scala | 51 ++++ .../exception/InvalidTokenException.scala | 36 ++- .../amazon/exception/SdkClientException.scala | 16 ++ .../ldbc/amazon/exception/StsException.scala | 17 +- .../exception/TokenFileAccessException.scala | 66 ++++- .../TokenFileNotFoundException.scala | 63 ++++- .../DefaultAwsCredentialsIdentity.scala | 44 ++++ .../useragent/BusinessMetricFeatureId.scala | 230 ++++++++++++++---- .../ldbc/amazon/util/SimpleJsonParser.scala | 38 ++- .../ldbc/amazon/util/SimpleXmlParser.scala | 57 +++++ .../ldbc/amazon/util/SystemSetting.scala | 39 +++ 25 files changed, 1526 insertions(+), 93 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala index 8e0566e37..ecba20f0c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala @@ -10,6 +10,20 @@ import java.time.Instant import ldbc.amazon.identity.AwsCredentials +/** + * Basic AWS credentials implementation that contains an access key ID and secret access key. + * + * This implementation is used for standard AWS access credentials that consist of a public + * access key ID and a private secret access key. Basic credentials do not include session + * tokens and are typically used for long-term access. + * + * @param accessKeyId The AWS access key ID used for authentication + * @param secretAccessKey The AWS secret access key used for signing requests + * @param validateCredentials Whether these credentials should be validated before use + * @param providerName Optional name of the credentials provider that created these credentials + * @param accountId Optional AWS account ID associated with these credentials + * @param expirationTime Optional expiration time for these credentials, if applicable + */ final case class AwsBasicCredentials( accessKeyId: String, secretAccessKey: String, diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala index 5e9e587cd..679ef3dee 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala @@ -10,6 +10,22 @@ import java.time.Instant import ldbc.amazon.identity.AwsCredentials +/** + * AWS session credentials implementation that includes a session token. + * + * Session credentials are temporary credentials that include an access key ID, + * secret access key, and a session token. These are typically obtained from + * AWS Security Token Service (STS) and have a limited lifetime. They are commonly + * used for temporary access, role-based access, or federated access scenarios. + * + * @param accessKeyId The AWS access key ID used for authentication + * @param secretAccessKey The AWS secret access key used for signing requests + * @param sessionToken The session token that must be included with requests + * @param validateCredentials Whether these credentials should be validated before use + * @param providerName Optional name of the credentials provider that created these credentials + * @param accountId Optional AWS account ID associated with these credentials + * @param expirationTime Optional expiration time for these credentials + */ final case class AwsSessionCredentials( accessKeyId: String, secretAccessKey: String, diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala index 8ba167183..22eb519c5 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -52,6 +52,21 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( httpClient: HttpClient[F] ) extends AwsCredentialsProvider[F]: + /** + * Resolves AWS credentials from the container credential provider endpoint. + * + * This method implements the core logic for retrieving temporary AWS credentials + * from container-based credential providers such as ECS Task Roles or EKS Pod Identity. + * + * The resolution process: + * 1. Load container credentials configuration from environment variables + * 2. Validate that required environment variables are present + * 3. Make HTTP request to the credential endpoint + * 4. Parse and return the temporary credentials + * + * @return F[AwsCredentials] The resolved AWS credentials with access key, secret key, and session token + * @throws SdkClientException if container credentials configuration is missing or credential retrieval fails + */ override def resolveCredentials(): F[AwsCredentials] = for config <- loadContainerCredentialsConfig() @@ -69,6 +84,25 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( } yield credentials + /** + * Loads container credentials configuration from environment variables. + * + * This method reads configuration from the standard AWS environment variables used + * by container credential providers and constructs the appropriate endpoint configuration. + * + * Environment variables checked (in priority order): + * - AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: Relative URI path (used with ECS metadata endpoint) + * - AWS_CONTAINER_CREDENTIALS_FULL_URI: Complete URI (used with custom endpoints like EKS Pod Identity) + * - AWS_CONTAINER_AUTHORIZATION_TOKEN: Direct authorization token for requests + * - AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE: Path to file containing authorization token + * + * When AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is present, the full endpoint URI is constructed + * as: http://169.254.170.2 + relative_uri (ECS metadata endpoint) + * + * When AWS_CONTAINER_CREDENTIALS_FULL_URI is present, it is used as-is (typically for EKS Pod Identity) + * + * @return F[Option[ContainerCredentialsConfig]] Container configuration if environment variables are properly set, None otherwise + */ private def loadContainerCredentialsConfig(): F[Option[ContainerCredentialsConfig]] = for relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") @@ -94,6 +128,23 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( case _ => None } + /** + * Loads the authorization token from environment variables or file. + * + * This method resolves the authorization token needed for container credential requests + * by checking both direct token environment variable and token file path. + * + * Token resolution priority: + * 1. AWS_CONTAINER_AUTHORIZATION_TOKEN (direct token value) + * 2. AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE (path to file containing token) + * + * The authorization token is typically used in EKS environments where the token + * is provided by the Kubernetes service account system. + * + * @param directToken Optional token value from AWS_CONTAINER_AUTHORIZATION_TOKEN + * @param tokenFilePath Optional file path from AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE + * @return F[Option[String]] The authorization token if available, None otherwise + */ private def loadAuthorizationToken( directToken: Option[String], tokenFilePath: Option[String] @@ -107,6 +158,26 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( Concurrent[F].pure(None) } + /** + * Loads authorization token from a file path. + * + * This method safely reads the authorization token from the specified file path, + * handling potential file system errors gracefully. This is commonly used in + * Kubernetes environments where tokens are mounted as files in the pod. + * + * The method performs the following operations: + * 1. Check if the file exists + * 2. Read the file content as UTF-8 text + * 3. Trim whitespace and validate the token is non-empty + * 4. Handle any file I/O errors by returning None + * + * Common token file locations: + * - /var/run/secrets/eks.amazonaws.com/serviceaccount/token (EKS IRSA) + * - /var/run/secrets/kubernetes.io/serviceaccount/token (Standard Kubernetes SA token) + * + * @param tokenFilePath Path to the token file + * @return F[Option[String]] The token content if file exists and is readable, None otherwise + */ private def loadTokenFromFile(tokenFilePath: Path): F[Option[String]] = Files[F].exists(tokenFilePath).flatMap { exists => if exists then { @@ -124,6 +195,30 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( } } + /** + * Fetches AWS credentials from the configured container endpoint. + * + * This method performs the actual HTTP request to retrieve temporary AWS credentials + * from the container credential provider endpoint. The request includes appropriate + * headers and authorization tokens when required. + * + * Request flow: + * 1. Build HTTP headers with authorization token (if present) + * 2. Make HTTP GET request to the credential endpoint + * 3. Validate the HTTP response status + * 4. Parse the JSON response to extract credentials + * + * The endpoint returns a JSON response containing: + * - AccessKeyId: Temporary access key + * - SecretAccessKey: Temporary secret key + * - Token: Session token + * - Expiration: RFC3339 formatted expiration timestamp + * - RoleArn: ARN of the assumed role (optional) + * + * @param config Container credentials configuration with endpoint URI and auth token + * @return F[AwsCredentials] The temporary AWS credentials from the endpoint + * @throws SdkClientException if the HTTP request fails or response parsing fails + */ private def fetchCredentialsFromEndpoint(config: ContainerCredentialsConfig): F[AwsCredentials] = val headers = buildRequestHeaders(config.authorizationToken) for @@ -132,6 +227,23 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( credentials <- parseCredentialsResponse(response.body) yield credentials + /** + * Builds HTTP request headers for container credential requests. + * + * This method constructs the necessary HTTP headers for credential endpoint requests, + * including content type specification and authorization token when available. + * + * Standard headers included: + * - Accept: application/json (indicates we expect JSON response) + * - User-Agent: aws-sdk-scala/ldbc (identifies the client SDK) + * - Authorization: token value (when authorization token is available) + * + * The Authorization header is only included when an authorization token is provided, + * which is typically required for EKS Pod Identity but not for ECS Task Roles. + * + * @param authToken Optional authorization token for the request + * @return Map[String, String] HTTP headers for the credential request + */ private def buildRequestHeaders(authToken: Option[String]): Map[String, String] = val baseHeaders = Map( "Accept" -> "application/json", @@ -142,6 +254,25 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( case None => baseHeaders } + /** + * Validates the HTTP response from the container credential endpoint. + * + * This method checks that the credential endpoint returned a successful HTTP status code + * (2xx range) and raises an exception for any error responses. + * + * Common error scenarios: + * - 401 Unauthorized: Invalid or missing authorization token + * - 403 Forbidden: Insufficient permissions for the requested role + * - 404 Not Found: Invalid endpoint URI or credential path + * - 500 Internal Server Error: Credential service internal error + * + * The error message includes both the HTTP status code and response body + * to provide detailed information for debugging credential issues. + * + * @param response HTTP response from the credential endpoint + * @return F[Unit] Success if status code is 2xx, failure otherwise + * @throws SdkClientException if the HTTP status code indicates an error + */ private def validateHttpResponse(response: ldbc.amazon.client.HttpResponse): F[Unit] = if response.statusCode >= 200 && response.statusCode < 300 then { Concurrent[F].unit @@ -153,6 +284,35 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( ) } + /** + * Parses the JSON credentials response from the container endpoint. + * + * This method deserializes the JSON response containing temporary AWS credentials + * and converts it into an AwsSessionCredentials instance. The response follows + * the standard AWS credential response format. + * + * Expected JSON structure: + * ```json + * { + * "AccessKeyId": "ASIA...", + * "SecretAccessKey": "...", + * "Token": "...", + * "Expiration": "2024-01-15T12:00:00Z", + * "RoleArn": "arn:aws:iam::123456789012:role/MyRole" // optional + * } + * ``` + * + * The method performs the following operations: + * 1. Parse the JSON response using SimpleJsonParser + * 2. Extract required fields (AccessKeyId, SecretAccessKey, Token, Expiration) + * 3. Parse the expiration timestamp in RFC3339 format + * 4. Extract the AWS account ID from the role ARN (if present) + * 5. Create AwsSessionCredentials with provider metadata + * + * @param jsonBody JSON response body from the credential endpoint + * @return F[AwsCredentials] Parsed AWS session credentials + * @throws SdkClientException if JSON parsing fails or required fields are missing + */ private def parseCredentialsResponse(jsonBody: String): F[AwsCredentials] = Concurrent[F] .fromEither( @@ -178,6 +338,27 @@ final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( new SdkClientException(s"Failed to parse container credentials response: ${ ex.getMessage }") } + /** + * Extracts the AWS account ID from an IAM role ARN. + * + * This method parses the role ARN returned by container credential endpoints + * to extract the AWS account ID, which is useful for credential metadata. + * + * ARN format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + * Example: arn:aws:iam::123456789012:role/EKS-Pod-Identity-Role + * + * The method splits the ARN by colons and extracts the account ID from the 5th position + * (0-indexed position 4). If the ARN format is invalid or the account ID position + * is not available, None is returned. + * + * This information is particularly useful for: + * - Auditing and logging credential usage + * - Validating that credentials are from the expected AWS account + * - Supporting multi-account AWS environments + * + * @param roleArn Optional IAM role ARN from the credential response + * @return Option[String] The AWS account ID if extractable from the ARN, None otherwise + */ private def extractAccountIdFromRoleArn(roleArn: Option[String]): Option[String] = roleArn.flatMap { arn => // ARN format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME @@ -217,7 +398,29 @@ private case class ContainerCredentialsResponse( RoleArn: Option[String] = None ) +/** + * Companion object for parsing JSON responses from container credential endpoints. + */ private object ContainerCredentialsResponse: + + /** + * Parses a JSON object into a ContainerCredentialsResponse. + * + * This method deserializes the standard AWS container credentials JSON response format, + * extracting all required fields and optional metadata. + * + * Required JSON fields: + * - AccessKeyId: The temporary AWS access key ID + * - SecretAccessKey: The temporary AWS secret access key + * - Token: The session token for temporary credentials + * - Expiration: RFC3339 formatted timestamp indicating when credentials expire + * + * Optional JSON fields: + * - RoleArn: The ARN of the IAM role that was assumed to generate these credentials + * + * @param json The parsed JSON object from the credential endpoint response + * @return Either[String, ContainerCredentialsResponse] Success with parsed response or error message + */ def fromJson(json: SimpleJsonParser.JsonObject): Either[String, ContainerCredentialsResponse] = for accessKeyId <- json.require("AccessKeyId") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index ac36a9d0a..7800da9bc 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -49,6 +49,32 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: UUIDGe region: String ) extends AwsCredentialsProvider[F]: + /** + * Lazily initialized list of AWS credential providers in order of precedence. + * + * This method creates the standard AWS SDK credential provider chain that matches + * the behavior of AWS SDK for Java v2. The providers are ordered by precedence, + * with more specific/explicit credential sources taking priority over implicit ones. + * + * Provider chain order: + * 1. SystemPropertyCredentialsProvider - Java system properties (aws.accessKeyId, aws.secretAccessKey) + * 2. EnvironmentVariableCredentialsProvider - Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) + * 3. WebIdentityTokenFileCredentialsProvider - OIDC token authentication (Kubernetes IRSA) + * 4. ProfileCredentialsProvider - AWS credentials file (~/.aws/credentials) + * 5. ContainerCredentialsProvider - ECS task roles and EKS pod identity + * 6. InstanceProfileCredentialsProvider - EC2 instance profile (IMDS) + * + * Each provider is only consulted if the previous providers fail to provide credentials. + * This ordering ensures that explicit credentials (environment variables, system properties) + * take precedence over implicit credentials (instance profiles, container roles). + * + * Performance considerations: + * - ProfileCredentialsProvider and InstanceProfileCredentialsProvider require initialization + * - The list is lazily computed to avoid unnecessary initialization overhead + * - Failed providers are cached to avoid repeated initialization attempts + * + * @return F[List[AwsCredentialsProvider[F]]] The ordered list of credential providers + */ private lazy val providers: F[List[AwsCredentialsProvider[F]]] = for profileProvider <- ProfileCredentialsProvider.default[F]() @@ -62,12 +88,67 @@ class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: UUIDGe instanceProfileCredentialsProvider ) + /** + * Resolves AWS credentials by trying providers in the default chain order. + * + * This method implements the AWS SDK v2 default credential provider chain behavior, + * attempting to resolve credentials from each provider in sequence until one succeeds + * or all providers are exhausted. + * + * Resolution process: + * 1. Initialize the provider list (lazy evaluation) + * 2. Try each provider in order, starting with highest precedence + * 3. Return credentials from the first successful provider + * 4. If all providers fail, throw an exception with aggregated error messages + * + * The first successful provider "wins" and the chain stops there. This behavior + * ensures predictable credential resolution and prevents unexpected credential + * source changes during application runtime. + * + * Common resolution scenarios: + * - Development: Environment variables or system properties + * - CI/CD: Web Identity tokens or environment variables + * - ECS: Container credentials provider + * - EKS: Web Identity tokens (IRSA) or container credentials + * - EC2: Instance profile credentials + * - Local AWS CLI: Profile credentials from ~/.aws/credentials + * + * @return F[AwsCredentials] The resolved AWS credentials from the first successful provider + * @throws SdkClientException if no provider in the chain can provide valid credentials + */ override def resolveCredentials(): F[AwsCredentials] = for providerList <- providers credentials <- tryProvidersInOrder(providerList, Nil) yield credentials + /** + * Attempts to resolve credentials from providers in order, handling failures gracefully. + * + * This method implements the recursive credential resolution logic for the provider chain. + * It tries each provider sequentially and accumulates error messages for debugging purposes. + * + * Error handling strategy: + * - Each provider failure is caught and logged for debugging + * - Failures are expected and normal behavior (e.g., no ~/.aws/credentials file) + * - Only when all providers fail is an exception raised + * - Error messages from all providers are included in the final exception + * + * The method maintains a list of error messages to provide comprehensive debugging + * information when credential resolution completely fails. This helps developers + * understand why credential resolution failed across all providers. + * + * Example error scenarios: + * - SystemPropertyCredentialsProvider: "aws.accessKeyId system property not set" + * - EnvironmentVariableCredentialsProvider: "AWS_ACCESS_KEY_ID environment variable not set" + * - ProfileCredentialsProvider: "Unable to load credentials from profile 'default'" + * - InstanceProfileCredentialsProvider: "Unable to retrieve credentials from IMDS" + * + * @param providers Remaining providers to try in the chain + * @param exceptions Accumulated error messages from failed providers + * @return F[AwsCredentials] The credentials from the first successful provider + * @throws SdkClientException if all providers in the list fail to provide credentials + */ private def tryProvidersInOrder( providers: List[AwsCredentialsProvider[F]], exceptions: List[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala index d05ca9f0a..bf0e282e4 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala @@ -15,13 +15,36 @@ import ldbc.amazon.useragent.BusinessMetricFeatureId import ldbc.amazon.util.SdkSystemSetting /** - * [[AwsCredentialsProvider]] implementation that loads credentials from the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and - * AWS_SESSION_TOKEN environment variables. + * AWS credentials provider that loads credentials from environment variables. + * + * This provider looks for AWS credentials in the following environment variables: + * - AWS_ACCESS_KEY_ID: The AWS access key ID + * - AWS_SECRET_ACCESS_KEY: The AWS secret access key + * - AWS_SESSION_TOKEN: Optional session token for temporary credentials + * + * This provider only checks environment variables and does not fallback to system properties. + * Customers can use this provider when they want to explicitly source credentials from + * environment variables only. + * + * @tparam F The effect type that supports environment variable access and error handling */ final class EnvironmentVariableCredentialsProvider[F[_]: Env: MonadThrow] extends SystemSettingsCredentialsProvider[F]: - // Customers should be able to specify a credentials provider that only looks at the environment variables, - // but not the system properties. For that reason, we're only checking the environment variables here. + /** + * Loads a setting value from environment variables only. + * + * This implementation specifically looks at environment variables and does not check + * system properties, allowing customers to specify a credentials provider that only + * uses environment variables. + * + * @param setting The system setting to load + * @return An effect containing the optional setting value from environment variables + */ override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = Env[F].get(setting.toString) + /** + * Returns the provider identifier for business metrics tracking. + * + * @return The business metric feature ID for environment variable credentials + */ override def provider: String = BusinessMetricFeatureId.CREDENTIALS_ENV_VARS.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala index 9c362e867..45278e7f8 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -60,6 +60,17 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( private val METADATA_TOKEN_TTL_SECONDS = 21600 // 6 hours private val CREDENTIAL_REFRESH_BUFFER = 4.minutes + /** + * Resolves AWS credentials from EC2 instance metadata service. + * + * This method performs the following steps: + * 1. Checks if the metadata service is disabled via environment variable + * 2. Retrieves cached credentials if they are still valid + * 3. Refreshes credentials from IMDS if needed + * + * @return AWS credentials from the instance profile + * @throws SdkClientException if metadata service is disabled or unreachable + */ override def resolveCredentials(): F[AwsCredentials] = for disabled <- checkIfDisabled() @@ -75,12 +86,32 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( } yield credentials + /** + * Checks if the EC2 metadata service is disabled via environment variable. + * + * The AWS_EC2_METADATA_DISABLED environment variable can be set to "true" + * to disable this credentials provider. + * + * @return true if the metadata service is disabled, false otherwise + */ private def checkIfDisabled(): F[Boolean] = Env[F].get("AWS_EC2_METADATA_DISABLED").map { case Some(value) => value.toLowerCase == "true" case None => false } + /** + * Refreshes credentials from the EC2 instance metadata service. + * + * This method performs the complete IMDS authentication flow: + * 1. Gets the IMDS endpoint (default or custom) + * 2. Attempts to acquire an IMDSv2 token (falls back to IMDSv1 if it fails) + * 3. Retrieves the available IAM role name + * 4. Fetches credentials for the role + * 5. Caches the credentials with timestamp + * + * @return Fresh AWS credentials from IMDS + */ private def refreshCredentials(): F[AwsCredentials] = for endpoint <- getImdsEndpoint() @@ -91,12 +122,31 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( _ <- credentialsRef.set(Some(cached)) yield credentials + /** + * Gets the IMDS endpoint URL from environment variable or default. + * + * The AWS_EC2_METADATA_SERVICE_ENDPOINT environment variable can be used + * to override the default IMDS endpoint (useful for testing or proxies). + * + * @return The IMDS endpoint URL without trailing slash + */ private def getImdsEndpoint(): F[String] = Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").map { case Some(endpoint) => endpoint.stripSuffix("/") case None => DEFAULT_IMD_SEND_POINT } + /** + * Acquires an IMDSv2 session token for enhanced security. + * + * IMDSv2 requires a session token to access metadata, which helps prevent + * Server-Side Request Forgery (SSRF) attacks. The token is obtained via + * a PUT request with a TTL header. + * + * @param endpoint The IMDS endpoint URL + * @return A session token valid for the configured TTL period + * @throws SdkClientException if token acquisition fails + */ private def acquireMetadataToken(endpoint: String): F[String] = val tokenUrl = s"$endpoint/latest/api/token" val headers = Map( @@ -108,6 +158,18 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( _ <- validateHttpResponse(response, "Failed to acquire metadata token") yield response.body.trim + /** + * Retrieves the IAM role name associated with the instance. + * + * This method lists all available IAM roles from the metadata service + * and selects the first one. In practice, EC2 instances typically have + * only one attached IAM role. + * + * @param endpoint The IMDS endpoint URL + * @param token Optional IMDSv2 session token + * @return The name of the IAM role attached to the instance + * @throws SdkClientException if no roles are available or request fails + */ private def getRoleName(endpoint: String, token: Option[String]): F[String] = val roleUrl = s"$endpoint/latest/meta-data/iam/security-credentials/" val headers = buildRequestHeaders(token) @@ -118,6 +180,19 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( roleName <- parseRoleListResponse(response.body) yield roleName + /** + * Retrieves AWS credentials for the specified IAM role. + * + * This method fetches the temporary credentials (access key, secret key, + * and session token) associated with the IAM role. The credentials are + * returned in JSON format by the metadata service. + * + * @param endpoint The IMDS endpoint URL + * @param token Optional IMDSv2 session token + * @param roleName The name of the IAM role + * @return AWS credentials for the specified role + * @throws SdkClientException if credentials cannot be retrieved or parsed + */ private def getCredentialsForRole( endpoint: String, token: Option[String], @@ -132,6 +207,15 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( credentials <- parseCredentialsResponse(response.body, roleName) yield credentials + /** + * Builds HTTP request headers for IMDS requests. + * + * Creates appropriate headers for both IMDSv1 and IMDSv2 requests. + * When an IMDSv2 token is available, it's included in the headers. + * + * @param token Optional IMDSv2 session token + * @return Map of HTTP headers for the request + */ private def buildRequestHeaders(token: Option[String]): Map[String, String] = val baseHeaders = Map( "Accept" -> "application/json", @@ -142,6 +226,17 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( case None => baseHeaders } + /** + * Validates HTTP response status codes and provides appropriate error messages. + * + * This method interprets common HTTP status codes in the context of IMDS operations + * and provides meaningful error messages for debugging authentication issues. + * + * @param response The HTTP response to validate + * @param context A descriptive context for error messages + * @return Unit if the response is successful + * @throws SdkClientException with context-specific error message for failed responses + */ private def validateHttpResponse(response: HttpResponse, context: String): F[Unit] = response.statusCode match { case code if code >= 200 && code < 300 => Concurrent[F].unit diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala index 9a0fb1cc9..cfae8cbe1 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala @@ -21,6 +21,33 @@ import ldbc.amazon.identity.* import ProfileCredentialsProvider.* +/** + * AWS credentials provider that loads credentials from the AWS credentials file. + * + * This provider reads credentials from the standard AWS credentials file located at + * `~/.aws/credentials` and supports profile-based credential management. The credentials + * file uses INI format with profile sections containing access keys and other configuration. + * + * The provider implements caching and file change detection to avoid unnecessary file I/O + * and provides thread-safe access to credentials through semaphore-based synchronization. + * + * File format example: + * ``` + * [default] + * aws_access_key_id = AKIAIOSFODNN7EXAMPLE + * aws_secret_access_key = wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + * + * [production] + * aws_access_key_id = AKIAIOSFODNN7EXAMPLE + * aws_secret_access_key = wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + * aws_session_token = IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMQ== + * ``` + * + * @param profileName The name of the profile to load credentials from + * @param cacheRef Mutable reference for caching parsed credentials and file metadata + * @param semaphore Semaphore for controlling concurrent access to file operations + * @tparam F The effect type that supports file operations, system properties, and concurrency + */ final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent]( profileName: String, cacheRef: Ref[F, Option[(ProfileFile, AwsCredentials)]], @@ -28,6 +55,17 @@ final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent )(using ev: MonadThrow[F]) extends AwsCredentialsProvider[F]: + /** + * Resolves AWS credentials from the specified profile in the credentials file. + * + * This method implements intelligent caching by checking if the credentials file + * has been modified since the last read. If the file is unchanged, cached credentials + * are returned. Otherwise, the file is re-parsed and credentials are updated. + * + * @return AWS credentials from the specified profile + * @throws SdkClientException if the credentials file is not found, the profile + * doesn't exist, or required fields are missing + */ override def resolveCredentials(): F[AwsCredentials] = for currentFile <- loadFile @@ -38,6 +76,17 @@ final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent case _ => updateCredentials(currentFile) yield credentials + /** + * Loads and parses the AWS credentials file from the user's home directory. + * + * This method locates the credentials file at `~/.aws/credentials`, reads its contents, + * and parses the INI-format configuration into profile structures. File metadata + * including last modified time is tracked for cache invalidation. + * + * @return Parsed credentials file with profiles and metadata + * @throws SdkClientException if the home directory cannot be determined, the file + * doesn't exist, or the file format is invalid + */ private def loadFile: F[ProfileFile] = for homeOpt <- SystemProperties[F].get("user.home") @@ -50,6 +99,16 @@ final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent profiles <- ev.fromEither(parseProfiles(content)) yield ProfileFile(profiles, Instant.ofEpochMilli(lastMod.toMillis)) + /** + * Parses the INI-format credentials file content into profile structures. + * + * This method processes the credentials file line by line, extracting profile + * sections and their properties. It supports both `[profile name]` and `[name]` + * section headers and handles various property formats. + * + * @param content The raw content of the credentials file + * @return Either an error if parsing fails, or a map of profile names to profile data + */ private def parseProfiles(content: String): Either[Throwable, Map[String, Profile]] = val profilePattern = "\\[(?:profile\\s+)?(.+)\\]".r val propertyPattern = "^\\s*([^=]+)\\s*=\\s*(.+)\\s*$".r diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala index e691df18a..24d680b5e 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -15,15 +15,38 @@ import ldbc.amazon.useragent.BusinessMetricFeatureId import ldbc.amazon.util.SdkSystemSetting /** - * [[AwsCredentialsProvider]] implementation that loads credentials from the aws.accessKeyId, aws.secretAccessKey and - * aws.sessionToken system properties. + * AWS credentials provider that loads credentials from JVM system properties. + * + * This provider looks for AWS credentials in the following system properties: + * - aws.accessKeyId: The AWS access key ID + * - aws.secretAccessKey: The AWS secret access key + * - aws.sessionToken: Optional session token for temporary credentials + * + * This provider only checks system properties and does not fallback to environment variables. + * Customers can use this provider when they want to explicitly source credentials from + * JVM system properties only. + * + * @tparam F The effect type that supports system property access and error handling */ final class SystemPropertyCredentialsProvider[F[_]: SystemProperties: MonadThrow] extends SystemSettingsCredentialsProvider[F]: - // Customers should be able to specify a credentials provider that only looks at the system properties, - // but not the environment variables. For that reason, we're only checking the system properties here. + /** + * Loads a setting value from JVM system properties only. + * + * This implementation specifically looks at system properties and does not check + * environment variables, allowing customers to specify a credentials provider that only + * uses system properties. + * + * @param setting The system setting to load + * @return An effect containing the optional setting value from system properties + */ override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = SystemProperties[F].get(setting.systemProperty) + /** + * Returns the provider identifier for business metrics tracking. + * + * @return The business metric feature ID for JVM system property credentials + */ override def provider: String = BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala index 4f1f22788..9b7d3a017 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -58,6 +58,28 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: webIdentityUtils: WebIdentityCredentialsUtils[F] ) extends AwsCredentialsProvider[F]: + /** + * Resolves AWS credentials using Web Identity Token authentication. + * + * This method implements the OIDC (OpenID Connect) authentication flow used in + * containerized environments such as Kubernetes with IRSA (IAM Roles for Service Accounts) + * or other OIDC-enabled platforms. + * + * The resolution process: + * 1. Load Web Identity configuration from environment variables/system properties + * 2. Validate that required configuration (token file path and role ARN) is present + * 3. Use STS AssumeRoleWithWebIdentity to exchange the JWT token for AWS credentials + * 4. Return the temporary AWS credentials with session token + * + * This provider is commonly used in: + * - Amazon EKS with IAM Roles for Service Accounts (IRSA) + * - Self-managed Kubernetes clusters with OIDC providers + * - CI/CD environments with OIDC authentication + * - Serverless platforms with Web Identity token support + * + * @return F[AwsCredentials] The resolved AWS credentials with temporary access key, secret key, and session token + * @throws SdkClientException if Web Identity configuration is missing or credential exchange fails + */ override def resolveCredentials(): F[AwsCredentials] = for config <- loadWebIdentityConfig() @@ -75,6 +97,25 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: } yield credentials + /** + * Loads Web Identity Token credential configuration from environment variables and system properties. + * + * This method attempts to load the required configuration for Web Identity Token authentication + * by checking environment variables first, then falling back to system properties. + * + * Required configuration: + * - Token file path: Path to the JWT token file (typically mounted by Kubernetes) + * - Role ARN: The ARN of the IAM role to assume using the Web Identity token + * + * Optional configuration: + * - Role session name: Custom session name for the assumed role session + * + * Configuration sources (in priority order): + * 1. Environment variables: AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN, AWS_ROLE_SESSION_NAME + * 2. System properties: aws.webIdentityTokenFile, aws.roleArn, aws.roleSessionName + * + * @return F[Option[WebIdentityTokenCredentialProperties]] Web Identity configuration if both token file and role ARN are available, None otherwise + */ private def loadWebIdentityConfig(): F[Option[WebIdentityTokenCredentialProperties]] = for tokenFilePath <- loadTokenFilePath() @@ -92,18 +133,95 @@ final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: case _ => None } + /** + * Loads the Web Identity Token file path from configuration sources. + * + * This method retrieves the file system path to the JWT token file used for + * Web Identity authentication. The token file is typically mounted by container + * orchestration systems and contains a JWT token signed by an OIDC provider. + * + * Configuration sources (in priority order): + * 1. Environment variable: AWS_WEB_IDENTITY_TOKEN_FILE + * 2. System property: aws.webIdentityTokenFile + * + * Common token file paths in Kubernetes environments: + * - /var/run/secrets/eks.amazonaws.com/serviceaccount/token (EKS IRSA) + * - /var/run/secrets/kubernetes.io/serviceaccount/token (Standard Kubernetes SA token) + * + * @return F[Option[String]] The token file path if configured, None otherwise + */ private def loadTokenFilePath(): F[Option[String]] = for envValue <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + /** + * Loads the IAM role ARN for Web Identity Token authentication. + * + * This method retrieves the Amazon Resource Name (ARN) of the IAM role that should + * be assumed using the Web Identity token. The role must be configured with a trust + * policy that allows the OIDC provider to assume it. + * + * Configuration sources (in priority order): + * 1. Environment variable: AWS_ROLE_ARN + * 2. System property: aws.roleArn + * + * Example role ARN format: + * arn:aws:iam::123456789012:role/EKS-Pod-Identity-Role + * + * The IAM role must have a trust policy similar to: + * ```json + * { + * "Version": "2012-10-17", + * "Statement": [ + * { + * "Effect": "Allow", + * "Principal": { + * "Federated": "arn:aws:iam::ACCOUNT_ID:oidc-provider/OIDC_PROVIDER_URL" + * }, + * "Action": "sts:AssumeRoleWithWebIdentity" + * } + * ] + * } + * ``` + * + * @return F[Option[String]] The IAM role ARN if configured, None otherwise + */ private def loadRoleArn(): F[Option[String]] = for envValue <- Env[F].get("AWS_ROLE_ARN") sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + /** + * Loads the optional role session name for Web Identity Token authentication. + * + * This method retrieves an optional session name that will be used to identify + * the assumed role session in AWS CloudTrail logs and other AWS services. + * If not provided, a default session name will be generated automatically. + * + * Configuration sources (in priority order): + * 1. Environment variable: AWS_ROLE_SESSION_NAME + * 2. System property: aws.roleSessionName + * + * Session name requirements: + * - Must be 2-64 characters long + * - Can contain letters, numbers, and the characters +=,.@- + * - Cannot contain spaces + * + * The session name appears in: + * - AWS CloudTrail logs as the "roleSessionName" field + * - AWS STS GetCallerIdentity response + * - IAM policy evaluation context for condition keys + * + * Example session names: + * - "kubernetes-pod-web-app" + * - "ci-cd-pipeline-123" + * - "user-workload-session" + * + * @return F[Option[String]] The role session name if configured, None otherwise + */ private def loadRoleSessionName(): F[Option[String]] = for envValue <- Env[F].get("AWS_ROLE_SESSION_NAME") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala index 3b599d7ae..b50795b6b 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -14,8 +14,77 @@ import ldbc.amazon.exception.SdkClientException import ldbc.amazon.identity.* import ldbc.amazon.util.SdkSystemSetting +/** + * Base trait for AWS credential providers that load credentials from system settings. + * + * This trait provides a common foundation for credential providers that read AWS credentials + * from system-level configuration sources such as environment variables and Java system properties. + * It implements the standard AWS credential loading pattern with support for both basic + * credentials (access key + secret key) and session credentials (+ session token). + * + * The trait uses the AWS SDK system setting definitions to ensure consistency with + * AWS SDK naming conventions and behavior. It supports the following credential types: + * + * Basic credentials: + * - Access Key ID (AWS_ACCESS_KEY_ID / aws.accessKeyId) + * - Secret Access Key (AWS_SECRET_ACCESS_KEY / aws.secretAccessKey) + * + * Session credentials (temporary credentials): + * - Access Key ID (AWS_ACCESS_KEY_ID / aws.accessKeyId) + * - Secret Access Key (AWS_SECRET_ACCESS_KEY / aws.secretAccessKey) + * - Session Token (AWS_SESSION_TOKEN / aws.sessionToken) + * + * Optional metadata: + * - Account ID (AWS_ACCOUNT_ID / aws.accountId) + * + * Implementing classes must provide: + * 1. loadSetting method to read from their specific configuration source + * 2. provider property to identify the credential source for metrics + * + * Error handling: + * - Missing access key or secret key results in SdkClientException + * - Missing session token is acceptable (falls back to basic credentials) + * - Account ID is optional for all credential types + * + * @tparam F The effect type that supports error handling via MonadThrow + */ trait SystemSettingsCredentialsProvider[F[_]](using ev: MonadThrow[F]) extends AwsCredentialsProvider[F]: + /** + * Resolves AWS credentials from system settings (environment variables or system properties). + * + * This method implements the standard AWS credential loading logic used by both + * environment variable and system property credential providers. It attempts to + * load all required and optional credential components from the configured source. + * + * Credential loading process: + * 1. Load access key ID (required) + * 2. Load secret access key (required) + * 3. Load session token (optional) + * 4. Load account ID (optional) + * 5. Validate that required credentials are present + * 6. Determine credential type based on session token presence + * 7. Create appropriate credential instance (Basic or Session) + * + * Credential validation: + * - Access key ID and secret access key are required + * - Values are trimmed and must be non-empty after trimming + * - Missing required credentials result in SdkClientException + * - Session token and account ID are optional + * + * Credential types returned: + * - AwsBasicCredentials: When only access key and secret key are provided + * - AwsSessionCredentials: When session token is also provided + * + * Both credential types include provider metadata for AWS usage metrics and debugging: + * - providerName: Identifies the specific credential source + * - accountId: AWS account ID if available + * - validateCredentials: Set to false to skip validation + * - expirationTime: None for system setting credentials (no expiration) + * + * @return F[AwsCredentials] The resolved AWS credentials (Basic or Session type) + * @throws SdkClientException if required credentials (access key or secret key) are missing + */ override def resolveCredentials(): F[AwsCredentials] = for accessKeyOpt <- loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.map(_.trim)) @@ -56,6 +125,50 @@ trait SystemSettingsCredentialsProvider[F[_]](using ev: MonadThrow[F]) extends A ) } + /** + * Loads a system setting value from the concrete configuration source. + * + * This abstract method must be implemented by concrete credential providers + * to specify how to read configuration values from their specific source + * (environment variables, system properties, etc.). + * + * The implementation should: + * 1. Read the value from the appropriate configuration source + * 2. Return Some(value) if the setting exists and has a value + * 3. Return None if the setting does not exist or has no value + * 4. Handle any source-specific errors appropriately + * + * Example implementations: + * - Environment variables: Read from System.getenv() or effect-based environment access + * - System properties: Read from System.getProperty() or effect-based property access + * - Configuration files: Parse and read from file-based configuration + * + * The SdkSystemSetting parameter provides both the environment variable name + * and system property name for the setting, allowing implementations to choose + * the appropriate source or implement fallback logic. + * + * @param setting The AWS SDK system setting definition containing environment variable and system property names + * @return F[Option[String]] The setting value if available, None otherwise + */ def loadSetting(setting: SdkSystemSetting): F[Option[String]] + /** + * Identifier string for this credential provider used in AWS usage metrics and logging. + * + * This string is included in AWS service requests to track credential provider usage + * and appears in AWS CloudTrail logs for debugging purposes. It should be a short, + * descriptive identifier that uniquely identifies the credential source. + * + * Common provider identifiers: + * - "Environment": For environment variable credentials + * - "SystemProperty": For Java system property credentials + * - "Profile": For AWS credentials file profiles + * - "Container": For ECS/EKS container credentials + * - "InstanceProfile": For EC2 instance profile credentials + * + * This value is used for AWS business metrics and helps AWS understand + * how different credential providers are being used across AWS SDKs. + * + * @return String identifier for this credential provider + */ def provider: String diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala index c7025722d..28621a53d 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -44,6 +44,18 @@ trait WebIdentityCredentialsUtils[F[_]]: object WebIdentityCredentialsUtils: + /** + * Private implementation of WebIdentityCredentialsUtils. + * + * This implementation handles the complete Web Identity Token flow: + * 1. Reading JWT token from file system + * 2. Validating the token format + * 3. Calling STS AssumeRoleWithWebIdentity + * 4. Converting the STS response to AWS credentials + * + * @param stsClient The STS client to use for AssumeRoleWithWebIdentity operations + * @tparam F The effect type that supports file operations and concurrency + */ private case class Impl[F[_]: Files: Concurrent]( stsClient: StsClient[F] ) extends WebIdentityCredentialsUtils[F]: diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala index 9141d9d72..829db0620 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala @@ -8,10 +8,112 @@ package ldbc.amazon.client import java.net.URI +/** + * Abstract HTTP client interface for making HTTP requests in the AWS authentication plugin. + * + * This trait defines a generic HTTP client abstraction that can be implemented for different + * effect types (F[_]) and HTTP libraries. It provides the core HTTP operations needed for + * AWS API communication including GET, PUT, and POST requests. + * + * The client is designed to work with functional effect systems like cats-effect, allowing + * for composable and type-safe HTTP operations. Each method returns the HTTP response + * wrapped in the effect type F[_]. + * + * Implementations of this trait should handle: + * - HTTP connection management and pooling + * - Request/response serialization + * - Error handling and retries + * - SSL/TLS configuration for HTTPS + * - Timeout configuration + * + * @tparam F the effect type that wraps the HTTP response (e.g., IO, Future, Task) + * + * @example {{{ + * import cats.effect.IO + * import java.net.URI + * + * // Usage with cats-effect IO + * def makeRequest(client: HttpClient[IO]): IO[HttpResponse] = { + * val uri = URI.create("https://api.amazonaws.com/endpoint") + * val headers = Map("Authorization" -> "Bearer token") + * + * for { + * response <- client.get(uri, headers) + * } yield response + * } + * }}} + * + * @see [[HttpResponse]] for the response type returned by HTTP operations + * @see [[ldbc.amazon.useragent.BusinessMetricFeatureId]] for user agent metrics tracking + */ trait HttpClient[F[_]]: + /** + * Performs an HTTP GET request to the specified URI. + * + * This method executes a GET request with the provided headers and returns + * the HTTP response wrapped in the effect type F[_]. GET requests should be + * idempotent and safe, typically used for retrieving data without side effects. + * + * @param uri the target URI for the GET request, must be a valid HTTP/HTTPS URI + * @param headers a map of HTTP headers to include in the request + * @return an effect containing the HTTP response + * + * @example {{{ + * val uri = URI.create("https://sts.amazonaws.com/") + * val headers = Map( + * "Accept" -> "application/json", + * "User-Agent" -> "ldbc-aws-plugin/1.0" + * ) + * client.get(uri, headers) + * }}} + */ def get(uri: URI, headers: Map[String, String]): F[HttpResponse] + /** + * Performs an HTTP PUT request to the specified URI with a request body. + * + * This method executes a PUT request with the provided headers and body content. + * PUT requests are typically used for creating or updating resources and should + * be idempotent. + * + * @param uri the target URI for the PUT request, must be a valid HTTP/HTTPS URI + * @param headers a map of HTTP headers to include in the request + * @param body the request body content as a string + * @return an effect containing the HTTP response + * + * @example {{{ + * val uri = URI.create("https://api.amazonaws.com/resource/123") + * val headers = Map( + * "Content-Type" -> "application/json", + * "Authorization" -> "AWS4-HMAC-SHA256 ..." + * ) + * val body = """{"key": "value"}""" + * client.put(uri, headers, body) + * }}} + */ def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] + /** + * Performs an HTTP POST request to the specified URI with a request body. + * + * This method executes a POST request with the provided headers and body content. + * POST requests are typically used for creating new resources or submitting data + * and may have side effects. + * + * @param uri the target URI for the POST request, must be a valid HTTP/HTTPS URI + * @param headers a map of HTTP headers to include in the request + * @param body the request body content as a string + * @return an effect containing the HTTP response + * + * @example {{{ + * val uri = URI.create("https://sts.amazonaws.com/") + * val headers = Map( + * "Content-Type" -> "application/x-amz-json-1.1", + * "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRole" + * ) + * val body = """{"RoleArn": "arn:aws:iam::123456789012:role/MyRole"}""" + * client.post(uri, headers, body) + * }}} + */ def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala index 20693ca32..adc79de96 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala @@ -6,6 +6,52 @@ package ldbc.amazon.client +/** + * Represents an HTTP response from AWS API calls or other HTTP operations. + * + * This case class encapsulates the essential components of an HTTP response including + * the status code, response headers, and response body. It is designed to be immutable + * and type-safe, providing a clean abstraction over HTTP responses in the AWS + * authentication plugin. + * + * The response is typically created by HTTP client implementations and consumed by + * AWS service clients for processing API responses, credential retrieval, and + * authentication operations. + * + * @param statusCode the HTTP status code indicating the result of the request (e.g., 200, 404, 500) + * @param headers a map of HTTP response headers with case-insensitive keys + * @param body the response body content as a string, may be empty for certain responses + * + * @example {{{ + * // Creating a successful response + * val response = HttpResponse( + * statusCode = 200, + * headers = Map( + * "Content-Type" -> "application/json", + * "Content-Length" -> "123" + * ), + * body = """{"access_token": "abc123", "expires_in": 3600}""" + * ) + * + * // Checking response status + * if (response.statusCode >= 200 && response.statusCode < 300) { + * // Process successful response + * println(s"Success: ${response.body}") + * } + * + * // Accessing specific headers + * response.headers.get("Content-Type") match { + * case Some("application/json") => // Parse JSON response + * case _ => // Handle other content types + * } + * }}} + * + * @note HTTP header names should be treated as case-insensitive according to RFC 7230. + * Implementations should normalize header names appropriately. + * + * @see [[HttpClient]] for the interface that produces HttpResponse instances + * @see [[ldbc.amazon.useragent.BusinessMetricFeatureId]] for tracking HTTP client metrics + */ final case class HttpResponse( statusCode: Int, headers: Map[String, String], diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 0ae3e890d..b1deb01c5 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -71,6 +71,18 @@ object StsClient: assumedRoleArn: String ) + /** + * Private implementation of StsClient. + * + * This implementation handles: + * 1. Request validation and parameter setup + * 2. HTTP request formatting and execution + * 3. Response parsing and error handling + * + * @param stsEndpoint The STS service endpoint URL + * @param httpClient The HTTP client to use for making requests + * @tparam F The effect type that supports UUID generation and concurrency + */ private case class Impl[F[_]: UUIDGen: Concurrent]( stsEndpoint: String, httpClient: HttpClient[F] @@ -107,6 +119,15 @@ object StsClient: stsResponse <- parseAssumeRoleResponse(response.body) yield stsResponse + /** + * Validates that the provided IAM role ARN has the correct format. + * + * Role ARNs must match the pattern: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + * where ACCOUNT_ID is a 12-digit number and ROLE_NAME contains valid characters. + * + * @param roleArn The IAM role ARN to validate + * @return Unit if valid, raises an error if invalid + */ private def validateRoleArn(roleArn: String): F[Unit] = val roleArnPattern = """^arn:aws:iam::\d{12}:role/[\w+=,.@-]+$""".r roleArnPattern.findFirstIn(roleArn) match { @@ -132,6 +153,13 @@ object StsClient: /** * Builds the STS request body in AWS Query format. + * + * Creates a URL-encoded form body with the required STS parameters for the + * AssumeRoleWithWebIdentity operation. All parameter names and values are + * properly URL-encoded to ensure safe transmission. + * + * @param request The STS request containing the parameters + * @return URL-encoded form data string for the request body */ private def buildRequestBody(request: AssumeRoleWithWebIdentityRequest): String = val params = Map( @@ -152,6 +180,12 @@ object StsClient: /** * Gets current timestamp in AWS format. + * + * Generates a timestamp string in the ISO 8601 format required by AWS: + * yyyyMMddTHHmmssZ (e.g., "20231201T120000Z"). The timestamp is always + * generated in UTC timezone. + * + * @return Either an error if timestamp generation fails, or the formatted timestamp string */ private def getCurrentTimestamp(): Either[Throwable, String] = try { @@ -167,6 +201,13 @@ object StsClient: /** * Validates HTTP response status. + * + * Checks if the HTTP response has a successful status code (2xx range). + * If the status code indicates an error, raises an StsException with + * the status code and response body for debugging. + * + * @param response The HTTP response to validate + * @return Unit if status is successful, raises StsException otherwise */ private def validateHttpResponse[F[_]: MonadThrow](response: HttpResponse): F[Unit] = if response.statusCode >= 200 && response.statusCode < 300 then { diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala index 1601deebb..eb6036181 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala @@ -6,5 +6,56 @@ package ldbc.amazon.exception +/** + * Exception thrown when AWS credentials cannot be fetched from any available credential provider. + * + * This exception represents a high-level failure in the AWS credentials resolution process and + * is typically thrown when: + * - All configured credential providers have failed to provide valid credentials + * - The credential provider chain has been exhausted without success + * - Network connectivity issues prevent access to credential sources + * - Authentication tokens have expired and cannot be refreshed + * - Required environment variables or configuration files are missing + * + * Common credential sources that may fail: + * - **Environment variables**: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN` + * - **AWS credentials file**: `~/.aws/credentials` + * - **IAM instance profiles**: EC2 metadata service + * - **Web Identity Tokens**: EKS service account tokens, OIDC providers + * - **AWS STS**: AssumeRole, AssumeRoleWithWebIdentity operations + * - **Container credentials**: ECS task metadata endpoints + * + * This exception typically indicates that the application cannot authenticate with AWS services + * and will likely be unable to access any AWS resources until the credential issue is resolved. + * + * Troubleshooting steps: + * 1. Verify AWS configuration and credentials + * 2. Check network connectivity to AWS endpoints + * 3. Validate IAM permissions and trust relationships + * 4. Ensure credential files and environment variables are correctly set + * 5. Check for expired tokens or certificates + * + * Example scenarios: + * - EKS pod without proper IRSA configuration + * - EC2 instance without IAM instance profile + * - Local development environment with missing AWS configuration + * - Network policies blocking access to AWS credential endpoints + * + * @param message The detailed error message describing why credential fetching failed, + * including information about which credential sources were attempted + * and what specific failures occurred + */ class CredentialsFetchError(message: String) extends Exception: + + /** + * Returns the error message for this exception. + * + * The message typically includes details about: + * - Which credential providers were attempted + * - The specific failure reason for each provider + * - Suggested remediation steps + * - Environment or configuration context + * + * @return The detailed error message describing the credential fetch failure + */ override def getMessage: String = message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala index a5e175d1c..63e5e3a52 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala @@ -7,21 +7,34 @@ package ldbc.amazon.exception /** - * Thrown when the Web Identity Token is invalid or malformed. + * Exception thrown when a Web Identity Token is invalid, malformed, or cannot be processed. * - * This exception is typically thrown when: + * This exception is typically thrown during token validation when: * - The JWT token does not have the correct format (header.payload.signature) * - The token file is empty or contains only whitespace - * - The token contains invalid characters or encoding - * - The JWT structure is corrupted + * - The token contains invalid characters or encoding issues + * - The JWT structure is corrupted or missing required components + * - Base64 decoding of JWT segments fails + * - JSON parsing of JWT header or payload fails * - * Valid JWT token format: + * Valid JWT token format (3 base64-encoded segments separated by dots): * ``` * eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature * ``` * - * @param message The detailed error message - * @param cause The underlying cause of the exception (optional) + * Common scenarios that trigger this exception: + * - Token file contains non-JWT content (e.g., HTML error page, plain text) + * - Network issues resulted in partial token download + * - Token rotation occurred mid-process leaving stale content + * - File system corruption affecting the token file + * + * This exception extends [[WebIdentityTokenException]] and inherits [[NoStackTrace]] behavior + * for performance optimization during token validation workflows. + * + * @param message The detailed error message describing the specific validation failure, + * including information about which part of the token validation failed + * @param cause The underlying cause of the exception (optional). Common causes include + * JSON parsing exceptions, Base64 decoding errors, or I/O exceptions */ class InvalidTokenException( message: String, @@ -29,7 +42,14 @@ class InvalidTokenException( ) extends WebIdentityTokenException(message, cause): /** - * Constructor with cause + * Alternative constructor that accepts a required cause parameter. + * + * This is useful when the underlying parsing or validation error should always be + * preserved for debugging token format issues. + * + * @param message The detailed error message describing the token validation failure + * @param cause The underlying cause of the validation failure (e.g., JSON parsing exception, + * Base64 decoding error, or character encoding exception) */ def this(message: String, cause: Throwable) = this(message, Some(cause)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala index 5892e222a..ceafdf7f5 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala @@ -6,6 +6,22 @@ package ldbc.amazon.exception +/** + * Base exception for AWS SDK client-side errors. + * + * This exception is thrown when errors occur on the client side, such as: + * - Invalid configuration + * - Network connectivity issues + * - Authentication failures + * - Missing required resources + * + * @param message A descriptive error message explaining the cause of the exception + */ class SdkClientException(message: String) extends RuntimeException: + /** + * Returns the error message for this exception. + * + * @return The descriptive error message + */ override def getMessage: String = message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala index 162024147..bde634167 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -28,8 +28,13 @@ import scala.util.control.NoStackTrace * https://sts.us-east-1.amazonaws.com/ * ``` * - * @param message The detailed error message including STS response details - * @param cause The underlying cause of the exception (optional) + * This exception extends [[SdkClientException]] and implements [[NoStackTrace]] for improved performance + * when exception handling is frequent, particularly in AWS authentication scenarios where retries are common. + * + * @param message The detailed error message including STS response details, HTTP status codes, + * and any relevant context from the failed STS operation + * @param cause The underlying cause of the exception (optional). Typically contains the original + * HTTP exception, JSON parsing error, or network connectivity issue */ class StsException( message: String, @@ -41,7 +46,13 @@ class StsException( cause.foreach(initCause) /** - * Constructor with cause + * Alternative constructor that accepts a required cause parameter. + * + * This constructor is useful when the underlying cause is always available and should be + * explicitly tracked for debugging purposes. + * + * @param message The detailed error message including STS response details + * @param cause The underlying cause of the exception that triggered this STS failure */ def this(message: String, cause: Throwable) = this(message, Some(cause)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala index a40802e20..a0180c0af 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala @@ -7,20 +7,54 @@ package ldbc.amazon.exception /** - * Thrown when the Web Identity Token file cannot be accessed due to permissions or I/O issues. + * Exception thrown when the Web Identity Token file exists but cannot be accessed due to + * permissions, I/O errors, or other file system issues. * - * This exception is typically thrown when: - * - The file exists but cannot be read due to insufficient permissions - * - I/O errors occur while reading the token file - * - File system issues prevent access to the token file + * This exception is typically thrown in the following scenarios: + * - The file exists but cannot be read due to insufficient file system permissions + * - I/O errors occur while reading the token file (disk errors, network mount issues) + * - File system issues prevent access (read-only file systems, corrupted file systems) + * - File is locked by another process or application + * - Security policy restrictions prevent file access (SELinux, AppArmor, etc.) + * - File descriptor limits are exceeded * - * Common solutions: - * - Verify file permissions (typically 600 for token files) - * - Check if the process has read access to the file - * - Ensure the file system is mounted and accessible + * Common environments and solutions: * - * @param message The detailed error message - * @param cause The underlying cause of the exception (optional) + * **Container environments (Docker, Kubernetes):** + * - Ensure the container user has read access to mounted token files + * - Verify that volume mounts are configured correctly + * - Check that security contexts allow file access + * + * **File permissions:** + * - Token files typically need 600 permissions (read for owner only) + * - Verify the process owner matches the file owner or has appropriate group access + * - Check parent directory permissions (must be executable to access files within) + * + * **File system issues:** + * - Ensure file systems are mounted correctly (especially network mounts) + * - Verify disk space and inode availability + * - Check for file system corruption or read-only states + * + * Example debugging steps: + * ```bash + * # Check file permissions + * ls -la /var/run/secrets/eks.amazonaws.com/serviceaccount/token + * + * # Test file access as the application user + * sudo -u app-user cat /var/run/secrets/eks.amazonaws.com/serviceaccount/token + * + * # Check file system mount status + * mount | grep /var/run/secrets + * ``` + * + * This exception extends [[WebIdentityTokenException]] and inherits [[NoStackTrace]] behavior + * to avoid performance overhead during frequent access attempts. + * + * @param message The detailed error message describing the specific access failure, + * including file path and permission details when available + * @param cause The underlying cause of the exception (optional). Common causes include + * `AccessDeniedException`, `IOException`, `SecurityException`, or other + * file system related exceptions */ class TokenFileAccessException( message: String, @@ -28,7 +62,15 @@ class TokenFileAccessException( ) extends WebIdentityTokenException(message, cause): /** - * Constructor with cause + * Alternative constructor that accepts a required cause parameter. + * + * This constructor is preferred when the underlying file system exception provides + * valuable diagnostic information that should be preserved for troubleshooting + * access issues. + * + * @param message The detailed error message describing the access failure + * @param cause The underlying file system exception that caused this access failure + * (e.g., `AccessDeniedException`, `IOException`, `SecurityException`) */ def this(message: String, cause: Throwable) = this(message, Some(cause)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala index dc4ae761f..132e197b9 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala @@ -7,21 +7,34 @@ package ldbc.amazon.exception /** - * Thrown when the Web Identity Token file cannot be found. + * Exception thrown when the Web Identity Token file cannot be found at the specified location. * - * This exception is typically thrown when: - * - The file path specified in AWS_WEB_IDENTITY_TOKEN_FILE does not exist - * - The token file has been moved or deleted - * - The file path is incorrectly configured + * This exception is typically thrown in the following scenarios: + * - The file path specified in `AWS_WEB_IDENTITY_TOKEN_FILE` environment variable does not exist + * - The token file has been moved, deleted, or renamed after the application started + * - The file path is incorrectly configured in the environment + * - Directory structure changes that affect the token file location + * - Container restart scenarios where mounted volumes are not yet available + * + * Common environments where this occurs: + * - **EKS with IRSA (IAM Roles for Service Accounts)**: Token files are mounted by the EKS service + * - **Fargate**: Token files provided via task metadata endpoint + * - **Local development**: When mimicking AWS environments with local token files * * Example usage in EKS/IRSA: * ``` * AWS_WEB_IDENTITY_TOKEN_FILE=/var/run/secrets/eks.amazonaws.com/serviceaccount/token + * AWS_ROLE_ARN=arn:aws:iam::123456789012:role/my-service-role * ``` * - * @param message The detailed error message - * @param tokenFilePath The path to the missing token file (optional) - * @param cause The underlying cause of the exception (optional) + * This exception extends [[WebIdentityTokenException]] and provides enhanced error messages + * that include the attempted file path when available, making debugging easier. + * + * @param message The detailed error message describing the file lookup failure + * @param tokenFilePath The path to the missing token file (optional). When provided, this path + * will be included in the error message returned by [[getMessage]] + * @param cause The underlying cause of the exception (optional). Typically a file system + * related exception such as `FileNotFoundException` or access-related errors */ class TokenFileNotFoundException( message: String, @@ -30,23 +43,51 @@ class TokenFileNotFoundException( ) extends WebIdentityTokenException(message, cause): /** - * Constructor with cause only + * Constructor with cause only. + * + * Use this constructor when the underlying file system exception is available but the + * specific file path should be extracted from the exception or is not relevant. + * + * @param message The detailed error message describing the file lookup failure + * @param cause The underlying file system exception that caused this failure */ def this(message: String, cause: Throwable) = this(message, None, Some(cause)) /** - * Constructor with token file path only + * Constructor with token file path only. + * + * Use this constructor when you have the specific file path that failed but no underlying + * exception (e.g., when programmatically checking file existence). + * + * @param message The detailed error message describing the file lookup failure + * @param tokenFilePath The full path to the token file that was not found */ def this(message: String, tokenFilePath: String) = this(message, Some(tokenFilePath), None) /** - * Constructor with both token file path and cause + * Constructor with both token file path and cause. + * + * This is the most comprehensive constructor, providing both the file path context + * and the underlying exception for complete error traceability. + * + * @param message The detailed error message describing the file lookup failure + * @param tokenFilePath The full path to the token file that was not found + * @param cause The underlying file system exception that caused this failure */ def this(message: String, tokenFilePath: String, cause: Throwable) = this(message, Some(tokenFilePath), Some(cause)) + /** + * Returns the error message for this exception, including the token file path when available. + * + * This method enhances the base error message by appending the token file path when it + * was provided during exception construction, making it easier to identify which specific + * file could not be found. + * + * @return The enhanced error message that includes the file path context when available + */ override def getMessage: String = tokenFilePath match case Some(path) => s"$message (Token file path: $path)" diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala index a8e0826c4..859065d94 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -11,6 +11,23 @@ import java.util.Objects import ldbc.amazon.identity.AwsCredentialsIdentity +/** + * Default implementation of AWS credentials identity. + * + * This implementation provides AWS credentials with complete identity information + * including access keys, account ID, expiration time, and provider name. It serves + * as the concrete implementation used throughout the AWS authentication plugin. + * + * The class implements proper equals/hashCode semantics for credential comparison + * and provides a secure toString representation that does not expose sensitive + * information like secret access keys. + * + * @param accessKeyId The AWS access key ID used for authentication + * @param secretAccessKey The AWS secret access key used for signing requests + * @param accountId Optional AWS account ID associated with these credentials + * @param expirationTime Optional expiration time for temporary credentials + * @param providerName Optional name of the credentials provider that created these credentials + */ final case class DefaultAwsCredentialsIdentity( accessKeyId: String, secretAccessKey: String, @@ -19,6 +36,14 @@ final case class DefaultAwsCredentialsIdentity( providerName: Option[String] ) extends AwsCredentialsIdentity: + /** + * Returns a string representation of these credentials without exposing sensitive information. + * + * The secret access key is intentionally omitted from the string representation for security. + * Only the access key ID, provider name, and account ID are included. + * + * @return A secure string representation of the credentials identity + */ override def toString: String = val builder = new StringBuilder() builder.append("AwsCredentialsIdentity(") @@ -29,6 +54,16 @@ final case class DefaultAwsCredentialsIdentity( builder.result() + /** + * Compares this credentials identity with another object for equality. + * + * Two credentials identities are considered equal if they have the same + * access key ID, secret access key, and account ID. The expiration time + * and provider name are not considered for equality comparison. + * + * @param obj The object to compare with this credentials identity + * @return true if the objects are equal, false otherwise + */ override def equals(obj: Any): Boolean = obj match case that: DefaultAwsCredentialsIdentity => @@ -38,6 +73,15 @@ final case class DefaultAwsCredentialsIdentity( Objects.equals(accountId, that.accountId)) case _ => false + /** + * Returns a hash code value for this credentials identity. + * + * The hash code is computed based on the access key ID, secret access key, + * and account ID to ensure consistency with the equals method. The expiration + * time and provider name are not included in the hash code calculation. + * + * @return A hash code value for this credentials identity + */ override def hashCode(): Int = var hashCode = 1 hashCode = 31 * hashCode + Objects.hashCode(accessKeyId) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala index ab52cc57a..9a8d1e6c3 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -7,58 +7,188 @@ package ldbc.amazon.useragent /** - * An enum class representing a short form of identity providers to record in the UA string. - * - * Unimplemented metrics: I,K - * Unsupported metrics (these will never be added): A,H + * Enumeration of business metric feature identifiers for AWS SDK user agent tracking. + * + * This enum represents a comprehensive set of short-form codes used to identify specific + * features, capabilities, and configurations in the AWS SDK user agent string. These + * metrics help AWS understand which features are being used and how the SDK is configured, + * enabling better service optimization and support. + * + * The feature identifiers cover several categories: + * - SDK features (waiters, paginators, transfer utilities) + * - Retry modes (legacy, standard, adaptive) + * - Compression and optimization features + * - Authentication and credential providers + * - Endpoint and account ID handling modes + * - Checksum validation methods + * - Protocol-specific features + * + * Each feature is assigned a unique single-character or short code that is embedded + * in the user agent string sent with AWS API requests. + * + * @param code the short identifier code used in the user agent string + * + * @example {{{ + * // Getting the code for a specific feature + * val retryCode = BusinessMetricFeatureId.RETRY_MODE_STANDARD.code // Returns "E" + * + * // Using in user agent construction + * val features = List( + * BusinessMetricFeatureId.PAGINATOR, + * BusinessMetricFeatureId.GZIP_REQUEST_COMPRESSION + * ) + * val codes = features.map(_.code).mkString(",") // "C,L" + * }}} + * + * @note Unimplemented metrics: I, K + * @note Unsupported metrics (will never be added): A, H + * + * @see [[ldbc.amazon.client.HttpClient]] for HTTP client implementations that may use these metrics */ enum BusinessMetricFeatureId(val code: String): - case WAITER extends BusinessMetricFeatureId("B") - case PAGINATOR extends BusinessMetricFeatureId("C") - case RETRY_MODE_LEGACY extends BusinessMetricFeatureId("D") - case RETRY_MODE_STANDARD extends BusinessMetricFeatureId("E") - case RETRY_MODE_ADAPTIVE extends BusinessMetricFeatureId("F") - case S3_TRANSFER extends BusinessMetricFeatureId("G") - case GZIP_REQUEST_COMPRESSION extends BusinessMetricFeatureId("L") - case PROTOCOL_RPC_V2_CBOR extends BusinessMetricFeatureId("M") - case ENDPOINT_OVERRIDE extends BusinessMetricFeatureId("N") - case S3_EXPRESS_BUCKET extends BusinessMetricFeatureId("J") - case ACCOUNT_ID_MODE_PREFERRED extends BusinessMetricFeatureId("P") - case ACCOUNT_ID_MODE_DISABLED extends BusinessMetricFeatureId("Q") - case ACCOUNT_ID_MODE_REQUIRED extends BusinessMetricFeatureId("R") - case SIGV4A_SIGNING extends BusinessMetricFeatureId("S") - case RESOLVED_ACCOUNT_ID extends BusinessMetricFeatureId("T") - case FLEXIBLE_CHECKSUMS_REQ_CRC32 extends BusinessMetricFeatureId("U") - case FLEXIBLE_CHECKSUMS_REQ_CRC32C extends BusinessMetricFeatureId("V") - case FLEXIBLE_CHECKSUMS_REQ_CRC64 extends BusinessMetricFeatureId("W") - case FLEXIBLE_CHECKSUMS_REQ_SHA1 extends BusinessMetricFeatureId("X") - case FLEXIBLE_CHECKSUMS_REQ_SHA256 extends BusinessMetricFeatureId("Y") + /** Indicates usage of SDK waiter functionality for polling operations. */ + case WAITER extends BusinessMetricFeatureId("B") + + /** Indicates usage of SDK paginator functionality for handling paginated API responses. */ + case PAGINATOR extends BusinessMetricFeatureId("C") + + /** Indicates usage of legacy retry mode with basic exponential backoff. */ + case RETRY_MODE_LEGACY extends BusinessMetricFeatureId("D") + + /** Indicates usage of standard retry mode with improved backoff strategies. */ + case RETRY_MODE_STANDARD extends BusinessMetricFeatureId("E") + + /** Indicates usage of adaptive retry mode with dynamic rate adjustment. */ + case RETRY_MODE_ADAPTIVE extends BusinessMetricFeatureId("F") + + /** Indicates usage of S3 transfer manager for optimized file uploads/downloads. */ + case S3_TRANSFER extends BusinessMetricFeatureId("G") + + /** Indicates usage of GZIP compression for request bodies. */ + case GZIP_REQUEST_COMPRESSION extends BusinessMetricFeatureId("L") + + /** Indicates usage of RPC v2 protocol with CBOR encoding. */ + case PROTOCOL_RPC_V2_CBOR extends BusinessMetricFeatureId("M") + + /** Indicates usage of custom endpoint override configuration. */ + case ENDPOINT_OVERRIDE extends BusinessMetricFeatureId("N") + + /** Indicates usage of S3 Express One Zone bucket features. */ + case S3_EXPRESS_BUCKET extends BusinessMetricFeatureId("J") + + /** Indicates account ID endpoint mode is set to preferred. */ + case ACCOUNT_ID_MODE_PREFERRED extends BusinessMetricFeatureId("P") + + /** Indicates account ID endpoint mode is disabled. */ + case ACCOUNT_ID_MODE_DISABLED extends BusinessMetricFeatureId("Q") + + /** Indicates account ID endpoint mode is required. */ + case ACCOUNT_ID_MODE_REQUIRED extends BusinessMetricFeatureId("R") + + /** Indicates usage of Signature Version 4A (SigV4A) signing for multi-region requests. */ + case SIGV4A_SIGNING extends BusinessMetricFeatureId("S") + + /** Indicates that account ID has been resolved for the request. */ + case RESOLVED_ACCOUNT_ID extends BusinessMetricFeatureId("T") + + /** Indicates usage of CRC32 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_CRC32 extends BusinessMetricFeatureId("U") + + /** Indicates usage of CRC32C checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_CRC32C extends BusinessMetricFeatureId("V") + + /** Indicates usage of CRC64 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_CRC64 extends BusinessMetricFeatureId("W") + + /** Indicates usage of SHA1 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_SHA1 extends BusinessMetricFeatureId("X") + + /** Indicates usage of SHA256 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_SHA256 extends BusinessMetricFeatureId("Y") + + /** Indicates flexible request checksums are calculated when supported. */ case FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED extends BusinessMetricFeatureId("Z") - case FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED extends BusinessMetricFeatureId("a") + + /** Indicates flexible request checksums are calculated when required. */ + case FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED extends BusinessMetricFeatureId("a") + + /** Indicates flexible response checksums are validated when supported. */ case FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED extends BusinessMetricFeatureId("b") - case FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED extends BusinessMetricFeatureId("c") - case DDB_MAPPER extends BusinessMetricFeatureId("d") - case BEARER_SERVICE_ENV_VARS extends BusinessMetricFeatureId("3") - case CREDENTIALS_CODE extends BusinessMetricFeatureId("e") - case CREDENTIALS_JVM_SYSTEM_PROPERTIES extends BusinessMetricFeatureId("f") - case CREDENTIALS_ENV_VARS extends BusinessMetricFeatureId("g") + + /** Indicates flexible response checksums are validated when required. */ + case FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED extends BusinessMetricFeatureId("c") + + /** Indicates usage of DynamoDB Object Mapper functionality. */ + case DDB_MAPPER extends BusinessMetricFeatureId("d") + + /** Indicates bearer token credentials loaded from environment variables. */ + case BEARER_SERVICE_ENV_VARS extends BusinessMetricFeatureId("3") + + /** Indicates credentials provided directly in application code. */ + case CREDENTIALS_CODE extends BusinessMetricFeatureId("e") + + /** Indicates credentials loaded from JVM system properties. */ + case CREDENTIALS_JVM_SYSTEM_PROPERTIES extends BusinessMetricFeatureId("f") + + /** Indicates credentials loaded from environment variables. */ + case CREDENTIALS_ENV_VARS extends BusinessMetricFeatureId("g") + + /** Indicates credentials loaded from environment variables with STS web identity token. */ case CREDENTIALS_ENV_VARS_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("h") - case CREDENTIALS_STS_ASSUME_ROLE extends BusinessMetricFeatureId("i") - case CREDENTIALS_STS_ASSUME_ROLE_SAML extends BusinessMetricFeatureId("j") - case CREDENTIALS_STS_ASSUME_ROLE_WEB_ID extends BusinessMetricFeatureId("k") - case CREDENTIALS_STS_FEDERATION_TOKEN extends BusinessMetricFeatureId("l") - case CREDENTIALS_STS_SESSION_TOKEN extends BusinessMetricFeatureId("m") - case CREDENTIALS_PROFILE extends BusinessMetricFeatureId("n") - case CREDENTIALS_PROFILE_SOURCE_PROFILE extends BusinessMetricFeatureId("o") - case CREDENTIALS_PROFILE_NAMED_PROVIDER extends BusinessMetricFeatureId("p") - case CREDENTIALS_PROFILE_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("q") - case CREDENTIALS_PROFILE_SSO extends BusinessMetricFeatureId("r") - case CREDENTIALS_SSO extends BusinessMetricFeatureId("s") - case CREDENTIALS_PROFILE_SSO_LEGACY extends BusinessMetricFeatureId("t") - case CREDENTIALS_SSO_LEGACY extends BusinessMetricFeatureId("u") - case CREDENTIALS_PROFILE_PROCESS extends BusinessMetricFeatureId("v") - case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") - case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") - case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") - case CREDENTIALS_CONTAINER extends BusinessMetricFeatureId("1") - case UNKNOWN extends BusinessMetricFeatureId("Unknown") + + /** Indicates credentials obtained via STS AssumeRole operation. */ + case CREDENTIALS_STS_ASSUME_ROLE extends BusinessMetricFeatureId("i") + + /** Indicates credentials obtained via STS AssumeRoleWithSAML operation. */ + case CREDENTIALS_STS_ASSUME_ROLE_SAML extends BusinessMetricFeatureId("j") + + /** Indicates credentials obtained via STS AssumeRoleWithWebIdentity operation. */ + case CREDENTIALS_STS_ASSUME_ROLE_WEB_ID extends BusinessMetricFeatureId("k") + + /** Indicates credentials obtained via STS GetFederationToken operation. */ + case CREDENTIALS_STS_FEDERATION_TOKEN extends BusinessMetricFeatureId("l") + + /** Indicates credentials obtained via STS GetSessionToken operation. */ + case CREDENTIALS_STS_SESSION_TOKEN extends BusinessMetricFeatureId("m") + + /** Indicates credentials loaded from AWS profile configuration. */ + case CREDENTIALS_PROFILE extends BusinessMetricFeatureId("n") + + /** Indicates credentials loaded from AWS profile with source profile configuration. */ + case CREDENTIALS_PROFILE_SOURCE_PROFILE extends BusinessMetricFeatureId("o") + + /** Indicates credentials loaded from AWS profile with named credential provider. */ + case CREDENTIALS_PROFILE_NAMED_PROVIDER extends BusinessMetricFeatureId("p") + + /** Indicates credentials loaded from AWS profile with STS web identity token. */ + case CREDENTIALS_PROFILE_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("q") + + /** Indicates credentials loaded from AWS profile with SSO configuration. */ + case CREDENTIALS_PROFILE_SSO extends BusinessMetricFeatureId("r") + + /** Indicates credentials obtained via AWS SSO. */ + case CREDENTIALS_SSO extends BusinessMetricFeatureId("s") + + /** Indicates credentials loaded from AWS profile with legacy SSO configuration. */ + case CREDENTIALS_PROFILE_SSO_LEGACY extends BusinessMetricFeatureId("t") + + /** Indicates credentials obtained via legacy AWS SSO. */ + case CREDENTIALS_SSO_LEGACY extends BusinessMetricFeatureId("u") + + /** Indicates credentials loaded from AWS profile with credential process. */ + case CREDENTIALS_PROFILE_PROCESS extends BusinessMetricFeatureId("v") + + /** Indicates credentials obtained via credential process. */ + case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") + + /** Indicates credentials obtained via HTTP credential provider. */ + case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") + + /** Indicates credentials obtained from EC2 Instance Metadata Service (IMDS). */ + case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") + + /** Indicates credentials obtained from container metadata service. */ + case CREDENTIALS_CONTAINER extends BusinessMetricFeatureId("1") + + /** Indicates an unknown or unrecognized business metric feature. */ + case UNKNOWN extends BusinessMetricFeatureId("Unknown") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala index 51285af7e..7c1119681 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -8,15 +8,51 @@ package ldbc.amazon.util /** * Simple JSON parser for AWS credential responses. - * Avoids external dependencies for Scala.js and Scala Native compatibility. + * + * This parser is designed specifically for parsing AWS service responses that contain + * credentials and other simple data structures. It avoids external dependencies to + * maintain compatibility with Scala.js and Scala Native environments. + * + * The parser handles flat JSON objects with string, number, boolean, and null values + * but does not support nested objects or arrays. */ object SimpleJsonParser: + /** + * Represents a parsed JSON object with string-valued fields. + * + * This class provides convenient methods for accessing JSON fields with + * proper null handling and error reporting for missing required fields. + * + * @param fields Map of field names to optional string values (None represents JSON null) + */ case class JsonObject(fields: Map[String, Option[String]]): + + /** + * Gets an optional field value. + * + * @param key The field name to retrieve + * @return Some(value) if the field exists and is not null, None otherwise + */ def get(key: String): Option[String] = fields.get(key).flatten + /** + * Gets a field value with empty string as fallback. + * + * @param key The field name to retrieve + * @return The field value or empty string if the field is missing or null + */ def getOrEmpty(key: String): String = fields.get(key).flatten.getOrElse("") + /** + * Requires a field to exist with a non-null value. + * + * This method is useful for extracting required fields from JSON responses + * with proper error messaging. + * + * @param key The field name to retrieve + * @return Right(value) if the field exists and is not null, Left(error message) otherwise + */ def require(key: String): Either[String, String] = fields.get(key) match case None => Left(s"Required field '$key' not found") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala index b0bc7a50e..71604e36c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala @@ -6,7 +6,29 @@ package ldbc.amazon.util +/** + * Simple XML parser utility for extracting content from XML documents. + * + * This parser provides basic XML parsing functionality without requiring a full XML library. + * It's designed specifically for parsing AWS service responses and handles common XML patterns. + * The parser performs basic entity decoding and content extraction but does not validate + * XML structure or handle complex XML features like namespaces, CDATA, or processing instructions. + */ object SimpleXmlParser: + + /** + * Decodes standard XML entities to their corresponding characters. + * + * This method replaces common XML entities with their actual character representations: + * - & → & + * - < → < + * - > → > + * - " → " + * - ' → ' + * + * @param s The string containing XML entities to decode + * @return The string with XML entities decoded to their character equivalents + */ def decodeXmlEntities(s: String): String = s.replace("&", "&") .replace("<", "<") @@ -14,6 +36,17 @@ object SimpleXmlParser: .replace(""", "\"") .replace("'", "'") + /** + * Extracts the content between XML tags. + * + * Finds the first occurrence of the specified XML tag and extracts its content. + * The content is trimmed and XML entities are decoded. This method only handles + * simple tags without attributes and does not support nested tags with the same name. + * + * @param tagName The name of the XML tag (without angle brackets) + * @param xml The XML string to search in + * @return Some(content) if the tag is found with content, None otherwise + */ def extractTagContent(tagName: String, xml: String): Option[String] = { val startTag = s"<$tagName>" val endTag = s"" @@ -28,6 +61,17 @@ object SimpleXmlParser: } } + /** + * Extracts a complete XML section including the opening and closing tags. + * + * Finds the first occurrence of the specified XML tag and extracts the entire + * section including the tags themselves. This is useful for extracting nested + * XML structures that need further processing. + * + * @param tagName The name of the XML tag (without angle brackets) + * @param xml The XML string to search in + * @return Some(section) if the tag section is found, None otherwise + */ def extractSection(tagName: String, xml: String): Option[String] = { val startTag = s"<$tagName>" val endTag = s"" @@ -41,6 +85,19 @@ object SimpleXmlParser: } } + /** + * Extracts tag content and throws an exception if the tag is missing or empty. + * + * This is a strict version of extractTagContent that requires the tag to exist + * and have non-empty content. If the tag is missing, empty, or contains only + * whitespace, an IllegalArgumentException is thrown with the provided error message. + * + * @param tagName The name of the XML tag (without angle brackets) + * @param xml The XML string to search in + * @param errorMsg The error message to use if the tag is missing or empty + * @return The tag content if found and non-empty + * @throws IllegalArgumentException if the tag is missing, empty, or contains only whitespace + */ def requireTag(tagName: String, xml: String, errorMsg: String): String = extractTagContent(tagName, xml) .filter(_.nonEmpty) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala index 58e35d88d..913abb747 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala @@ -6,8 +6,47 @@ package ldbc.amazon.util +/** + * Base trait for system setting definitions. This trait serves as a marker interface + * for all system setting implementations in the AWS authentication plugin. + * + * System settings provide a way to configure AWS SDK behavior through system properties + * and environment variables. Implementations of this trait define the available + * configuration options and their default values. + * + * @see [[SdkSystemSetting]] for the concrete implementation with AWS SDK settings + */ trait SystemSetting +/** + * Enumeration of AWS SDK system settings that can be configured through system properties. + * + * This enum defines all the AWS SDK configuration options available through system properties, + * providing a type-safe way to access configuration values. Each setting includes both the + * system property name and an optional default value. + * + * The settings cover various aspects of AWS SDK configuration including: + * - Credential configuration (access keys, session tokens, roles) + * - Regional settings and endpoint configuration + * - Metadata service settings + * - HTTP client configuration + * - Optimization and feature toggles + * - Authentication and signing preferences + * + * @param systemProperty the name of the system property used to configure this setting + * @param defaultValue the default value to use if the system property is not set, None if no default + * + * @example {{{ + * // Access the system property name + * val propertyName = SdkSystemSetting.AWS_REGION.systemProperty + * + * // Get the default value + * val default = SdkSystemSetting.AWS_EC2_METADATA_DISABLED.defaultValue + * }}} + * + * @see [[SystemSetting]] + * @see [[ldbc.amazon.client.HttpClient]] for HTTP client configuration + */ enum SdkSystemSetting(val systemProperty: String, val defaultValue: Option[String]): /** * Configure the AWS access key ID. From c2ba0c5f0eacc5858c0042c015a865d0b94c972b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 20:32:10 +0900 Subject: [PATCH 154/215] Delete unused --- .../exception/InvalidTokenException.scala | 18 +------- .../ldbc/amazon/exception/StsException.scala | 4 +- .../exception/TokenFileAccessException.scala | 20 +-------- .../TokenFileNotFoundException.scala | 42 +------------------ .../exception/WebIdentityTokenException.scala | 14 ------- 5 files changed, 5 insertions(+), 93 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala index 63e5e3a52..8064af341 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala @@ -33,23 +33,7 @@ package ldbc.amazon.exception * * @param message The detailed error message describing the specific validation failure, * including information about which part of the token validation failed - * @param cause The underlying cause of the exception (optional). Common causes include - * JSON parsing exceptions, Base64 decoding errors, or I/O exceptions */ class InvalidTokenException( message: String, - cause: Option[Throwable] = None -) extends WebIdentityTokenException(message, cause): - - /** - * Alternative constructor that accepts a required cause parameter. - * - * This is useful when the underlying parsing or validation error should always be - * preserved for debugging token format issues. - * - * @param message The detailed error message describing the token validation failure - * @param cause The underlying cause of the validation failure (e.g., JSON parsing exception, - * Base64 decoding error, or character encoding exception) - */ - def this(message: String, cause: Throwable) = - this(message, Some(cause)) +) extends WebIdentityTokenException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala index bde634167..48a33ebed 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -47,10 +47,10 @@ class StsException( /** * Alternative constructor that accepts a required cause parameter. - * + * * This constructor is useful when the underlying cause is always available and should be * explicitly tracked for debugging purposes. - * + * * @param message The detailed error message including STS response details * @param cause The underlying cause of the exception that triggered this STS failure */ diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala index a0180c0af..da99e5b49 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala @@ -52,25 +52,7 @@ package ldbc.amazon.exception * * @param message The detailed error message describing the specific access failure, * including file path and permission details when available - * @param cause The underlying cause of the exception (optional). Common causes include - * `AccessDeniedException`, `IOException`, `SecurityException`, or other - * file system related exceptions */ class TokenFileAccessException( message: String, - cause: Option[Throwable] = None -) extends WebIdentityTokenException(message, cause): - - /** - * Alternative constructor that accepts a required cause parameter. - * - * This constructor is preferred when the underlying file system exception provides - * valuable diagnostic information that should be preserved for troubleshooting - * access issues. - * - * @param message The detailed error message describing the access failure - * @param cause The underlying file system exception that caused this access failure - * (e.g., `AccessDeniedException`, `IOException`, `SecurityException`) - */ - def this(message: String, cause: Throwable) = - this(message, Some(cause)) +) extends WebIdentityTokenException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala index 132e197b9..5f141a53a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala @@ -33,51 +33,11 @@ package ldbc.amazon.exception * @param message The detailed error message describing the file lookup failure * @param tokenFilePath The path to the missing token file (optional). When provided, this path * will be included in the error message returned by [[getMessage]] - * @param cause The underlying cause of the exception (optional). Typically a file system - * related exception such as `FileNotFoundException` or access-related errors */ class TokenFileNotFoundException( message: String, tokenFilePath: Option[String] = None, - cause: Option[Throwable] = None -) extends WebIdentityTokenException(message, cause): - - /** - * Constructor with cause only. - * - * Use this constructor when the underlying file system exception is available but the - * specific file path should be extracted from the exception or is not relevant. - * - * @param message The detailed error message describing the file lookup failure - * @param cause The underlying file system exception that caused this failure - */ - def this(message: String, cause: Throwable) = - this(message, None, Some(cause)) - - /** - * Constructor with token file path only. - * - * Use this constructor when you have the specific file path that failed but no underlying - * exception (e.g., when programmatically checking file existence). - * - * @param message The detailed error message describing the file lookup failure - * @param tokenFilePath The full path to the token file that was not found - */ - def this(message: String, tokenFilePath: String) = - this(message, Some(tokenFilePath), None) - - /** - * Constructor with both token file path and cause. - * - * This is the most comprehensive constructor, providing both the file path context - * and the underlying exception for complete error traceability. - * - * @param message The detailed error message describing the file lookup failure - * @param tokenFilePath The full path to the token file that was not found - * @param cause The underlying file system exception that caused this failure - */ - def this(message: String, tokenFilePath: String, cause: Throwable) = - this(message, Some(tokenFilePath), Some(cause)) +) extends WebIdentityTokenException(message): /** * Returns the error message for this exception, including the token file path when available. diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala index 99787ea2f..8614c32a1 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala @@ -6,25 +6,11 @@ package ldbc.amazon.exception -import scala.util.control.NoStackTrace - /** * Base exception for Web Identity Token operations. * * @param message The error message - * @param cause The underlying cause (optional) */ abstract class WebIdentityTokenException( message: String, - cause: Option[Throwable] = None ) extends SdkClientException(message) - with NoStackTrace: - - // Set the cause if provided - cause.foreach(initCause) - - /** - * Constructor with cause - */ - def this(message: String, cause: Throwable) = - this(message, Some(cause)) From 43d4c4cde3a8286e299359d8f539174223c3758b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 20:32:37 +0900 Subject: [PATCH 155/215] Delete unused --- .../ldbc/amazon/exception/StsException.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala index 48a33ebed..0024cd57a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -6,8 +6,6 @@ package ldbc.amazon.exception -import scala.util.control.NoStackTrace - /** * Exception thrown when AWS STS (Security Token Service) operations fail. * @@ -33,26 +31,7 @@ import scala.util.control.NoStackTrace * * @param message The detailed error message including STS response details, HTTP status codes, * and any relevant context from the failed STS operation - * @param cause The underlying cause of the exception (optional). Typically contains the original - * HTTP exception, JSON parsing error, or network connectivity issue */ class StsException( message: String, - cause: Option[Throwable] = None ) extends SdkClientException(message) - with NoStackTrace: - - // Set the cause if provided - cause.foreach(initCause) - - /** - * Alternative constructor that accepts a required cause parameter. - * - * This constructor is useful when the underlying cause is always available and should be - * explicitly tracked for debugging purposes. - * - * @param message The detailed error message including STS response details - * @param cause The underlying cause of the exception that triggered this STS failure - */ - def this(message: String, cause: Throwable) = - this(message, Some(cause)) From edd4bc78696ce2fae805b27497ace143da01bd4a Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 20:32:47 +0900 Subject: [PATCH 156/215] Action sbt scalafmtAll --- .../scala/ldbc/amazon/exception/InvalidTokenException.scala | 2 +- .../src/main/scala/ldbc/amazon/exception/StsException.scala | 2 +- .../scala/ldbc/amazon/exception/TokenFileAccessException.scala | 2 +- .../ldbc/amazon/exception/TokenFileNotFoundException.scala | 2 +- .../scala/ldbc/amazon/exception/WebIdentityTokenException.scala | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala index 8064af341..9ac7dd46c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala @@ -35,5 +35,5 @@ package ldbc.amazon.exception * including information about which part of the token validation failed */ class InvalidTokenException( - message: String, + message: String ) extends WebIdentityTokenException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala index 0024cd57a..fbff2f01a 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -33,5 +33,5 @@ package ldbc.amazon.exception * and any relevant context from the failed STS operation */ class StsException( - message: String, + message: String ) extends SdkClientException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala index da99e5b49..837bf20ea 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala @@ -54,5 +54,5 @@ package ldbc.amazon.exception * including file path and permission details when available */ class TokenFileAccessException( - message: String, + message: String ) extends WebIdentityTokenException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala index 5f141a53a..54b5c3108 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala @@ -36,7 +36,7 @@ package ldbc.amazon.exception */ class TokenFileNotFoundException( message: String, - tokenFilePath: Option[String] = None, + tokenFilePath: Option[String] = None ) extends WebIdentityTokenException(message): /** diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala index 8614c32a1..e12b42ac1 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala @@ -12,5 +12,5 @@ package ldbc.amazon.exception * @param message The error message */ abstract class WebIdentityTokenException( - message: String, + message: String ) extends SdkClientException(message) From 1f81474b6b0659782bba27522fa812b50cc6e4e7 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 20:52:55 +0900 Subject: [PATCH 157/215] Delete unused --- .../ProfileCredentialsProviderTest.scala | 75 ++----------------- .../WebIdentityCredentialsUtilsTest.scala | 13 ---- 2 files changed, 8 insertions(+), 80 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala index 146e625f3..856ba5aac 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala @@ -6,6 +6,8 @@ package ldbc.amazon.auth.credentials +import java.time.Instant + import cats.effect.std.SystemProperties import cats.effect.IO @@ -15,6 +17,8 @@ import munit.CatsEffectSuite import ldbc.amazon.exception.SdkClientException +import ProfileCredentialsProvider.* + class ProfileCredentialsProviderTest extends CatsEffectSuite: // Test fixtures @@ -32,29 +36,8 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: override def set(key: String, value: String): IO[Option[String]] = IO.raiseError(new UnsupportedOperationException("set not supported in mock")) - test("ProfileCredentialsProvider creation succeeds with default parameters") { - given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO - - for provider <- ProfileCredentialsProvider.default[IO]() - yield - // Basic test - provider was created successfully - assert(provider != null) - } - - test("ProfileCredentialsProvider creation succeeds with named profile") { - given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO - - for provider <- ProfileCredentialsProvider.default[IO]("dev") - yield - // Basic test - provider was created successfully - assert(provider != null) - } - test("ProfileCredentialsProvider fails when user.home is missing") { given SystemProperties[IO] = mockSystemProperties(homeDir = None) - given Files[IO] = Files.forIO for provider <- ProfileCredentialsProvider.default[IO]() @@ -73,8 +56,6 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: } test("ProfileFile case class creation works") { - import java.time.Instant - import ProfileCredentialsProvider.* val profiles = Map("default" -> Profile("default", Map("aws_access_key_id" -> "test"))) val instant = Instant.now() @@ -85,7 +66,6 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: } test("Profile case class creation works") { - import ProfileCredentialsProvider.* val properties = Map( "aws_access_key_id" -> "AKIAIOSFODNN7EXAMPLE", @@ -97,52 +77,13 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: assertEquals(profile.properties, properties) } - // Test basic functionality that doesn't require file system access - test("ProfileCredentialsProvider implements AwsCredentialsProvider trait") { - given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO - - for - provider <- ProfileCredentialsProvider.default[IO]() - providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider - yield - // Type check passes - provider implements the trait correctly - assert(providerTrait != null) - } - - test("ProfileCredentialsProvider default factory creates provider with correct profile name") { - given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO - - for - defaultProvider <- ProfileCredentialsProvider.default[IO]() - namedProvider <- ProfileCredentialsProvider.default[IO]("custom") - yield - // Both providers should be created successfully - assert(defaultProvider != null) - assert(namedProvider != null) - } - - // Test thread safety of provider creation - test("ProfileCredentialsProvider factory is thread-safe") { - given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO - - for providers <- IO.parSequenceN(10)((1 to 10).map(_ => ProfileCredentialsProvider.default[IO]()).toList) - yield - // All providers should be created successfully - providers.foreach(provider => assert(provider != null)) - } - test("ProfileCredentialsProvider handles various profile names correctly") { given SystemProperties[IO] = mockSystemProperties() - given Files[IO] = Files.forIO val profileNames = List("default", "dev", "staging", "production", "test-profile", "profile_with_underscores") - for providers <- IO.traverse(profileNames)(ProfileCredentialsProvider.default[IO](_)) - yield - // All providers should be created successfully regardless of profile name format - providers.foreach(provider => assert(provider != null)) - assertEquals(providers.length, profileNames.length) + assertIO( + IO.traverse(profileNames)(ProfileCredentialsProvider.default[IO](_)).map(_.length), + profileNames.length + ) } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala index f62c02595..e72ec7ecf 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala @@ -297,19 +297,6 @@ class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: } } - test("default factory creates WebIdentityCredentialsUtils with proper STS client") { - val mockHttpClient: HttpClient[IO] = new HttpClient[IO]: - override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = - IO.raiseError(new UnsupportedOperationException("HTTP requests not supported in test")) - override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = - IO.raiseError(new UnsupportedOperationException("HTTP requests not supported in test")) - override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = - IO.raiseError(new UnsupportedOperationException("HTTP requests not supported in test")) - - val webIdentityUtils = WebIdentityCredentialsUtils.default[IO]("us-east-1", mockHttpClient) - assert(webIdentityUtils != null) - } - test("reads token from file with correct trimming") { val tokenWithWhitespace = s" $validJwtToken \n\t" From 4d4b697542d05a9cd830a930974467a8fd16dc53 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 7 Dec 2025 20:53:19 +0900 Subject: [PATCH 158/215] Added sbt project description --- build.sbt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index 20f9c0d08..ae7ece4b1 100644 --- a/build.sbt +++ b/build.sbt @@ -145,7 +145,7 @@ lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) - .module("aws-authentication-plugin", "") + .module("aws-authentication-plugin", "Project for the plugin used with Aurora IAM authentication") .settings( libraryDependencies ++= Seq( "co.fs2" %%% "fs2-core" % "3.12.2", @@ -220,7 +220,7 @@ lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) ) .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) .nativeSettings(Test / nativeBrewFormulas += "s2n") - .dependsOn(connector, queryBuilder, schema, awsAuthenticationPlugin) + .dependsOn(connector, queryBuilder, schema) .enablePlugins(NoPublishPlugin) lazy val benchmark = (project in file("benchmark")) From 0c19d577bdf9410085ac990f027740ee51e5e671 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 19:53:22 +0900 Subject: [PATCH 159/215] Delete unused --- .../internal/WebIdentityCredentialsUtilsTest.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala index e72ec7ecf..3d47e0226 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala @@ -6,18 +6,16 @@ package ldbc.amazon.auth.credentials.internal -import java.net.URI import java.time.Instant import cats.effect.{ IO, Ref } -import cats.effect.std.UUIDGen import fs2.io.file.Files import munit.CatsEffectSuite import ldbc.amazon.auth.credentials.{ AwsSessionCredentials, WebIdentityTokenCredentialProperties } -import ldbc.amazon.client.{ HttpClient, HttpResponse, StsClient } +import ldbc.amazon.client.StsClient import ldbc.amazon.exception.{ InvalidTokenException, StsException } class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: From 1620f2b09198514c1c1f7659508738865089c41a Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 21:30:17 +0900 Subject: [PATCH 160/215] Change not use null --- .../ldbc/amazon/client/SimpleHttpClient.scala | 14 +++++++------- .../ldbc/amazon/client/SimpleHttpClient.scala | 14 +++++++------- .../ldbc/amazon/client/SimpleHttpClient.scala | 14 +++++++------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index c72ca51be..a40507b8c 100644 --- a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -43,7 +43,7 @@ class SimpleHttpClient[F[_]: Network: Async]( extends HttpClient[F]: private def isHttps(uri: URI): Boolean = - uri.getScheme != null && uri.getScheme.toLowerCase == "https" + Option(uri.getScheme).exists(_.toLowerCase == "https") private def getDefaultPort(uri: URI): Int = if uri.getPort > 0 then uri.getPort @@ -51,20 +51,20 @@ class SimpleHttpClient[F[_]: Network: Async]( else 80 private def validateScheme(uri: URI): F[Unit] = - uri.getScheme match - case null => ev.raiseError(new SdkClientException("URI scheme is required")) - case scheme if scheme.toLowerCase == "http" => + Option(uri.getScheme) match + case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit - case scheme if scheme.toLowerCase == "https" => ev.unit - case unsupported => + case Some(scheme) if scheme.toLowerCase == "https" => ev.unit + case Some(unsupported) => ev.raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS - if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then + if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then ev.raiseError( new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index d2eba08ae..8b6943f63 100644 --- a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -45,7 +45,7 @@ class SimpleHttpClient[F[_]: Network: Async]( extends HttpClient[F]: private def isHttps(uri: URI): Boolean = - uri.getScheme != null && uri.getScheme.toLowerCase == "https" + Option(uri.getScheme).exists(_.toLowerCase == "https") private def getDefaultPort(uri: URI): Int = if uri.getPort > 0 then uri.getPort @@ -53,20 +53,20 @@ class SimpleHttpClient[F[_]: Network: Async]( else 80 private def validateScheme(uri: URI): F[Unit] = - uri.getScheme match - case null => ev.raiseError(new SdkClientException("URI scheme is required")) - case scheme if scheme.toLowerCase == "http" => + Option(uri.getScheme) match + case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit - case scheme if scheme.toLowerCase == "https" => ev.unit - case unsupported => + case Some(scheme) if scheme.toLowerCase == "https" => ev.unit + case Some(unsupported) => ev.raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS - if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then + if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then ev.raiseError( new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 151cf203f..ed79fa579 100644 --- a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -43,7 +43,7 @@ class SimpleHttpClient[F[_]: Network: Async]( extends HttpClient[F]: private def isHttps(uri: URI): Boolean = - uri.getScheme != null && uri.getScheme.toLowerCase == "https" + Option(uri.getScheme).exists(_.toLowerCase == "https") private def getDefaultPort(uri: URI): Int = if uri.getPort > 0 then uri.getPort @@ -51,20 +51,20 @@ class SimpleHttpClient[F[_]: Network: Async]( else 80 private def validateScheme(uri: URI): F[Unit] = - uri.getScheme match - case null => ev.raiseError(new SdkClientException("URI scheme is required")) - case scheme if scheme.toLowerCase == "http" => + Option(uri.getScheme) match + case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit - case scheme if scheme.toLowerCase == "https" => ev.unit - case unsupported => + case Some(scheme) if scheme.toLowerCase == "https" => ev.unit + case Some(unsupported) => ev.raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS - if uri.getHost != null && uri.getHost.contains(".amazonaws.com") && !isHttps(uri) then + if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then ev.raiseError( new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) From 8251e2fb7410e2b9bb8abdc9670843bfdaf019b9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 21:31:32 +0900 Subject: [PATCH 161/215] Action sbt scalafmtAll --- .../src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala | 4 ++-- .../src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala | 4 ++-- .../src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index a40507b8c..53f0a9504 100644 --- a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -52,12 +52,12 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = Option(uri.getScheme) match - case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case None => ev.raiseError(new SdkClientException("URI scheme is required")) case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit case Some(scheme) if scheme.toLowerCase == "https" => ev.unit - case Some(unsupported) => + case Some(unsupported) => ev.raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 8b6943f63..ca34fb050 100644 --- a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -54,12 +54,12 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = Option(uri.getScheme) match - case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case None => ev.raiseError(new SdkClientException("URI scheme is required")) case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit case Some(scheme) if scheme.toLowerCase == "https" => ev.unit - case Some(unsupported) => + case Some(unsupported) => ev.raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index ed79fa579..11cc6d9ac 100644 --- a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -52,12 +52,12 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = Option(uri.getScheme) match - case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case None => ev.raiseError(new SdkClientException("URI scheme is required")) case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints ev.unit case Some(scheme) if scheme.toLowerCase == "https" => ev.unit - case Some(unsupported) => + case Some(unsupported) => ev.raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) From 2dce21a4ceeffa67ef6dc6601f4a8125ec3d6001 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 21:41:52 +0900 Subject: [PATCH 162/215] Change use MonadThrow -> Async --- .../ldbc/amazon/client/SimpleHttpClient.scala | 30 +++++++++---------- .../ldbc/amazon/client/SimpleHttpClient.scala | 30 +++++++++---------- .../ldbc/amazon/client/SimpleHttpClient.scala | 30 +++++++++---------- 3 files changed, 42 insertions(+), 48 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 53f0a9504..31f5fa55d 100644 --- a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -13,7 +13,6 @@ import scala.concurrent.duration.* import com.comcast.ip4s.* import cats.syntax.all.* -import cats.MonadThrow import cats.effect.* import cats.effect.syntax.all.* @@ -39,8 +38,7 @@ import ldbc.amazon.exception.* class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration -)(using ev: MonadThrow[F]) - extends HttpClient[F]: +) extends HttpClient[F]: private def isHttps(uri: URI): Boolean = Option(uri.getScheme).exists(_.toLowerCase == "https") @@ -52,23 +50,23 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = Option(uri.getScheme) match - case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints - ev.unit - case Some(scheme) if scheme.toLowerCase == "https" => ev.unit + Async[F].unit + case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit case Some(unsupported) => - ev.raiseError( + Async[F].raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then - ev.raiseError( + Async[F].raiseError( new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) - else ev.unit + else Async[F].unit private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then @@ -123,8 +121,8 @@ class SimpleHttpClient[F[_]: Network: Async]( private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = for - h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) - p <- ev.fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) yield SocketAddress(h, p) private def sendRequest( @@ -162,9 +160,9 @@ class SimpleHttpClient[F[_]: Network: Async]( line.split(" ").toList match case _ :: code :: _ => code.toIntOption match - case Some(c) => ev.pure(c) - case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) - case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + case Some(c) => Async[F].pure(c) + case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) private def parseHeaderLine(line: String): Option[(String, String)] = line.split(": ", 2).toList match @@ -181,9 +179,9 @@ class SimpleHttpClient[F[_]: Network: Async]( val headers = headerLines.flatMap(parseHeaderLine).toMap val body = bodyLines.drop(1).mkString("\r\n") // drop empty line - ev.pure(HttpResponse(statusCode, headers, body)) + Async[F].pure(HttpResponse(statusCode, headers, body)) case _ => - ev.raiseError(CredentialsFetchError("Empty response")) + Async[F].raiseError(CredentialsFetchError("Empty response")) private def receiveResponse(socket: Socket[F]): F[HttpResponse] = socket.reads diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index ca34fb050..1c52632a7 100644 --- a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -15,7 +15,6 @@ import scala.concurrent.duration.* import com.comcast.ip4s.* import cats.syntax.all.* -import cats.MonadThrow import cats.effect.* import cats.effect.syntax.all.* @@ -41,8 +40,7 @@ import ldbc.amazon.exception.* class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration -)(using ev: MonadThrow[F]) - extends HttpClient[F]: +) extends HttpClient[F]: private def isHttps(uri: URI): Boolean = Option(uri.getScheme).exists(_.toLowerCase == "https") @@ -54,23 +52,23 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = Option(uri.getScheme) match - case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints - ev.unit - case Some(scheme) if scheme.toLowerCase == "https" => ev.unit + Async[F].unit + case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit case Some(unsupported) => - ev.raiseError( + Async[F].raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then - ev.raiseError( + Async[F].raiseError( new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) - else ev.unit + else Async[F].unit private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then @@ -125,8 +123,8 @@ class SimpleHttpClient[F[_]: Network: Async]( private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = for - h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) - p <- ev.fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) yield SocketAddress(h, p) private def sendRequest( @@ -164,9 +162,9 @@ class SimpleHttpClient[F[_]: Network: Async]( line.split(" ").toList match case _ :: code :: _ => code.toIntOption match - case Some(c) => ev.pure(c) - case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) - case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + case Some(c) => Async[F].pure(c) + case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) private def parseHeaderLine(line: String): Option[(String, String)] = line.split(": ", 2).toList match @@ -183,9 +181,9 @@ class SimpleHttpClient[F[_]: Network: Async]( val headers = headerLines.flatMap(parseHeaderLine).toMap val body = bodyLines.drop(1).mkString("\r\n") // drop empty line - ev.pure(HttpResponse(statusCode, headers, body)) + Async[F].pure(HttpResponse(statusCode, headers, body)) case _ => - ev.raiseError(CredentialsFetchError("Empty response")) + Async[F].raiseError(CredentialsFetchError("Empty response")) private def receiveResponse(socket: Socket[F]): F[HttpResponse] = socket.reads diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 11cc6d9ac..3a3792486 100644 --- a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -13,7 +13,6 @@ import scala.concurrent.duration.* import com.comcast.ip4s.* import cats.syntax.all.* -import cats.MonadThrow import cats.effect.* import cats.effect.syntax.all.* @@ -39,8 +38,7 @@ import ldbc.amazon.exception.* class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration -)(using ev: MonadThrow[F]) - extends HttpClient[F]: +) extends HttpClient[F]: private def isHttps(uri: URI): Boolean = Option(uri.getScheme).exists(_.toLowerCase == "https") @@ -52,23 +50,23 @@ class SimpleHttpClient[F[_]: Network: Async]( private def validateScheme(uri: URI): F[Unit] = Option(uri.getScheme) match - case None => ev.raiseError(new SdkClientException("URI scheme is required")) + case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) case Some(scheme) if scheme.toLowerCase == "http" => // Log warning for HTTP usage, but allow it for non-sensitive endpoints - ev.unit - case Some(scheme) if scheme.toLowerCase == "https" => ev.unit + Async[F].unit + case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit case Some(unsupported) => - ev.raiseError( + Async[F].raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) private def validateSecurityRequirements(uri: URI): F[Unit] = // AWS endpoints should always use HTTPS if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then - ev.raiseError( + Async[F].raiseError( new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) - else ev.unit + else Async[F].unit private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then @@ -123,8 +121,8 @@ class SimpleHttpClient[F[_]: Network: Async]( private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = for - h <- ev.fromOption(Host.fromString(host), new SdkClientException("Invalid host")) - p <- ev.fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) yield SocketAddress(h, p) private def sendRequest( @@ -162,9 +160,9 @@ class SimpleHttpClient[F[_]: Network: Async]( line.split(" ").toList match case _ :: code :: _ => code.toIntOption match - case Some(c) => ev.pure(c) - case None => ev.raiseError(new CredentialsFetchError(s"Invalid status code: $code")) - case _ => ev.raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + case Some(c) => Async[F].pure(c) + case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) private def parseHeaderLine(line: String): Option[(String, String)] = line.split(": ", 2).toList match @@ -181,9 +179,9 @@ class SimpleHttpClient[F[_]: Network: Async]( val headers = headerLines.flatMap(parseHeaderLine).toMap val body = bodyLines.drop(1).mkString("\r\n") // drop empty line - ev.pure(HttpResponse(statusCode, headers, body)) + Async[F].pure(HttpResponse(statusCode, headers, body)) case _ => - ev.raiseError(CredentialsFetchError("Empty response")) + Async[F].raiseError(CredentialsFetchError("Empty response")) private def receiveResponse(socket: Socket[F]): F[HttpResponse] = socket.reads From fdff2d02ffa61863b1982d79f10f973cea9edbec Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 22:49:13 +0900 Subject: [PATCH 163/215] Create BasedHttpClient --- .../ldbc/amazon/client/SimpleHttpClient.scala | 171 +--------------- .../ldbc/amazon/client/SimpleHttpClient.scala | 171 +--------------- .../ldbc/amazon/client/SimpleHttpClient.scala | 171 +--------------- .../DefaultCredentialsProviderChain.scala | 2 +- .../InstanceProfileCredentialsProvider.scala | 4 +- .../ldbc/amazon/client/BasedHttpClient.scala | 188 ++++++++++++++++++ .../ldbc/amazon/client/StsClientTest.scala | 2 +- 7 files changed, 201 insertions(+), 508 deletions(-) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 31f5fa55d..df1c98246 100644 --- a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -6,23 +6,15 @@ package ldbc.amazon.client -import java.net.URI - import scala.concurrent.duration.* import com.comcast.ip4s.* -import cats.syntax.all.* - import cats.effect.* -import cats.effect.syntax.all.* -import fs2.* import fs2.io.net.* import fs2.io.net.tls.* -import ldbc.amazon.exception.* - /** * Secure HTTP client that supports both HTTP and HTTPS protocols. * @@ -35,40 +27,12 @@ import ldbc.amazon.exception.* * This addresses the security vulnerability where AWS credentials * could be sent over unencrypted HTTP connections. */ -class SimpleHttpClient[F[_]: Network: Async]( +final case class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration -) extends HttpClient[F]: - - private def isHttps(uri: URI): Boolean = - Option(uri.getScheme).exists(_.toLowerCase == "https") - - private def getDefaultPort(uri: URI): Int = - if uri.getPort > 0 then uri.getPort - else if isHttps(uri) then 443 - else 80 - - private def validateScheme(uri: URI): F[Unit] = - Option(uri.getScheme) match - case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) - case Some(scheme) if scheme.toLowerCase == "http" => - // Log warning for HTTP usage, but allow it for non-sensitive endpoints - Async[F].unit - case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit - case Some(unsupported) => - Async[F].raiseError( - new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") - ) +) extends BasedHttpClient[F]: - private def validateSecurityRequirements(uri: URI): F[Unit] = - // AWS endpoints should always use HTTPS - if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then - Async[F].raiseError( - new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") - ) - else Async[F].unit - - private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + override def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then for socket <- Network[F].client(address) @@ -79,132 +43,3 @@ class SimpleHttpClient[F[_]: Network: Async]( .build yield tlsSocket else Network[F].client(address) - - override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) - yield response - - override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) - yield response - - override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) - yield response - - private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = - for - h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) - p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) - yield SocketAddress(h, p) - - private def sendRequest( - socket: Socket[F], - method: String, - host: String, - port: Int, - isSecure: Boolean, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[Unit] = - val defaultPort = if isSecure then 443 else 80 - val hostHeader = if port == defaultPort then host else s"$host:$port" - val contentHeaders = body match { - case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) - case None => Map.empty - } - val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") - - val requestLine = s"$method $path HTTP/1.1\r\n" - val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString - val requestWithHeaders = requestLine + headerLines + "\r\n" - val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) - - Stream - .emit(fullRequest) - .through(text.utf8.encode) - .through(socket.writes) - .compile - .drain - - private def parseStatusLine(line: String): F[Int] = - // "HTTP/1.1 200 OK" -> 200 - line.split(" ").toList match - case _ :: code :: _ => - code.toIntOption match - case Some(c) => Async[F].pure(c) - case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) - case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) - - private def parseHeaderLine(line: String): Option[(String, String)] = - line.split(": ", 2).toList match - case key :: value :: Nil => Some(key -> value) - case _ => None - - private def parseHttpResponse(raw: String): F[HttpResponse] = - val lines = raw.split("\r\n").toList - - lines match - case statusLine :: rest => - parseStatusLine(statusLine).flatMap: statusCode => - val (headerLines, bodyLines) = rest.span(_.nonEmpty) - val headers = headerLines.flatMap(parseHeaderLine).toMap - val body = bodyLines.drop(1).mkString("\r\n") // drop empty line - - Async[F].pure(HttpResponse(statusCode, headers, body)) - case _ => - Async[F].raiseError(CredentialsFetchError("Empty response")) - - private def receiveResponse(socket: Socket[F]): F[HttpResponse] = - socket.reads - .through(text.utf8.decode) - .compile - .string - .flatMap(parseHttpResponse) - - private def makeRequest( - address: SocketAddress[Host], - host: String, - port: Int, - isSecure: Boolean, - method: String, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[HttpResponse] = - createSocket(address, isSecure, host) - .use { socket => - for - _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) - response <- receiveResponse(socket) - yield response - } - .timeout(connectTimeout + readTimeout) diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 1c52632a7..0ffe3efd7 100644 --- a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -6,25 +6,17 @@ package ldbc.amazon.client -import java.net.URI - import javax.net.ssl.SNIHostName import scala.concurrent.duration.* import com.comcast.ip4s.* -import cats.syntax.all.* - import cats.effect.* -import cats.effect.syntax.all.* -import fs2.* import fs2.io.net.* import fs2.io.net.tls.* -import ldbc.amazon.exception.* - /** * Secure HTTP client that supports both HTTP and HTTPS protocols. * @@ -37,40 +29,12 @@ import ldbc.amazon.exception.* * This addresses the security vulnerability where AWS credentials * could be sent over unencrypted HTTP connections. */ -class SimpleHttpClient[F[_]: Network: Async]( +final case class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration -) extends HttpClient[F]: - - private def isHttps(uri: URI): Boolean = - Option(uri.getScheme).exists(_.toLowerCase == "https") - - private def getDefaultPort(uri: URI): Int = - if uri.getPort > 0 then uri.getPort - else if isHttps(uri) then 443 - else 80 - - private def validateScheme(uri: URI): F[Unit] = - Option(uri.getScheme) match - case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) - case Some(scheme) if scheme.toLowerCase == "http" => - // Log warning for HTTP usage, but allow it for non-sensitive endpoints - Async[F].unit - case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit - case Some(unsupported) => - Async[F].raiseError( - new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") - ) +) extends BasedHttpClient[F]: - private def validateSecurityRequirements(uri: URI): F[Unit] = - // AWS endpoints should always use HTTPS - if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then - Async[F].raiseError( - new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") - ) - else Async[F].unit - - private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + override def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then for socket <- Network[F].client(address) @@ -81,132 +45,3 @@ class SimpleHttpClient[F[_]: Network: Async]( .build yield tlsSocket else Network[F].client(address) - - override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) - yield response - - override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) - yield response - - override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) - yield response - - private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = - for - h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) - p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) - yield SocketAddress(h, p) - - private def sendRequest( - socket: Socket[F], - method: String, - host: String, - port: Int, - isSecure: Boolean, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[Unit] = - val defaultPort = if isSecure then 443 else 80 - val hostHeader = if port == defaultPort then host else s"$host:$port" - val contentHeaders = body match { - case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) - case None => Map.empty - } - val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") - - val requestLine = s"$method $path HTTP/1.1\r\n" - val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString - val requestWithHeaders = requestLine + headerLines + "\r\n" - val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) - - Stream - .emit(fullRequest) - .through(text.utf8.encode) - .through(socket.writes) - .compile - .drain - - private def parseStatusLine(line: String): F[Int] = - // "HTTP/1.1 200 OK" -> 200 - line.split(" ").toList match - case _ :: code :: _ => - code.toIntOption match - case Some(c) => Async[F].pure(c) - case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) - case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) - - private def parseHeaderLine(line: String): Option[(String, String)] = - line.split(": ", 2).toList match - case key :: value :: Nil => Some(key -> value) - case _ => None - - private def parseHttpResponse(raw: String): F[HttpResponse] = - val lines = raw.split("\r\n").toList - - lines match - case statusLine :: rest => - parseStatusLine(statusLine).flatMap: statusCode => - val (headerLines, bodyLines) = rest.span(_.nonEmpty) - val headers = headerLines.flatMap(parseHeaderLine).toMap - val body = bodyLines.drop(1).mkString("\r\n") // drop empty line - - Async[F].pure(HttpResponse(statusCode, headers, body)) - case _ => - Async[F].raiseError(CredentialsFetchError("Empty response")) - - private def receiveResponse(socket: Socket[F]): F[HttpResponse] = - socket.reads - .through(text.utf8.decode) - .compile - .string - .flatMap(parseHttpResponse) - - private def makeRequest( - address: SocketAddress[Host], - host: String, - port: Int, - isSecure: Boolean, - method: String, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[HttpResponse] = - createSocket(address, isSecure, host) - .use { socket => - for - _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) - response <- receiveResponse(socket) - yield response - } - .timeout(connectTimeout + readTimeout) diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala index 3a3792486..476c732d3 100644 --- a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -6,23 +6,15 @@ package ldbc.amazon.client -import java.net.URI - import scala.concurrent.duration.* import com.comcast.ip4s.* -import cats.syntax.all.* - import cats.effect.* -import cats.effect.syntax.all.* -import fs2.* import fs2.io.net.* import fs2.io.net.tls.* -import ldbc.amazon.exception.* - /** * Secure HTTP client that supports both HTTP and HTTPS protocols. * @@ -35,40 +27,12 @@ import ldbc.amazon.exception.* * This addresses the security vulnerability where AWS credentials * could be sent over unencrypted HTTP connections. */ -class SimpleHttpClient[F[_]: Network: Async]( +final case class SimpleHttpClient[F[_]: Network: Async]( connectTimeout: Duration, readTimeout: Duration -) extends HttpClient[F]: - - private def isHttps(uri: URI): Boolean = - Option(uri.getScheme).exists(_.toLowerCase == "https") - - private def getDefaultPort(uri: URI): Int = - if uri.getPort > 0 then uri.getPort - else if isHttps(uri) then 443 - else 80 - - private def validateScheme(uri: URI): F[Unit] = - Option(uri.getScheme) match - case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) - case Some(scheme) if scheme.toLowerCase == "http" => - // Log warning for HTTP usage, but allow it for non-sensitive endpoints - Async[F].unit - case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit - case Some(unsupported) => - Async[F].raiseError( - new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") - ) +) extends BasedHttpClient[F]: - private def validateSecurityRequirements(uri: URI): F[Unit] = - // AWS endpoints should always use HTTPS - if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then - Async[F].raiseError( - new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") - ) - else Async[F].unit - - private def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + override def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = if isSecure then for socket <- Network[F].client(address) @@ -79,132 +43,3 @@ class SimpleHttpClient[F[_]: Network: Async]( .build yield tlsSocket else Network[F].client(address) - - override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) - yield response - - override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) - yield response - - override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = - for - _ <- validateScheme(uri) - _ <- validateSecurityRequirements(uri) - host = uri.getHost - port = getDefaultPort(uri) - isSecure = isHttps(uri) - path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") - address <- resolveAddress(host, port) - response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) - yield response - - private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = - for - h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) - p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) - yield SocketAddress(h, p) - - private def sendRequest( - socket: Socket[F], - method: String, - host: String, - port: Int, - isSecure: Boolean, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[Unit] = - val defaultPort = if isSecure then 443 else 80 - val hostHeader = if port == defaultPort then host else s"$host:$port" - val contentHeaders = body match { - case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) - case None => Map.empty - } - val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") - - val requestLine = s"$method $path HTTP/1.1\r\n" - val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString - val requestWithHeaders = requestLine + headerLines + "\r\n" - val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) - - Stream - .emit(fullRequest) - .through(text.utf8.encode) - .through(socket.writes) - .compile - .drain - - private def parseStatusLine(line: String): F[Int] = - // "HTTP/1.1 200 OK" -> 200 - line.split(" ").toList match - case _ :: code :: _ => - code.toIntOption match - case Some(c) => Async[F].pure(c) - case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) - case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) - - private def parseHeaderLine(line: String): Option[(String, String)] = - line.split(": ", 2).toList match - case key :: value :: Nil => Some(key -> value) - case _ => None - - private def parseHttpResponse(raw: String): F[HttpResponse] = - val lines = raw.split("\r\n").toList - - lines match - case statusLine :: rest => - parseStatusLine(statusLine).flatMap: statusCode => - val (headerLines, bodyLines) = rest.span(_.nonEmpty) - val headers = headerLines.flatMap(parseHeaderLine).toMap - val body = bodyLines.drop(1).mkString("\r\n") // drop empty line - - Async[F].pure(HttpResponse(statusCode, headers, body)) - case _ => - Async[F].raiseError(CredentialsFetchError("Empty response")) - - private def receiveResponse(socket: Socket[F]): F[HttpResponse] = - socket.reads - .through(text.utf8.decode) - .compile - .string - .flatMap(parseHttpResponse) - - private def makeRequest( - address: SocketAddress[Host], - host: String, - port: Int, - isSecure: Boolean, - method: String, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[HttpResponse] = - createSocket(address, isSecure, host) - .use { socket => - for - _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) - response <- receiveResponse(socket) - yield response - } - .timeout(connectTimeout + readTimeout) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala index 7800da9bc..baa345757 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -178,7 +178,7 @@ object DefaultCredentialsProviderChain: def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( region: String ): DefaultCredentialsProviderChain[F] = - val httpClient = new SimpleHttpClient[F]( + val httpClient = SimpleHttpClient[F]( connectTimeout = 1.second, readTimeout = 2.seconds ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala index 45278e7f8..d32cbaea4 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -391,7 +391,7 @@ object InstanceProfileCredentialsProvider: */ private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = Async[F].pure( - new SimpleHttpClient[F]( + SimpleHttpClient[F]( connectTimeout = 2.seconds, readTimeout = 5.seconds ) @@ -418,7 +418,7 @@ object InstanceProfileCredentialsProvider: yield available private def checkMetadataServiceAvailability[F[_]: Network: Async](): F[Boolean] = - val httpClient = new SimpleHttpClient[F]( + val httpClient = SimpleHttpClient[F]( connectTimeout = 1.second, readTimeout = 2.seconds ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala new file mode 100644 index 000000000..89a25c326 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala @@ -0,0 +1,188 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.syntax.all.* + +import cats.effect.* +import cats.effect.syntax.all.* + +import fs2.* +import fs2.io.net.* + +import ldbc.amazon.exception.* + +private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: + + def connectTimeout: Duration + def readTimeout: Duration + + private def isHttps(uri: URI): Boolean = + Option(uri.getScheme).exists(_.toLowerCase == "https") + + private def getDefaultPort(uri: URI): Int = + if uri.getPort > 0 then uri.getPort + else if isHttps(uri) then 443 + else 80 + + private def validateScheme(uri: URI): F[Unit] = + Option(uri.getScheme) match + case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) + case Some(scheme) if scheme.toLowerCase == "http" => + // Log warning for HTTP usage, but allow it for non-sensitive endpoints + Async[F].unit + case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit + case Some(unsupported) => + Async[F].raiseError( + new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") + ) + + private def validateSecurityRequirements(uri: URI): F[Unit] = + // AWS endpoints should always use HTTPS + if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then + Async[F].raiseError( + new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${uri.getScheme}://${uri.getHost}") + ) + else Async[F].unit + + def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] + + + override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) + yield response + + override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + + override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) + yield response + + private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = + for + h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + yield SocketAddress(h, p) + + private def sendRequest( + socket: Socket[F], + method: String, + host: String, + port: Int, + isSecure: Boolean, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[Unit] = + val defaultPort = if isSecure then 443 else 80 + val hostHeader = if port == defaultPort then host else s"$host:$port" + val contentHeaders = body match { + case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) + case None => Map.empty + } + val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") + + val requestLine = s"$method $path HTTP/1.1\r\n" + val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString + val requestWithHeaders = requestLine + headerLines + "\r\n" + val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) + + Stream + .emit(fullRequest) + .through(text.utf8.encode) + .through(socket.writes) + .compile + .drain + + private def parseStatusLine(line: String): F[Int] = + // "HTTP/1.1 200 OK" -> 200 + line.split(" ").toList match + case _ :: code :: _ => + code.toIntOption match + case Some(c) => Async[F].pure(c) + case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + + private def parseHeaderLine(line: String): Option[(String, String)] = + line.split(": ", 2).toList match + case key :: value :: Nil => Some(key -> value) + case _ => None + + private def parseHttpResponse(raw: String): F[HttpResponse] = + val lines = raw.split("\r\n").toList + + lines match + case statusLine :: rest => + parseStatusLine(statusLine).flatMap: statusCode => + val (headerLines, bodyLines) = rest.span(_.nonEmpty) + val headers = headerLines.flatMap(parseHeaderLine).toMap + val body = bodyLines.drop(1).mkString("\r\n") // drop empty line + + Async[F].pure(HttpResponse(statusCode, headers, body)) + case _ => + Async[F].raiseError(CredentialsFetchError("Empty response")) + + private def receiveResponse(socket: Socket[F]): F[HttpResponse] = + socket.reads + .through(text.utf8.decode) + .compile + .string + .flatMap(parseHttpResponse) + + private def makeRequest( + address: SocketAddress[Host], + host: String, + port: Int, + isSecure: Boolean, + method: String, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[HttpResponse] = + createSocket(address, isSecure, host) + .use { socket => + for + _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) + response <- receiveResponse(socket) + yield response + } + .timeout(connectTimeout + readTimeout) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala index 4490aeb52..e427e2029 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala @@ -32,7 +32,7 @@ class StsClientTest extends CatsEffectSuite: // HTTP client for testing private def httpClient: SimpleHttpClient[IO] = - new SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) + SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) // STS client configured for LocalStack private def localStackStsClient: StsClient[IO] = StsClient.build(localStackEndpoint, httpClient) From e4361b785c2a9ccf6769da38cf9c425f6052b52d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 22:49:29 +0900 Subject: [PATCH 164/215] Action sbt scalafmtAll --- .../ldbc/amazon/client/BasedHttpClient.scala | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala index 89a25c326..56c89fea3 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala @@ -25,7 +25,7 @@ import ldbc.amazon.exception.* private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: def connectTimeout: Duration - def readTimeout: Duration + def readTimeout: Duration private def isHttps(uri: URI): Boolean = Option(uri.getScheme).exists(_.toLowerCase == "https") @@ -42,7 +42,7 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: // Log warning for HTTP usage, but allow it for non-sensitive endpoints Async[F].unit case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit - case Some(unsupported) => + case Some(unsupported) => Async[F].raiseError( new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") ) @@ -51,13 +51,12 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: // AWS endpoints should always use HTTPS if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then Async[F].raiseError( - new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${uri.getScheme}://${uri.getHost}") + new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") ) else Async[F].unit def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] - override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = for _ <- validateScheme(uri) @@ -66,7 +65,7 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: port = getDefaultPort(uri) isSecure = isHttps(uri) path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) yield response @@ -79,7 +78,7 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: port = getDefaultPort(uri) isSecure = isHttps(uri) path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) yield response @@ -92,7 +91,7 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: port = getDefaultPort(uri) isSecure = isHttps(uri) path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + - Option(uri.getQuery).map("?" + _).getOrElse("") + Option(uri.getQuery).map("?" + _).getOrElse("") address <- resolveAddress(host, port) response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) yield response @@ -104,15 +103,15 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: yield SocketAddress(h, p) private def sendRequest( - socket: Socket[F], - method: String, - host: String, - port: Int, - isSecure: Boolean, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[Unit] = + socket: Socket[F], + method: String, + host: String, + port: Int, + isSecure: Boolean, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[Unit] = val defaultPort = if isSecure then 443 else 80 val hostHeader = if port == defaultPort then host else s"$host:$port" val contentHeaders = body match { @@ -169,15 +168,15 @@ private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: .flatMap(parseHttpResponse) private def makeRequest( - address: SocketAddress[Host], - host: String, - port: Int, - isSecure: Boolean, - method: String, - path: String, - headers: Map[String, String], - body: Option[String] - ): F[HttpResponse] = + address: SocketAddress[Host], + host: String, + port: Int, + isSecure: Boolean, + method: String, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[HttpResponse] = createSocket(address, isSecure, host) .use { socket => for From 0f21e9de9d2467aaab0ec59eb1f0696a1fed27bf Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 23:02:21 +0900 Subject: [PATCH 165/215] Remove meaningless try-catch blocks --- .../scala/ldbc/amazon/client/StsClient.scala | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index b1deb01c5..1874ea6b0 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -93,7 +93,6 @@ object StsClient: ): F[AssumeRoleWithWebIdentityResponse] = for _ <- validateRoleArn(request.roleArn) - timestamp <- Concurrent[F].fromEither(getCurrentTimestamp()) sessionName <- request.roleSessionName.fold( UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") )(Concurrent[F].pure) @@ -111,7 +110,7 @@ object StsClient: headers = Map( "Content-Type" -> "application/x-www-form-urlencoded", "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRoleWithWebIdentity", - "X-Amz-Date" -> timestamp + "X-Amz-Date" -> getCurrentTimestamp ) response <- httpClient.post(URI.create(stsEndpoint), headers, requestBody) @@ -185,19 +184,13 @@ object StsClient: * yyyyMMddTHHmmssZ (e.g., "20231201T120000Z"). The timestamp is always * generated in UTC timezone. * - * @return Either an error if timestamp generation fails, or the formatted timestamp string + * @return formatted timestamp string */ - private def getCurrentTimestamp(): Either[Throwable, String] = - try { - Right( - DateTimeFormatter - .ofPattern("yyyyMMdd'T'HHmmss'Z'") - .withZone(ZoneOffset.UTC) - .format(Instant.now()) - ) - } catch { - case ex: Exception => Left(new SdkClientException("Failed to generate timestamp")) - } + private def getCurrentTimestamp: String = + DateTimeFormatter + .ofPattern("yyyyMMdd'T'HHmmss'Z'") + .withZone(ZoneOffset.UTC) + .format(Instant.now()) /** * Validates HTTP response status. @@ -210,15 +203,14 @@ object StsClient: * @return Unit if status is successful, raises StsException otherwise */ private def validateHttpResponse[F[_]: MonadThrow](response: HttpResponse): F[Unit] = - if response.statusCode >= 200 && response.statusCode < 300 then { + if response.statusCode >= 200 && response.statusCode < 300 then MonadThrow[F].unit - } else { + else MonadThrow[F].raiseError( new StsException( s"STS request failed with status ${ response.statusCode }: ${ response.body }" ) ) - } /** * Parses STS XML response to extract credentials. From d3290e6b958c39aa0573f2c41a8aa2b89c4dec55 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 23:08:40 +0900 Subject: [PATCH 166/215] Update ProfileCredentialsProviderTest --- .../scala/ldbc/amazon/client/StsClient.scala | 8 ++++---- .../ProfileCredentialsProviderTest.scala | 16 ++++++---------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 1874ea6b0..7601a864c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -16,7 +16,7 @@ import cats.MonadThrow import cats.effect.std.UUIDGen import cats.effect.Concurrent -import ldbc.amazon.exception.{ SdkClientException, StsException } +import ldbc.amazon.exception.StsException import ldbc.amazon.util.SimpleXmlParser /** @@ -179,14 +179,14 @@ object StsClient: /** * Gets current timestamp in AWS format. - * + * * Generates a timestamp string in the ISO 8601 format required by AWS: * yyyyMMddTHHmmssZ (e.g., "20231201T120000Z"). The timestamp is always * generated in UTC timezone. - * + * * @return formatted timestamp string */ - private def getCurrentTimestamp: String = + private def getCurrentTimestamp: String = DateTimeFormatter .ofPattern("yyyyMMdd'T'HHmmss'Z'") .withZone(ZoneOffset.UTC) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala index 856ba5aac..11a672595 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala @@ -15,8 +15,6 @@ import fs2.io.file.Files import munit.CatsEffectSuite -import ldbc.amazon.exception.SdkClientException - import ProfileCredentialsProvider.* class ProfileCredentialsProviderTest extends CatsEffectSuite: @@ -39,14 +37,12 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: test("ProfileCredentialsProvider fails when user.home is missing") { given SystemProperties[IO] = mockSystemProperties(homeDir = None) - for - provider <- ProfileCredentialsProvider.default[IO]() - result <- provider.resolveCredentials().attempt - yield result match - case Left(exception: SdkClientException) => - // Should fail because user.home is not available - assert(true) // Test passed - case _ => fail("Expected SdkClientException") + assertIOBoolean( + for + provider <- ProfileCredentialsProvider.default[IO]() + result <- provider.resolveCredentials().attempt + yield result.isLeft + ) } test("ProfileCredentialsProvider companion object methods work") { From 45978ba672c543cb08ced602c5366c3a70bcf36e Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 14 Dec 2025 23:09:00 +0900 Subject: [PATCH 167/215] Action sbt scalafmtAll --- .../shared/src/main/scala/ldbc/amazon/client/StsClient.scala | 3 +-- .../auth/credentials/ProfileCredentialsProviderTest.scala | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala index 7601a864c..80af00aee 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -203,8 +203,7 @@ object StsClient: * @return Unit if status is successful, raises StsException otherwise */ private def validateHttpResponse[F[_]: MonadThrow](response: HttpResponse): F[Unit] = - if response.statusCode >= 200 && response.statusCode < 300 then - MonadThrow[F].unit + if response.statusCode >= 200 && response.statusCode < 300 then MonadThrow[F].unit else MonadThrow[F].raiseError( new StsException( diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala index 11a672595..9346e8cf5 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala @@ -40,7 +40,7 @@ class ProfileCredentialsProviderTest extends CatsEffectSuite: assertIOBoolean( for provider <- ProfileCredentialsProvider.default[IO]() - result <- provider.resolveCredentials().attempt + result <- provider.resolveCredentials().attempt yield result.isLeft ) } From 50f054d2addecd851c8ae7e0786025be2f88cabf Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 22:52:13 +0900 Subject: [PATCH 168/215] Create authentication plugin project --- build.sbt | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 3819c64de..936a61c04 100644 --- a/build.sbt +++ b/build.sbt @@ -120,6 +120,21 @@ lazy val jdbcConnector = crossProject(JVMPlatform) .defaultSettings .dependsOn(core) +lazy val authenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) + .crossType(CrossType.Full) + .module("authentication-plugin", "MySQL authentication plugin written in pure Scala3") + .settings( + libraryDependencies ++= Seq( + "org.typelevel" %%% "cats-core" % "2.10.0", + "org.scodec" %%% "scodec-bits" % "1.1.38", + ) + ) + .jsSettings( + Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) + ) + .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) + .nativeSettings(Test / nativeBrewFormulas += "s2n") + lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) .module("connector", "MySQL connector written in pure Scala3") @@ -141,7 +156,7 @@ lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) ) .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) .nativeSettings(Test / nativeBrewFormulas += "s2n") - .dependsOn(core) + .dependsOn(core, authenticationPlugin) lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) @@ -159,6 +174,7 @@ lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativeP ) .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) .nativeSettings(Test / nativeBrewFormulas += "s2n") + .dependsOn(authenticationPlugin) lazy val plugin = LepusSbtPluginProject("ldbc-plugin", "plugin") .settings(description := "Projects that provide sbt plug-ins") @@ -418,6 +434,7 @@ lazy val ldbc = tlCrossRootProject schema, codegen, zioInterop, + authenticationPlugin, awsAuthenticationPlugin, plugin, tests, From 4d11608291000c00226ae5daed2991b9b03913bd Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 22:53:55 +0900 Subject: [PATCH 169/215] Create AuthenticationPlugin --- .../plugin/AuthenticationPlugin.scala | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala new file mode 100644 index 000000000..5e6233930 --- /dev/null +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import scodec.bits.ByteVector + +/** + * A trait representing a MySQL authentication plugin for database connections. + * + * This trait defines the contract for various authentication mechanisms supported by MySQL, + * including traditional password-based authentication (mysql_native_password) and + * modern authentication methods like mysql_clear_password for IAM authentication. + * + * Authentication plugins are used during the MySQL handshake process to validate + * client credentials and establish secure database connections. + * + * @tparam F The effect type that wraps the authentication operations + */ +trait AuthenticationPlugin[F[_]]: + + /** + * The name of the authentication plugin as recognized by the MySQL server. + * + * Common plugin names include: + * - "mysql_native_password" for traditional SHA1-based password authentication + * - "mysql_clear_password" for plaintext password transmission over SSL + * - "caching_sha2_password" for SHA256-based password authentication + * - "mysql_old_password" for legacy MySQL authentication (deprecated) + * + * @return The plugin name string that identifies this authentication method + */ + def name: PluginName + + /** + * Indicates whether this authentication plugin requires a secure (encrypted) connection. + * + * Some authentication plugins, particularly those that transmit passwords in cleartext + * (like mysql_clear_password), require SSL/TLS encryption to ensure data security. + * Traditional hashing-based plugins may optionally use encryption but don't strictly require it. + * + * @return true if SSL/TLS connection is mandatory for this plugin, false otherwise + */ + def requiresConfidentiality: Boolean + + /** + * Processes the password according to the authentication plugin's requirements. + * + * Different authentication plugins handle passwords differently: + * - mysql_native_password: Performs SHA1-based hashing with the server's scramble + * - mysql_clear_password: Returns the password as plaintext bytes (requires SSL) + * - caching_sha2_password: Performs SHA256-based hashing with salt + * + * @param password The user's password in plaintext + * @param scramble The random challenge bytes sent by the MySQL server during handshake. + * Used as salt/seed for cryptographic hashing in most authentication methods. + * May be ignored by plugins that don't use server-side challenges. + * @return The processed password data wrapped in the effect type F, ready for transmission + * to the MySQL server during authentication + */ + def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] From 404e7a1c69a97ac479e01696bd6f8d64bcef2e1e Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 22:54:07 +0900 Subject: [PATCH 170/215] Create PluginName --- .../ldbc/authentication/plugin/pacakge.scala | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala new file mode 100644 index 000000000..329665b29 --- /dev/null +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication + +package object plugin: + + opaque type PluginName = String + + // Standard MySQL Authentication Plugins + val MYSQL_CLEAR_PASSWORD: PluginName = "mysql_clear_password" + val MYSQL_NATIVE_PASSWORD: PluginName = "mysql_native_password" + val SHA256_PASSWORD: PluginName = "sha256_password" + val CACHING_SHA2_PASSWORD: PluginName = "caching_sha2_password" + + // Legacy Authentication Plugin (deprecated) + val MYSQL_OLD_PASSWORD: PluginName = "mysql_old_password" + + // Authentication plugins for external authentication + val AUTHENTICATION_WINDOWS: PluginName = "authentication_windows" + val AUTHENTICATION_PAM: PluginName = "authentication_pam" + val AUTHENTICATION_LDAP_SIMPLE: PluginName = "authentication_ldap_simple" + val AUTHENTICATION_LDAP_SASL: PluginName = "authentication_ldap_sasl" + + // Kerberos authentication + val AUTHENTICATION_KERBEROS: PluginName = "authentication_kerberos" + + // FIDO authentication (MySQL 8.0.27+) + val AUTHENTICATION_FIDO: PluginName = "authentication_fido" + + // Multi-factor authentication (MySQL 8.0.27+) + val AUTHENTICATION_WEBAUTHN: PluginName = "authentication_webauthn" + + // No login authentication plugin + val MYSQL_NO_LOGIN: PluginName = "mysql_no_login" + + // Test plugins (for testing purposes) + val TEST_PLUGIN_SERVER: PluginName = "test_plugin_server" + val DAEMON_EXAMPLE: PluginName = "daemon_example" + + // Socket peer-credential authentication (Unix socket) + val AUTH_SOCKET: PluginName = "auth_socket" From 05195808d7e8dd5fbf920c1dbb1045a36ffb7a47 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 22:54:19 +0900 Subject: [PATCH 171/215] Create MysqlClearPasswordPlugin --- .../plugin/MysqlClearPasswordPlugin.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala new file mode 100644 index 000000000..5ee0f8e54 --- /dev/null +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala @@ -0,0 +1,28 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import cats.Applicative + +import scodec.bits.ByteVector + +import java.nio.charset.StandardCharsets + +trait MysqlClearPasswordPlugin[F[_]] extends AuthenticationPlugin[F]: + + override final def name: PluginName = MYSQL_CLEAR_PASSWORD + override final def requiresConfidentiality: Boolean = true + +object MysqlClearPasswordPlugin: + + private case class Impl[F[_]: Applicative]() extends MysqlClearPasswordPlugin[F]: + override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = + val result = if password.isEmpty then ByteVector.empty + else ByteVector(password.getBytes(StandardCharsets.UTF_8)) + Applicative[F].pure(result) + + def apply[F[_]: Applicative](): MysqlClearPasswordPlugin[F] = Impl[F]() From 604c33f4fd3b3b69b58833e9b289345a0ac44aa4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 22:54:39 +0900 Subject: [PATCH 172/215] Create EncryptPasswordPlugin --- .../plugin/EncryptPasswordPlugin.scala | 39 ++++++++++ .../plugin/EncryptPasswordPlugin.scala | 45 +++++++++++ .../plugin/EncryptPasswordPlugin.scala | 74 +++++++++++++++++++ .../ldbc/authentication/plugin/Openssl.scala | 64 ++++++++++++++++ 4 files changed, 222 insertions(+) create mode 100644 module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala create mode 100644 module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala create mode 100644 module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala create mode 100644 module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala diff --git a/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala b/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala new file mode 100644 index 000000000..03ef8ac7e --- /dev/null +++ b/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -0,0 +1,39 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import java.nio.charset.StandardCharsets + +import scala.scalajs.js +import scala.scalajs.js.typedarray.Uint8Array + +import scodec.bits.ByteVector + +trait EncryptPasswordPlugin: + + private val crypto = js.Dynamic.global.require("crypto") + + def transformation: String + + private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = + val scrambleLength = scramble.length + (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray + + def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = + val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) + val mysqlScrambleBuff = xorString(input, scramble, input.length) + encryptWithRSAPublicKey( + mysqlScrambleBuff, + publicKeyString + ) + + private def encryptWithRSAPublicKey(input: Array[Byte], publicKey: String): Array[Byte] = + val encrypted = crypto.publicEncrypt( + publicKey, + ByteVector(input).toUint8Array + ) + ByteVector.view(encrypted.asInstanceOf[Uint8Array]).toArray diff --git a/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala b/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala new file mode 100644 index 000000000..f6a500695 --- /dev/null +++ b/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import java.nio.charset.StandardCharsets +import java.security.interfaces.RSAPublicKey +import java.security.spec.X509EncodedKeySpec +import java.security.KeyFactory +import java.security.PublicKey +import java.util.Base64 + +import javax.crypto.Cipher + +trait EncryptPasswordPlugin: + + def transformation: String + + def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = + val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) + val mysqlScrambleBuff = xorString(input, scramble, input.length) + encryptWithRSAPublicKey( + mysqlScrambleBuff, + decodeRSAPublicKey(publicKeyString) + ) + + private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = + val scrambleLength = scramble.length + (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray + + private def encryptWithRSAPublicKey(input: Array[Byte], key: PublicKey): Array[Byte] = + val cipher = Cipher.getInstance(transformation) + cipher.init(Cipher.ENCRYPT_MODE, key) + cipher.doFinal(input) + + private def decodeRSAPublicKey(key: String): RSAPublicKey = + val offset = key.indexOf("\n") + 1 + val len = key.indexOf("-----END PUBLIC KEY-----") - offset + val certificateData = Base64.getMimeDecoder.decode(key.substring(offset, offset + len)) + val spec = new X509EncodedKeySpec(certificateData) + val kf = KeyFactory.getInstance("RSA") + kf.generatePublic(spec).asInstanceOf[RSAPublicKey] diff --git a/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala new file mode 100644 index 000000000..322754420 --- /dev/null +++ b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import java.nio.charset.StandardCharsets + +import scala.scalanative.unsafe.* +import scala.scalanative.unsigned.* + +import ldbc.authentication.plugin.Openssl.* + +trait EncryptPasswordPlugin: + + def transformation: String + + def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = + val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) + val mysqlScrambleBuff = xorString(input, scramble, input.length) + encryptWithRSAPublicKey( + mysqlScrambleBuff, + publicKeyString + ) + + private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = + val scrambleLength = scramble.length + (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray + + private def encryptWithRSAPublicKey(input: Array[Byte], publicKey: String): Array[Byte] = + Zone { implicit zone => + val publicKeyCStr = toCString(publicKey) + val bio = BIO_new_mem_buf(publicKeyCStr, publicKey.length) + if bio == null then throw new RuntimeException("Failed to create a new memory BIO.") + + val evpPkey = PEM_read_bio_PUBKEY(bio, null, null, null) + if evpPkey == null then throw new RuntimeException("Failed to load public key.") + + val ctx = EVP_PKEY_CTX_new(evpPkey, null) + if EVP_PKEY_encrypt_init(ctx) <= 0 then throw new RuntimeException("Failed to initialize encryption context.") + + if EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) != 1 then { + throw new RuntimeException("Failed to set RSA padding.") + } + + val sha1Md = EVP_get_digestbyname(c"sha1") + if EVP_PKEY_CTX_set_rsa_oaep_md(ctx, sha1Md) != 1 then { + throw new RuntimeException("Failed to set OAEP hash function.") + } + + if EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, sha1Md) != 1 then { + throw new RuntimeException("Failed to set MGF1 hash function.") + } + + val inputBuf = alloc[UByte](input.length.toULong) + for i <- input.indices do !(inputBuf + i) = input(i).toUByte + + val outLen = stackalloc[CSize]() + !outLen = 0.toULong + + if EVP_PKEY_encrypt(ctx, null, outLen, inputBuf, input.length.toULong) <= 0 then + throw new RuntimeException("Failed to obtain the output buffer size required for encryption.") + + val encryptedBuf = alloc[UByte](!outLen) + if EVP_PKEY_encrypt(ctx, encryptedBuf, outLen, inputBuf, input.length.toULong) <= 0 then + throw new RuntimeException("Encryption failed.") + + val result = Array.fill[Byte]((!outLen).toInt)(0) + for i <- 0 until (!outLen).toInt do result(i) = (!(encryptedBuf + i)).toInt.toByte + + result + } diff --git a/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala new file mode 100644 index 000000000..1c21dca30 --- /dev/null +++ b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import org.typelevel.scalaccompat.annotation.* + +import scala.scalanative.unsafe.* +import scala.scalanative.unsigned.* + +@nowarn212 +@link("crypto") +@extern +private[ldbc] object Openssl: + + final val RSA_PKCS1_OAEP_PADDING = 4 + + final val EVP_MAX_MD_SIZE = 64 + + type EVP_MD + type EVP_MD_CTX + type ENGINE + type BIO + type EVP_PKEY + type EVP_PKEY_CTX + + type pem_password_cb = CFuncPtr4[Ptr[Byte], CInt, CInt, Ptr[Byte], CInt] + + def EVP_get_digestbyname(name: Ptr[CChar]): Ptr[EVP_MD] = extern + + def EVP_Digest( + data: Ptr[Byte], + count: CSize, + md: Ptr[Byte], + size: Ptr[CUnsignedInt], + `type`: Ptr[EVP_MD], + impl: Ptr[ENGINE] + ): CInt = extern + + def BIO_new_mem_buf(buf: Ptr[Byte], len: CInt): Ptr[BIO] = extern + + def PEM_read_bio_PUBKEY(bp: Ptr[BIO], x: Ptr[Ptr[EVP_PKEY]], cb: pem_password_cb, u: Ptr[Byte]): Ptr[EVP_PKEY] = + extern + + def EVP_PKEY_CTX_new(pkey: Ptr[EVP_PKEY], e: Ptr[ENGINE]): Ptr[EVP_PKEY_CTX] = extern + + def EVP_PKEY_encrypt_init(ctx: Ptr[EVP_PKEY_CTX]): CInt = extern + + def EVP_PKEY_CTX_set_rsa_padding(ctx: Ptr[EVP_PKEY_CTX], padding: CInt): CInt = extern + + def EVP_PKEY_CTX_set_rsa_oaep_md(ctx: Ptr[EVP_PKEY_CTX], md: Ptr[EVP_MD]): CInt = extern + + def EVP_PKEY_CTX_set_rsa_mgf1_md(ctx: Ptr[EVP_PKEY_CTX], md: Ptr[EVP_MD]): CInt = extern + + def EVP_PKEY_encrypt( + ctx: Ptr[EVP_PKEY_CTX], + out: Ptr[UByte], + outlen: Ptr[CSize], + in: Ptr[UByte], + inlen: CSize + ): CInt = extern From df73764b5eb708b7dd7377dcc46160b485c10d0f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 22:55:43 +0900 Subject: [PATCH 173/215] Delete unused Authentication --- .../net/protocol/Authentication.scala | 37 +------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala index 8df586eb7..104e9dde2 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala @@ -6,14 +6,6 @@ package ldbc.connector.net.protocol -import cats.effect.kernel.Sync - -import fs2.hashing.Hashing - -import ldbc.connector.authenticator.* -import ldbc.connector.exception.* -import ldbc.connector.util.Version - /** * Protocol to handle the Authentication Phase * @@ -28,34 +20,7 @@ import ldbc.connector.util.Version * @tparam F * The effect type */ -trait Authentication[F[_]: Hashing: Sync]: - - /** - * Determine the authentication plugin. - * - * @param pluginName - * Plugin name - * @param version - * MySQL Server version - */ - protected def determinatePlugin(pluginName: String, version: Version): Either[SQLException, AuthenticationPlugin[F]] = - pluginName match - case "mysql_clear_password" => Right(MysqlClearPasswordPlugin[F]()) - case "mysql_native_password" => Right(MysqlNativePasswordPlugin[F]()) - case "sha256_password" => Right(Sha256PasswordPlugin[F]()) - case "caching_sha2_password" => Right(CachingSha2PasswordPlugin[F](version)) - case unknown => - Left( - new SQLInvalidAuthorizationSpecException( - s"Unknown authentication plugin: $pluginName", - detail = Some( - "This error may be due to lack of support on the ldbc side or a newly added plugin on the MySQL side." - ), - hint = Some( - "Report Issues here: https://github.com/takapi327/ldbc/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=" - ) - ) - ) +trait Authentication[F[_]]: /** * Start the authentication process. From f998baf7d4b1d766002e9a6113f7c0ba01ffa1e0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 23:38:03 +0900 Subject: [PATCH 174/215] Change use AuthenticationPlugin, EncryptPasswordPlugin for Sha256PasswordPlugin --- .../connector/authenticator/Sha256PasswordPlugin.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala index 7c77bf4bb..cb8b836fd 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala @@ -18,12 +18,11 @@ import cats.effect.kernel.Sync import fs2.hashing.HashAlgorithm import fs2.hashing.Hashing import fs2.Chunk -trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F] with Sha256PasswordPluginPlatform[F]: - protected def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = - val scrambleLength = scramble.length - (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray - override def name: String = "sha256_password" +import ldbc.authentication.plugin.* + +trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F], EncryptPasswordPlugin: + override def name: PluginName = SHA256_PASSWORD override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) From e675c3b11a41b2dea550de19763eaaf73b079a5f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 23:38:47 +0900 Subject: [PATCH 175/215] Change use PluginName --- .../connector/authenticator/Sha256PasswordPluginPlatform.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala index cc9df619d..bc61e94f4 100644 --- a/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ b/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala @@ -4,6 +4,7 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ +/* package ldbc.connector.authenticator import java.nio.charset.StandardCharsets @@ -37,3 +38,5 @@ trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => val kf = KeyFactory.getInstance("RSA") kf.generatePublic(spec).asInstanceOf[RSAPublicKey] } + + */ From 8e39800836bc22d1d31b80ff5bf51fc6c611eb5c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 23:39:21 +0900 Subject: [PATCH 176/215] Change use PluginName --- .../connector/authenticator/CachingSha2PasswordPlugin.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala index 84b405da8..dbe97fdbc 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala @@ -10,11 +10,13 @@ import cats.effect.kernel.Sync import fs2.hashing.Hashing +import ldbc.authentication.plugin.* + import ldbc.connector.util.Version trait CachingSha2PasswordPlugin[F[_]] extends Sha256PasswordPlugin[F]: - override def name: String = "caching_sha2_password" + override def name: PluginName = CACHING_SHA2_PASSWORD object CachingSha2PasswordPlugin: def apply[F[_]: Hashing: Sync](version: Version): CachingSha2PasswordPlugin[F] = From 1097123546b2ba9eaa1db1529f0afea842746243 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 23:39:43 +0900 Subject: [PATCH 177/215] Change use PluginName for MysqlNativePasswordPlugin --- .../connector/authenticator/MysqlNativePasswordPlugin.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala index d9c7dfb92..a4ba81fcc 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala @@ -19,9 +19,11 @@ import fs2.hashing.HashAlgorithm import fs2.hashing.Hashing import fs2.Chunk +import ldbc.authentication.plugin.* + class MysqlNativePasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F]: - override def name: String = "mysql_native_password" + override def name: PluginName = MYSQL_NATIVE_PASSWORD override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) From 065d7038cc45ee47ec2eced27ecded59c270105d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 23:40:27 +0900 Subject: [PATCH 178/215] Change to deprecated AuthenticationPlugin, MysqlClearPasswordPlugin --- .../ldbc/connector/authenticator/AuthenticationPlugin.scala | 1 + .../ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala index 459b7b830..a5d7a3470 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala @@ -20,6 +20,7 @@ import scodec.bits.ByteVector * * @tparam F The effect type that wraps the authentication operations */ +@deprecated("This plugin is deprecated. Please use ldbc.authentication.plugin.AuthenticationPlugin instead.", "0.5.0") trait AuthenticationPlugin[F[_]]: /** diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala index 8735ab59a..a7f2af42d 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala @@ -12,6 +12,7 @@ import scodec.bits.ByteVector import cats.effect.kernel.Sync +@deprecated("This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", "0.5.0") class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: override def name: String = "mysql_clear_password" @@ -21,4 +22,5 @@ class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: else Sync[F].delay(ByteVector(password.getBytes(StandardCharsets.UTF_8))) object MysqlClearPasswordPlugin: + @deprecated("This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", "0.5.0") def apply[F[_]: Sync](): MysqlClearPasswordPlugin[F] = new MysqlClearPasswordPlugin[F]() From 8f5f61972aa19034843c1af58ea2e1eb5f86f42f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Wed, 17 Dec 2025 23:41:26 +0900 Subject: [PATCH 179/215] Added plugins property --- .../scala/ldbc/connector/net/Protocol.scala | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index fe5a2103b..fea0885d1 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -25,7 +25,9 @@ import fs2.io.net.Socket import org.typelevel.otel4s.trace.{ Span, Tracer } import org.typelevel.otel4s.Attribute -import ldbc.connector.authenticator.* +import ldbc.authentication.plugin.* + +import ldbc.connector.authenticator.{ MysqlNativePasswordPlugin, Sha256PasswordPlugin, CachingSha2PasswordPlugin } import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.packet.* @@ -138,7 +140,7 @@ object Protocol: private val SELECT_SERVER_VARIABLES_QUERY = "SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout" - private[ldbc] case class Impl[F[_]: Async: Tracer: Hashing]( + private[ldbc] case class Impl[F[_]: Async: Tracer]( initialPacket: InitialPacket, hostInfo: HostInfo, socket: PacketSocket[F], @@ -146,7 +148,8 @@ object Protocol: allowPublicKeyRetrieval: Boolean = false, capabilityFlags: Set[CapabilitiesFlags], sequenceIdRef: Ref[F, Byte], - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: Map[String, AuthenticationPlugin[F]] )(using ev: MonadError[F, Throwable], ex: Exchange[F]) extends Protocol[F]: @@ -360,7 +363,7 @@ object Protocol: * Authentication method Switch Request Packet */ private def changeAuthenticationMethod(switchRequestPacket: AuthSwitchRequestPacket, password: String): F[Unit] = - determinatePlugin(switchRequestPacket.pluginName, initialPacket.serverVersion) match + determinatePlugin(switchRequestPacket.pluginName) match case Left(error) => ev.raiseError(error) *> socket.send(ComQuitPacket()) case Right(plugin: CachingSha2PasswordPlugin[F]) => for @@ -422,7 +425,7 @@ object Protocol: * Scramble buffer for authentication payload */ private def allowPublicKeyRetrievalRequest( - plugin: Sha256PasswordPlugin[F], + plugin: EncryptPasswordPlugin, password: String, scrambleBuff: Array[Byte] ): F[Unit] = @@ -489,7 +492,7 @@ object Protocol: capabilityFlags, username, hashedPassword.length.toByte +: hashedPassword.toArray, - plugin.name, + plugin.name.toString, initialPacket.characterSet, hostInfo.database ) @@ -511,7 +514,7 @@ object Protocol: password ) case None => - determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match + determinatePlugin(initialPacket.authPlugin) match case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) case Right(plugin) => checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk( @@ -524,7 +527,7 @@ object Protocol: override def changeUser(user: String, password: String): F[Unit] = exchange[F, Unit](TelemetrySpanName.CHANGE_USER) { (span: Span[F]) => span.addAttributes(attributes*) *> ( - determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match + determinatePlugin(initialPacket.authPlugin) match case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) case Right(plugin) => for @@ -557,6 +560,19 @@ object Protocol: span.recordException(error) *> ev.raiseError(error) else ev.unit + private def determinatePlugin(pluginName: String): Either[SQLException, AuthenticationPlugin[F]] = + plugins.get(pluginName).toRight( + new SQLInvalidAuthorizationSpecException( + s"Unknown authentication plugin: $pluginName", + detail = Some( + "This error may be due to lack of support on the ldbc side or a newly added plugin on the MySQL side." + ), + hint = Some( + "Report Issues here: https://github.com/takapi327/ldbc/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=" + ) + ) + ) + def apply[F[_]: Async: Console: Tracer: Exchange: Hashing]( sockets: Resource[F, Socket[F]], hostInfo: HostInfo, @@ -565,7 +581,8 @@ object Protocol: allowPublicKeyRetrieval: Boolean = false, readTimeout: Duration, capabilitiesFlags: Set[CapabilitiesFlags], - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: Map[String, AuthenticationPlugin[F]] ): Resource[F, Protocol[F]] = for sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) @@ -581,7 +598,8 @@ object Protocol: capabilitiesFlags, sequenceIdRef, initialPacketRef, - defaultAuthenticationPlugin + defaultAuthenticationPlugin, + plugins ) ) yield protocol @@ -594,7 +612,8 @@ object Protocol: capabilitiesFlags: Set[CapabilitiesFlags], sequenceIdRef: Ref[F, Byte], initialPacketRef: Ref[F, Option[InitialPacket]], - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: Map[String, AuthenticationPlugin[F]] )(using ev: Async[F]): F[Protocol[F]] = initialPacketRef.get.flatMap { case Some(initialPacket) => @@ -607,7 +626,12 @@ object Protocol: allowPublicKeyRetrieval, capabilitiesFlags, sequenceIdRef, - defaultAuthenticationPlugin + defaultAuthenticationPlugin, + Map( + MYSQL_NATIVE_PASSWORD.toString -> MysqlNativePasswordPlugin[F](), + SHA256_PASSWORD.toString -> Sha256PasswordPlugin[F](), + CACHING_SHA2_PASSWORD.toString -> CachingSha2PasswordPlugin[F](initialPacket.serverVersion) + ) ++ plugins ) ) case None => From 63e8a26173aff1eb7bd3e147787f2a1bef158354 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 18 Dec 2025 00:02:16 +0900 Subject: [PATCH 180/215] Added plugins property --- .../scala/ldbc/connector/Connection.scala | 19 +++++++--- .../ldbc/connector/MySQLDataSource.scala | 26 +++++++++++--- .../connector/pool/PooledDataSource.scala | 35 +++++++++++++++++-- 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index 440420e2c..dfd4d617f 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -24,7 +24,7 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.connector.authenticator.AuthenticationPlugin +import ldbc.authentication.plugin.* import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.* @@ -77,7 +77,8 @@ object Connection: useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default[F, Unit]( host, port, @@ -93,6 +94,7 @@ object Connection: useServerPrepStmts, databaseTerm, defaultAuthenticationPlugin, + plugins, unitBefore, unitAfter ) @@ -113,7 +115,8 @@ object Connection: useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default( host, port, @@ -129,6 +132,7 @@ object Connection: useServerPrepStmts, databaseTerm, defaultAuthenticationPlugin, + plugins, before, after ) @@ -148,6 +152,7 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Connection[F] => F[A], after: (A, Connection[F]) => F[Unit] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = @@ -172,6 +177,7 @@ object Connection: useServerPrepStmts, databaseTerm, defaultAuthenticationPlugin, + plugins, before, after ) @@ -192,9 +198,11 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], acquire: Connection[F] => F[A], release: (A, Connection[F]) => F[Unit] ): Resource[F, LdbcConnection[F]] = + val pluginMap = plugins.map(plugin => plugin.name.toString -> plugin).toMap val capabilityFlags = defaultCapabilityFlags ++ (if database.isDefined then Set(CapabilitiesFlags.CLIENT_CONNECT_WITH_DB) else Set.empty) ++ (if sslOptions.isDefined then Set(CapabilitiesFlags.CLIENT_SSL) else Set.empty) @@ -210,7 +218,8 @@ object Connection: allowPublicKeyRetrieval, readTimeout, capabilityFlags, - defaultAuthenticationPlugin + defaultAuthenticationPlugin, + pluginMap ) _ <- Resource.eval(protocol.startAuthentication(user, password.getOrElse(""))) serverVariables <- Resource.eval(protocol.serverVariables()) @@ -252,6 +261,7 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], acquire: Connection[F] => F[A], release: (A, Connection[F]) => F[Unit] )(using ev: Async[F]): Resource[F, LdbcConnection[F]] = @@ -281,6 +291,7 @@ object Connection: useServerPrepStmts, databaseTerm, defaultAuthenticationPlugin, + plugins, acquire, release ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index ec3f5cfea..0192364dc 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -19,7 +19,8 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.connector.authenticator.AuthenticationPlugin +import ldbc.authentication.plugin.AuthenticationPlugin + import ldbc.connector.pool.* import ldbc.DataSource @@ -49,6 +50,7 @@ import ldbc.DataSource * @param useCursorFetch whether to use cursor-based fetching for result sets * @param useServerPrepStmts whether to use server-side prepared statements * @param defaultAuthenticationPlugin The authentication plugin used first for communication with the server + * @param plugins Additional authentication plugins used for communication with the server * @param before optional hook to execute before a connection is acquired * @param after optional hook to execute after a connection is used * @@ -84,6 +86,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ) extends DataSource[F]: @@ -117,7 +120,8 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, - defaultAuthenticationPlugin = defaultAuthenticationPlugin + defaultAuthenticationPlugin = defaultAuthenticationPlugin, + plugins = plugins ) case (Some(b), None) => Connection.withBeforeAfter( @@ -136,7 +140,8 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, - defaultAuthenticationPlugin = defaultAuthenticationPlugin + defaultAuthenticationPlugin = defaultAuthenticationPlugin, + plugins = plugins ) case (None, _) => Connection( @@ -153,7 +158,8 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, - defaultAuthenticationPlugin = defaultAuthenticationPlugin + defaultAuthenticationPlugin = defaultAuthenticationPlugin, + plugins = plugins ) /** Sets the hostname or IP address of the MySQL server. @@ -259,6 +265,18 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen def setDefaultAuthenticationPlugin(defaultAuthenticationPlugin: AuthenticationPlugin[F]): MySQLDataSource[F, A] = copy(defaultAuthenticationPlugin = Some(defaultAuthenticationPlugin)) + /** + * Sets whether to authentication plugin to be used for communication with the server. + * + * @param p1 + * The authentication plugin used for communication with the server + * @param pn + * List of authentication plugins used for communication with the server + * @return a new MySQLDataSource with the updated setting + */ + def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): MySQLDataSource[F, A] = + copy(plugins = p1 :: pn.toList) + /** * Adds a before hook that will be executed when a connection is acquired. * diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala index c86b31042..68e65039b 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala @@ -23,6 +23,8 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData +import ldbc.authentication.plugin.AuthenticationPlugin + import ldbc.connector.* import ldbc.connector.exception.SQLException @@ -168,6 +170,17 @@ trait PooledDataSource[F[_]] extends DataSource[F]: */ def validateConnection(conn: Connection[F]): F[Boolean] + /** + * Sets whether to authentication plugin to be used for communication with the server. + * + * @param p1 + * The authentication plugin used for communication with the server + * @param pn + * List of authentication plugins used for communication with the server + * @return a new MySQLDataSource with the updated setting + */ + def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): PooledDataSource[F] + object PooledDataSource: private case class Impl[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( @@ -184,6 +197,7 @@ object PooledDataSource: databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None, minConnections: Int = 5, @@ -598,7 +612,8 @@ object PooledDataSource: allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + plugins = plugins ) case (Some(b), None) => Connection.withBeforeAfter( @@ -616,7 +631,8 @@ object PooledDataSource: allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + plugins = plugins ) case (None, _) => Connection( @@ -632,9 +648,22 @@ object PooledDataSource: allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + plugins = plugins ) + /** + * Sets whether to authentication plugin to be used for communication with the server. + * + * @param p1 + * The authentication plugin used for communication with the server + * @param pn + * List of authentication plugins used for communication with the server + * @return a new MySQLDataSource with the updated setting + */ + override def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): PooledDataSource[F] = + copy(plugins = p1 :: pn.toList) + private[connector] def create[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], From 015b2ceba20a7f3d368d760af59e72c5732c0e60 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 18 Dec 2025 00:02:39 +0900 Subject: [PATCH 181/215] Fixed test compile --- .../scala/ldbc/connector/ConnectionTest.scala | 3 +- .../CachingSha2PasswordPluginTest.scala | 2 +- .../net/protocol/AuthenticationTest.scala | 51 ------------------- .../pool/KeepaliveExecutorTest.scala | 3 ++ .../pool/PoolStatusReporterTest.scala | 5 ++ 5 files changed, 11 insertions(+), 53 deletions(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index fd1d58c89..f33578c87 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -16,7 +16,8 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.connector.authenticator.MysqlClearPasswordPlugin +import ldbc.authentication.plugin.MysqlClearPasswordPlugin + import ldbc.connector.exception.* class ConnectionTest extends FTestPlatform: diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala index 3396fbada..ab455b982 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala @@ -19,7 +19,7 @@ class CachingSha2PasswordPluginTest extends FTestPlatform: test("CachingSha2PasswordPlugin#name should return correct plugin name") { val plugin = CachingSha2PasswordPlugin[IO](Version(8, 0, 10)) - assertEquals(plugin.name, "caching_sha2_password") + assertEquals(plugin.name.toString, "caching_sha2_password") } test("CachingSha2PasswordPlugin for version >= 8.0.5 should use default transformation") { diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala index 114359841..c256cf8d1 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala @@ -8,11 +8,6 @@ package ldbc.connector.net.protocol import cats.effect.{ IO, Ref } -import fs2.hashing.Hashing - -import ldbc.connector.authenticator.* -import ldbc.connector.exception.* -import ldbc.connector.util.Version import ldbc.connector.FTestPlatform class AuthenticationTest extends Authentication[IO], FTestPlatform: @@ -50,52 +45,6 @@ class AuthenticationTest extends Authentication[IO], FTestPlatform: _ <- selectedPluginName.set(None) } yield () - // Test for determinatePlugin - test("determinatePlugin should return the correct plugin for mysql_native_password") { - val version = Version(8, 0, 26) - val result = determinatePlugin("mysql_native_password", version) - assert(result.isRight, "Expected Right but got Left") - assert( - result.exists(_.isInstanceOf[MysqlNativePasswordPlugin[IO]]), - s"Expected MysqlNativePasswordPlugin but got ${ result.map(_.getClass.getSimpleName) }" - ) - } - - test("determinatePlugin should return the correct plugin for sha256_password") { - val version = Version(8, 0, 26) - val result = determinatePlugin("sha256_password", version) - assert(result.isRight, "Expected Right but got Left") - assert( - result.exists(_.isInstanceOf[Sha256PasswordPlugin[IO]]), - s"Expected Sha256PasswordPlugin but got ${ result.map(_.getClass.getSimpleName) }" - ) - } - - test("determinatePlugin should return the correct plugin for caching_sha2_password") { - val version = Version(8, 0, 26) - val result = determinatePlugin("caching_sha2_password", version) - assert(result.isRight, "Expected Right but got Left") - assert( - result.exists(_.isInstanceOf[CachingSha2PasswordPlugin[IO]]), - s"Expected CachingSha2PasswordPlugin but got ${ result.map(_.getClass.getSimpleName) }" - ) - } - - test("determinatePlugin should return Left for unknown plugin") { - val version = Version(8, 0, 26) - val result = determinatePlugin("unknown_plugin", version) - assert(result.isLeft, "Expected Left but got Right") - val exception = result.left.getOrElse(throw new RuntimeException("Expected Left but got Right")) - assert( - exception.isInstanceOf[SQLInvalidAuthorizationSpecException], - s"Expected SQLInvalidAuthorizationSpecException but got ${ exception.getClass.getSimpleName }" - ) - assert( - exception.getMessage.contains("Unknown authentication plugin: unknown_plugin"), - s"Error message did not match expected. Got: ${ exception.getMessage }" - ) - } - test("startAuthentication should set the correct state") { for { _ <- resetState diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala index 2eb78302f..084e548c6 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala @@ -14,6 +14,8 @@ import cats.effect.* import ldbc.sql.* +import ldbc.authentication.plugin.AuthenticationPlugin + import ldbc.connector.* class KeepaliveExecutorTest extends FTestPlatform: @@ -148,6 +150,7 @@ class KeepaliveExecutorTest extends FTestPlatform: override def returnToPool(pooled: PooledConnection[F]) = Temporal[F].unit override def removeConnection(pooled: PooledConnection[F]) = Temporal[F].unit override def validateConnection(conn: Connection[F]) = validationFunc(conn) + override def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): PooledDataSource[F] = ??? test("KeepaliveExecutor should start and stop correctly") { for diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala index 4bc5c4fde..d971f7a09 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala @@ -14,6 +14,8 @@ import cats.effect.* import munit.CatsEffectSuite +import ldbc.authentication.plugin.AuthenticationPlugin + import ldbc.connector.Connection class PoolStatusReporterTest extends CatsEffectSuite: @@ -67,6 +69,7 @@ class PoolStatusReporterTest extends CatsEffectSuite: def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? def validateConnection(conn: Connection[IO]) = ??? + def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => @@ -109,6 +112,7 @@ class PoolStatusReporterTest extends CatsEffectSuite: def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? def validateConnection(conn: Connection[IO]) = ??? + def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => @@ -169,6 +173,7 @@ class PoolStatusReporterTest extends CatsEffectSuite: def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? def validateConnection(conn: Connection[IO]) = ??? + def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => From db1f29b7f68a1cb3ee4ed1f66b708911da4d5981 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 18 Dec 2025 00:03:24 +0900 Subject: [PATCH 182/215] Delete unused --- .../Sha256PasswordPluginPlatform.scala | 31 --------- .../Sha256PasswordPluginPlatform.scala | 42 ------------ .../connector/authenticator/Openssl.scala | 64 ------------------ .../Sha256PasswordPluginPlatform.scala | 67 ------------------- 4 files changed, 204 deletions(-) delete mode 100644 module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala delete mode 100644 module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala delete mode 100644 module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala delete mode 100644 module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala diff --git a/module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala deleted file mode 100644 index af1d261e5..000000000 --- a/module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector.authenticator -import java.nio.charset.StandardCharsets - -import scala.scalajs.js -import scala.scalajs.js.typedarray.Uint8Array - -import scodec.bits.ByteVector -trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => - private val crypto = js.Dynamic.global.require("crypto") - - def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = - val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) - val mysqlScrambleBuff = xorString(input, scramble, input.length) - encryptWithRSAPublicKey( - mysqlScrambleBuff, - publicKeyString - ) - - private def encryptWithRSAPublicKey(input: Array[Byte], publicKey: String): Array[Byte] = - val encrypted = crypto.publicEncrypt( - publicKey, - ByteVector(input).toUint8Array - ) - ByteVector.view(encrypted.asInstanceOf[Uint8Array]).toArray -} diff --git a/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala deleted file mode 100644 index bc61e94f4..000000000 --- a/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -/* -package ldbc.connector.authenticator - -import java.nio.charset.StandardCharsets -import java.security.interfaces.RSAPublicKey -import java.security.spec.X509EncodedKeySpec -import java.security.KeyFactory -import java.security.PublicKey -import java.util.Base64 - -import javax.crypto.Cipher - -trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => - def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = - val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) - val mysqlScrambleBuff = xorString(input, scramble, input.length) - encryptWithRSAPublicKey( - mysqlScrambleBuff, - decodeRSAPublicKey(publicKeyString) - ) - - private def encryptWithRSAPublicKey(input: Array[Byte], key: PublicKey): Array[Byte] = - val cipher = Cipher.getInstance(transformation) - cipher.init(Cipher.ENCRYPT_MODE, key) - cipher.doFinal(input) - - private def decodeRSAPublicKey(key: String): RSAPublicKey = - val offset = key.indexOf("\n") + 1 - val len = key.indexOf("-----END PUBLIC KEY-----") - offset - val certificateData = Base64.getMimeDecoder.decode(key.substring(offset, offset + len)) - val spec = new X509EncodedKeySpec(certificateData) - val kf = KeyFactory.getInstance("RSA") - kf.generatePublic(spec).asInstanceOf[RSAPublicKey] -} - - */ diff --git a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala b/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala deleted file mode 100644 index 4830853ad..000000000 --- a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector.authenticator - -import scala.scalanative.unsafe.* -import scala.scalanative.unsigned.* - -import org.typelevel.scalaccompat.annotation.* - -@nowarn212 -@link("crypto") -@extern -private[ldbc] object Openssl: - - final val RSA_PKCS1_OAEP_PADDING = 4 - - final val EVP_MAX_MD_SIZE = 64 - - type EVP_MD - type EVP_MD_CTX - type ENGINE - type BIO - type EVP_PKEY - type EVP_PKEY_CTX - - type pem_password_cb = CFuncPtr4[Ptr[Byte], CInt, CInt, Ptr[Byte], CInt] - - def EVP_get_digestbyname(name: Ptr[CChar]): Ptr[EVP_MD] = extern - - def EVP_Digest( - data: Ptr[Byte], - count: CSize, - md: Ptr[Byte], - size: Ptr[CUnsignedInt], - `type`: Ptr[EVP_MD], - impl: Ptr[ENGINE] - ): CInt = extern - - def BIO_new_mem_buf(buf: Ptr[Byte], len: CInt): Ptr[BIO] = extern - - def PEM_read_bio_PUBKEY(bp: Ptr[BIO], x: Ptr[Ptr[EVP_PKEY]], cb: pem_password_cb, u: Ptr[Byte]): Ptr[EVP_PKEY] = - extern - - def EVP_PKEY_CTX_new(pkey: Ptr[EVP_PKEY], e: Ptr[ENGINE]): Ptr[EVP_PKEY_CTX] = extern - - def EVP_PKEY_encrypt_init(ctx: Ptr[EVP_PKEY_CTX]): CInt = extern - - def EVP_PKEY_CTX_set_rsa_padding(ctx: Ptr[EVP_PKEY_CTX], padding: CInt): CInt = extern - - def EVP_PKEY_CTX_set_rsa_oaep_md(ctx: Ptr[EVP_PKEY_CTX], md: Ptr[EVP_MD]): CInt = extern - - def EVP_PKEY_CTX_set_rsa_mgf1_md(ctx: Ptr[EVP_PKEY_CTX], md: Ptr[EVP_MD]): CInt = extern - - def EVP_PKEY_encrypt( - ctx: Ptr[EVP_PKEY_CTX], - out: Ptr[UByte], - outlen: Ptr[CSize], - in: Ptr[UByte], - inlen: CSize - ): CInt = extern diff --git a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala deleted file mode 100644 index 555f8310d..000000000 --- a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector.authenticator -import java.nio.charset.StandardCharsets - -import scala.scalanative.unsafe.* -import scala.scalanative.unsigned.* - -import ldbc.connector.authenticator.Openssl.* - -trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => - def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = - val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) - val mysqlScrambleBuff = xorString(input, scramble, input.length) - encryptWithRSAPublicKey( - mysqlScrambleBuff, - publicKeyString - ) - - private def encryptWithRSAPublicKey(input: Array[Byte], publicKey: String): Array[Byte] = - Zone { implicit zone => - val publicKeyCStr = toCString(publicKey) - val bio = BIO_new_mem_buf(publicKeyCStr, publicKey.length) - if bio == null then throw new RuntimeException("Failed to create a new memory BIO.") - - val evpPkey = PEM_read_bio_PUBKEY(bio, null, null, null) - if evpPkey == null then throw new RuntimeException("Failed to load public key.") - - val ctx = EVP_PKEY_CTX_new(evpPkey, null) - if EVP_PKEY_encrypt_init(ctx) <= 0 then throw new RuntimeException("Failed to initialize encryption context.") - - if EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) != 1 then { - throw new RuntimeException("Failed to set RSA padding.") - } - - val sha1Md = EVP_get_digestbyname(c"sha1") - if EVP_PKEY_CTX_set_rsa_oaep_md(ctx, sha1Md) != 1 then { - throw new RuntimeException("Failed to set OAEP hash function.") - } - - if EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, sha1Md) != 1 then { - throw new RuntimeException("Failed to set MGF1 hash function.") - } - - val inputBuf = alloc[UByte](input.length.toULong) - for i <- input.indices do !(inputBuf + i) = input(i).toUByte - - val outLen = stackalloc[CSize]() - !outLen = 0.toULong - - if EVP_PKEY_encrypt(ctx, null, outLen, inputBuf, input.length.toULong) <= 0 then - throw new RuntimeException("Failed to obtain the output buffer size required for encryption.") - - val encryptedBuf = alloc[UByte](!outLen) - if EVP_PKEY_encrypt(ctx, encryptedBuf, outLen, inputBuf, input.length.toULong) <= 0 then - throw new RuntimeException("Encryption failed.") - - val result = Array.fill[Byte]((!outLen).toInt)(0) - for i <- 0 until (!outLen).toInt do result(i) = (!(encryptedBuf + i)).toInt.toByte - - result - } -} From 661983fa2fd8e093c7dc86ca939d504c1ce14eb2 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 18 Dec 2025 00:04:09 +0900 Subject: [PATCH 183/215] Action sbt scalafmtAll --- .../plugin/EncryptPasswordPlugin.scala | 2 +- .../plugin/EncryptPasswordPlugin.scala | 8 ++--- .../plugin/EncryptPasswordPlugin.scala | 2 +- .../ldbc/authentication/plugin/Openssl.scala | 4 +-- .../plugin/MysqlClearPasswordPlugin.scala | 13 +++---- .../ldbc/authentication/plugin/pacakge.scala | 28 +++++++-------- .../scala/ldbc/connector/Connection.scala | 17 +++++----- .../ldbc/connector/MySQLDataSource.scala | 11 +++--- .../CachingSha2PasswordPlugin.scala | 4 +-- .../MysqlClearPasswordPlugin.scala | 10 ++++-- .../MysqlNativePasswordPlugin.scala | 2 +- .../authenticator/Sha256PasswordPlugin.scala | 2 +- .../scala/ldbc/connector/net/Protocol.scala | 34 ++++++++++--------- .../connector/pool/PooledDataSource.scala | 11 +++--- .../scala/ldbc/connector/ConnectionTest.scala | 4 +-- .../pool/KeepaliveExecutorTest.scala | 4 +-- .../pool/PoolStatusReporterTest.scala | 16 ++++----- 17 files changed, 90 insertions(+), 82 deletions(-) diff --git a/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala b/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala index 03ef8ac7e..fdba42a91 100644 --- a/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala +++ b/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -16,7 +16,7 @@ import scodec.bits.ByteVector trait EncryptPasswordPlugin: private val crypto = js.Dynamic.global.require("crypto") - + def transformation: String private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = diff --git a/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala b/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala index f6a500695..ec352e840 100644 --- a/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala +++ b/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -37,9 +37,9 @@ trait EncryptPasswordPlugin: cipher.doFinal(input) private def decodeRSAPublicKey(key: String): RSAPublicKey = - val offset = key.indexOf("\n") + 1 - val len = key.indexOf("-----END PUBLIC KEY-----") - offset + val offset = key.indexOf("\n") + 1 + val len = key.indexOf("-----END PUBLIC KEY-----") - offset val certificateData = Base64.getMimeDecoder.decode(key.substring(offset, offset + len)) - val spec = new X509EncodedKeySpec(certificateData) - val kf = KeyFactory.getInstance("RSA") + val spec = new X509EncodedKeySpec(certificateData) + val kf = KeyFactory.getInstance("RSA") kf.generatePublic(spec).asInstanceOf[RSAPublicKey] diff --git a/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala index 322754420..65b42b368 100644 --- a/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala +++ b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -32,7 +32,7 @@ trait EncryptPasswordPlugin: private def encryptWithRSAPublicKey(input: Array[Byte], publicKey: String): Array[Byte] = Zone { implicit zone => val publicKeyCStr = toCString(publicKey) - val bio = BIO_new_mem_buf(publicKeyCStr, publicKey.length) + val bio = BIO_new_mem_buf(publicKeyCStr, publicKey.length) if bio == null then throw new RuntimeException("Failed to create a new memory BIO.") val evpPkey = PEM_read_bio_PUBKEY(bio, null, null, null) diff --git a/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala index 1c21dca30..5a9ecd34a 100644 --- a/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala +++ b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala @@ -6,11 +6,11 @@ package ldbc.authentication.plugin -import org.typelevel.scalaccompat.annotation.* - import scala.scalanative.unsafe.* import scala.scalanative.unsigned.* +import org.typelevel.scalaccompat.annotation.* + @nowarn212 @link("crypto") @extern diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala index 5ee0f8e54..fb6d8e03a 100644 --- a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala @@ -6,23 +6,24 @@ package ldbc.authentication.plugin -import cats.Applicative +import java.nio.charset.StandardCharsets import scodec.bits.ByteVector -import java.nio.charset.StandardCharsets +import cats.Applicative trait MysqlClearPasswordPlugin[F[_]] extends AuthenticationPlugin[F]: - override final def name: PluginName = MYSQL_CLEAR_PASSWORD - override final def requiresConfidentiality: Boolean = true + override final def name: PluginName = MYSQL_CLEAR_PASSWORD + override final def requiresConfidentiality: Boolean = true object MysqlClearPasswordPlugin: private case class Impl[F[_]: Applicative]() extends MysqlClearPasswordPlugin[F]: override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = - val result = if password.isEmpty then ByteVector.empty - else ByteVector(password.getBytes(StandardCharsets.UTF_8)) + val result = + if password.isEmpty then ByteVector.empty + else ByteVector(password.getBytes(StandardCharsets.UTF_8)) Applicative[F].pure(result) def apply[F[_]: Applicative](): MysqlClearPasswordPlugin[F] = Impl[F]() diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala index 329665b29..7e7c04248 100644 --- a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala @@ -11,35 +11,35 @@ package object plugin: opaque type PluginName = String // Standard MySQL Authentication Plugins - val MYSQL_CLEAR_PASSWORD: PluginName = "mysql_clear_password" + val MYSQL_CLEAR_PASSWORD: PluginName = "mysql_clear_password" val MYSQL_NATIVE_PASSWORD: PluginName = "mysql_native_password" - val SHA256_PASSWORD: PluginName = "sha256_password" + val SHA256_PASSWORD: PluginName = "sha256_password" val CACHING_SHA2_PASSWORD: PluginName = "caching_sha2_password" - + // Legacy Authentication Plugin (deprecated) val MYSQL_OLD_PASSWORD: PluginName = "mysql_old_password" - + // Authentication plugins for external authentication - val AUTHENTICATION_WINDOWS: PluginName = "authentication_windows" - val AUTHENTICATION_PAM: PluginName = "authentication_pam" + val AUTHENTICATION_WINDOWS: PluginName = "authentication_windows" + val AUTHENTICATION_PAM: PluginName = "authentication_pam" val AUTHENTICATION_LDAP_SIMPLE: PluginName = "authentication_ldap_simple" - val AUTHENTICATION_LDAP_SASL: PluginName = "authentication_ldap_sasl" - + val AUTHENTICATION_LDAP_SASL: PluginName = "authentication_ldap_sasl" + // Kerberos authentication val AUTHENTICATION_KERBEROS: PluginName = "authentication_kerberos" - + // FIDO authentication (MySQL 8.0.27+) val AUTHENTICATION_FIDO: PluginName = "authentication_fido" - + // Multi-factor authentication (MySQL 8.0.27+) val AUTHENTICATION_WEBAUTHN: PluginName = "authentication_webauthn" - + // No login authentication plugin val MYSQL_NO_LOGIN: PluginName = "mysql_no_login" - + // Test plugins (for testing purposes) val TEST_PLUGIN_SERVER: PluginName = "test_plugin_server" - val DAEMON_EXAMPLE: PluginName = "daemon_example" - + val DAEMON_EXAMPLE: PluginName = "daemon_example" + // Socket peer-credential authentication (Unix socket) val AUTH_SOCKET: PluginName = "auth_socket" diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index dfd4d617f..aca2f00fa 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -24,12 +24,13 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.authentication.plugin.* import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.* import ldbc.connector.net.protocol.* +import ldbc.authentication.plugin.* + type Connection[F[_]] = ldbc.sql.Connection[F] object Connection: @@ -78,7 +79,7 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default[F, Unit]( host, port, @@ -116,7 +117,7 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default( host, port, @@ -152,7 +153,7 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Connection[F] => F[A], after: (A, Connection[F]) => F[Unit] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = @@ -177,7 +178,7 @@ object Connection: useServerPrepStmts, databaseTerm, defaultAuthenticationPlugin, - plugins, + plugins, before, after ) @@ -198,11 +199,11 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - plugins: List[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], acquire: Connection[F] => F[A], release: (A, Connection[F]) => F[Unit] ): Resource[F, LdbcConnection[F]] = - val pluginMap = plugins.map(plugin => plugin.name.toString -> plugin).toMap + val pluginMap = plugins.map(plugin => plugin.name.toString -> plugin).toMap val capabilityFlags = defaultCapabilityFlags ++ (if database.isDefined then Set(CapabilitiesFlags.CLIENT_CONNECT_WITH_DB) else Set.empty) ++ (if sslOptions.isDefined then Set(CapabilitiesFlags.CLIENT_SSL) else Set.empty) @@ -261,7 +262,7 @@ object Connection: useServerPrepStmts: Boolean = false, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - plugins: List[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], acquire: Connection[F] => F[A], release: (A, Connection[F]) => F[Unit] )(using ev: Async[F]): Resource[F, LdbcConnection[F]] = diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 0192364dc..eefc2e98a 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -19,10 +19,9 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.authentication.plugin.AuthenticationPlugin - import ldbc.connector.pool.* +import ldbc.authentication.plugin.AuthenticationPlugin import ldbc.DataSource /** @@ -86,7 +85,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ) extends DataSource[F]: @@ -121,7 +120,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, - plugins = plugins + plugins = plugins ) case (Some(b), None) => Connection.withBeforeAfter( @@ -141,7 +140,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, - plugins = plugins + plugins = plugins ) case (None, _) => Connection( @@ -159,7 +158,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, - plugins = plugins + plugins = plugins ) /** Sets the hostname or IP address of the MySQL server. diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala index dbe97fdbc..e81e41c5d 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala @@ -10,10 +10,10 @@ import cats.effect.kernel.Sync import fs2.hashing.Hashing -import ldbc.authentication.plugin.* - import ldbc.connector.util.Version +import ldbc.authentication.plugin.* + trait CachingSha2PasswordPlugin[F[_]] extends Sha256PasswordPlugin[F]: override def name: PluginName = CACHING_SHA2_PASSWORD diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala index a7f2af42d..50531631b 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala @@ -12,7 +12,10 @@ import scodec.bits.ByteVector import cats.effect.kernel.Sync -@deprecated("This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", "0.5.0") +@deprecated( + "This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", + "0.5.0" +) class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: override def name: String = "mysql_clear_password" @@ -22,5 +25,8 @@ class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: else Sync[F].delay(ByteVector(password.getBytes(StandardCharsets.UTF_8))) object MysqlClearPasswordPlugin: - @deprecated("This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", "0.5.0") + @deprecated( + "This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", + "0.5.0" + ) def apply[F[_]: Sync](): MysqlClearPasswordPlugin[F] = new MysqlClearPasswordPlugin[F]() diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala index a4ba81fcc..767305b22 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala @@ -23,7 +23,7 @@ import ldbc.authentication.plugin.* class MysqlNativePasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F]: - override def name: PluginName = MYSQL_NATIVE_PASSWORD + override def name: PluginName = MYSQL_NATIVE_PASSWORD override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala index cb8b836fd..6f1df89af 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala @@ -22,7 +22,7 @@ import fs2.Chunk import ldbc.authentication.plugin.* trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F], EncryptPasswordPlugin: - override def name: PluginName = SHA256_PASSWORD + override def name: PluginName = SHA256_PASSWORD override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index fea0885d1..701e8346d 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -25,9 +25,7 @@ import fs2.io.net.Socket import org.typelevel.otel4s.trace.{ Span, Tracer } import org.typelevel.otel4s.Attribute -import ldbc.authentication.plugin.* - -import ldbc.connector.authenticator.{ MysqlNativePasswordPlugin, Sha256PasswordPlugin, CachingSha2PasswordPlugin } +import ldbc.connector.authenticator.{ CachingSha2PasswordPlugin, MysqlNativePasswordPlugin, Sha256PasswordPlugin } import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.packet.* @@ -36,6 +34,8 @@ import ldbc.connector.net.packet.response.* import ldbc.connector.net.protocol.* import ldbc.connector.telemetry.* +import ldbc.authentication.plugin.* + /** * Protocol is a protocol to communicate with MySQL server. * It provides a way to authenticate, reset sequence id, and close the connection. @@ -149,7 +149,7 @@ object Protocol: capabilityFlags: Set[CapabilitiesFlags], sequenceIdRef: Ref[F, Byte], defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - plugins: Map[String, AuthenticationPlugin[F]] + plugins: Map[String, AuthenticationPlugin[F]] )(using ev: MonadError[F, Throwable], ex: Exchange[F]) extends Protocol[F]: @@ -561,17 +561,19 @@ object Protocol: else ev.unit private def determinatePlugin(pluginName: String): Either[SQLException, AuthenticationPlugin[F]] = - plugins.get(pluginName).toRight( - new SQLInvalidAuthorizationSpecException( - s"Unknown authentication plugin: $pluginName", - detail = Some( - "This error may be due to lack of support on the ldbc side or a newly added plugin on the MySQL side." - ), - hint = Some( - "Report Issues here: https://github.com/takapi327/ldbc/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=" + plugins + .get(pluginName) + .toRight( + new SQLInvalidAuthorizationSpecException( + s"Unknown authentication plugin: $pluginName", + detail = Some( + "This error may be due to lack of support on the ldbc side or a newly added plugin on the MySQL side." + ), + hint = Some( + "Report Issues here: https://github.com/takapi327/ldbc/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=" + ) ) ) - ) def apply[F[_]: Async: Console: Tracer: Exchange: Hashing]( sockets: Resource[F, Socket[F]], @@ -582,7 +584,7 @@ object Protocol: readTimeout: Duration, capabilitiesFlags: Set[CapabilitiesFlags], defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - plugins: Map[String, AuthenticationPlugin[F]] + plugins: Map[String, AuthenticationPlugin[F]] ): Resource[F, Protocol[F]] = for sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) @@ -613,7 +615,7 @@ object Protocol: sequenceIdRef: Ref[F, Byte], initialPacketRef: Ref[F, Option[InitialPacket]], defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], - plugins: Map[String, AuthenticationPlugin[F]] + plugins: Map[String, AuthenticationPlugin[F]] )(using ev: Async[F]): F[Protocol[F]] = initialPacketRef.get.flatMap { case Some(initialPacket) => @@ -629,7 +631,7 @@ object Protocol: defaultAuthenticationPlugin, Map( MYSQL_NATIVE_PASSWORD.toString -> MysqlNativePasswordPlugin[F](), - SHA256_PASSWORD.toString -> Sha256PasswordPlugin[F](), + SHA256_PASSWORD.toString -> Sha256PasswordPlugin[F](), CACHING_SHA2_PASSWORD.toString -> CachingSha2PasswordPlugin[F](initialPacket.serverVersion) ) ++ plugins ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala index 68e65039b..e411fce84 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala @@ -23,11 +23,10 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.authentication.plugin.AuthenticationPlugin - import ldbc.connector.* import ldbc.connector.exception.SQLException +import ldbc.authentication.plugin.AuthenticationPlugin import ldbc.DataSource /** @@ -197,7 +196,7 @@ object PooledDataSource: databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None, minConnections: Int = 5, @@ -613,7 +612,7 @@ object PooledDataSource: useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, - plugins = plugins + plugins = plugins ) case (Some(b), None) => Connection.withBeforeAfter( @@ -632,7 +631,7 @@ object PooledDataSource: useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, - plugins = plugins + plugins = plugins ) case (None, _) => Connection( @@ -649,7 +648,7 @@ object PooledDataSource: useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, databaseTerm = databaseTerm, - plugins = plugins + plugins = plugins ) /** diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index f33578c87..36ca6df4f 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -16,10 +16,10 @@ import org.typelevel.otel4s.trace.Tracer import ldbc.sql.DatabaseMetaData -import ldbc.authentication.plugin.MysqlClearPasswordPlugin - import ldbc.connector.exception.* +import ldbc.authentication.plugin.MysqlClearPasswordPlugin + class ConnectionTest extends FTestPlatform: given Tracer[IO] = Tracer.noop[IO] diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala index 084e548c6..9c62b1612 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala @@ -14,10 +14,10 @@ import cats.effect.* import ldbc.sql.* -import ldbc.authentication.plugin.AuthenticationPlugin - import ldbc.connector.* +import ldbc.authentication.plugin.AuthenticationPlugin + class KeepaliveExecutorTest extends FTestPlatform: // Helper to create a mock Connection for testing diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala index d971f7a09..3290457e8 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala @@ -14,10 +14,10 @@ import cats.effect.* import munit.CatsEffectSuite -import ldbc.authentication.plugin.AuthenticationPlugin - import ldbc.connector.Connection +import ldbc.authentication.plugin.AuthenticationPlugin + class PoolStatusReporterTest extends CatsEffectSuite: class TestPoolLogger[F[_]: Applicative](var logCount: Int = 0) extends PoolLogger[F]: @@ -68,8 +68,8 @@ class PoolStatusReporterTest extends CatsEffectSuite: def createNewConnectionForPool() = ??? def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? - def validateConnection(conn: Connection[IO]) = ??? - def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? + def validateConnection(conn: Connection[IO]) = ??? + def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => @@ -111,8 +111,8 @@ class PoolStatusReporterTest extends CatsEffectSuite: def createNewConnectionForPool() = ??? def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? - def validateConnection(conn: Connection[IO]) = ??? - def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? + def validateConnection(conn: Connection[IO]) = ??? + def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => @@ -172,8 +172,8 @@ class PoolStatusReporterTest extends CatsEffectSuite: def createNewConnectionForPool() = ??? def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? - def validateConnection(conn: Connection[IO]) = ??? - def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? + def validateConnection(conn: Connection[IO]) = ??? + def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => From 2777604f174bfc7f870d84b8d877f83eef97cef3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 18 Dec 2025 00:05:04 +0900 Subject: [PATCH 184/215] Action sbt scalafmtSbt --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 936a61c04..c3ecfdadd 100644 --- a/build.sbt +++ b/build.sbt @@ -126,7 +126,7 @@ lazy val authenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlat .settings( libraryDependencies ++= Seq( "org.typelevel" %%% "cats-core" % "2.10.0", - "org.scodec" %%% "scodec-bits" % "1.1.38", + "org.scodec" %%% "scodec-bits" % "1.1.38" ) ) .jsSettings( From 6fd1d8fc185cd616ee77dbb44f379b99151fe879 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 18 Dec 2025 00:05:25 +0900 Subject: [PATCH 185/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ecf318320..1aa1b97bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,11 +155,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target + run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-authentication-plugin/native/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-authentication-plugin/js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-authentication-plugin/jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target + run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-authentication-plugin/native/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-authentication-plugin/js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-authentication-plugin/jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) From 40f67e659af6a1d18f3a5c9a63a6234fcb4392b6 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 14:56:05 +0900 Subject: [PATCH 186/215] Create AwsIamAuthenticationPlugin --- .../plugin/AwsIamAuthenticationPlugin.scala | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala new file mode 100644 index 000000000..d7a92b56d --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala @@ -0,0 +1,145 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.plugin + +import java.nio.charset.StandardCharsets + +import scodec.bits.ByteVector + +import cats.Monad +import cats.syntax.all.* + +import cats.effect.* +import cats.effect.std.{ Env, SystemProperties, UUIDGen } + +import fs2.io.file.Files +import fs2.io.net.* + +import ldbc.authentication.plugin.MysqlClearPasswordPlugin + +import ldbc.amazon.identity.{AwsCredentials, AwsCredentialsProvider} +import ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain +import ldbc.amazon.auth.token.{AuthTokenGenerator, RdsIamAuthTokenGenerator} + +/** + * AWS IAM authentication plugin for connecting to MySQL databases using IAM credentials. + * + * This plugin enables secure database connections to AWS RDS MySQL instances using AWS IAM + * credentials instead of traditional username/password authentication. It extends the + * `MysqlClearPasswordPlugin` and generates temporary authentication tokens using AWS credentials. + * + * The plugin automatically: + * - Resolves AWS credentials from various sources (environment variables, profiles, IAM roles, etc.) + * - Generates time-limited authentication tokens signed with AWS credentials + * - Handles token refresh and credential rotation transparently + * + * Security benefits: + * - No hardcoded database passwords in application code + * - Leverages existing AWS IAM policies and roles + * - Tokens are automatically time-limited (15 minutes) + * - Supports AWS credential rotation and temporary credentials + * + * Usage requirements: + * - AWS RDS instance must have IAM authentication enabled + * - Database user must be created with IAM authentication + * - Application must have appropriate IAM permissions + * + * @tparam F The effect type that wraps authentication operations + * @param provider The AWS credentials provider for obtaining authentication credentials + * @param generator The token generator for creating RDS IAM authentication tokens + * + * @see [[https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html AWS RDS IAM Database Authentication]] + * @since 1.0.0 + */ +final class AwsIamAuthenticationPlugin[F[_]: Monad]( + provider: AwsCredentialsProvider[F], + generator: AuthTokenGenerator[F] +) extends MysqlClearPasswordPlugin[F]: + + /** + * Generates an AWS IAM authentication token instead of processing a traditional password. + * + * This method overrides the standard password hashing behavior to generate a temporary + * authentication token using AWS IAM credentials. The process involves: + * + * 1. Resolving AWS credentials from the configured provider + * 2. Generating a signed authentication token using the RDS IAM authentication protocol + * 3. Converting the token to bytes for transmission to the MySQL server + * + * The generated token is a pre-signed URL that contains: + * - The RDS endpoint hostname and port + * - The database username + * - AWS signature version 4 (SigV4) authentication parameters + * - A 15-minute expiration time + * + * @param password The original password parameter (ignored in IAM authentication) + * @param scramble The server's challenge bytes (ignored in IAM authentication) + * @return The AWS IAM authentication token as UTF-8 encoded bytes wrapped in the effect type F + * + * @note The password and scramble parameters are inherited from the parent trait but are not + * used in AWS IAM authentication since the token generation process uses AWS credentials + * and cryptographic signing instead of password-based authentication. + */ + override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = + for + credentials <- provider.resolveCredentials() + token <- generator.generateToken(credentials) + yield ByteVector(token.getBytes(StandardCharsets.UTF_8)) + +object AwsIamAuthenticationPlugin: + + /** + * Creates a default AWS IAM authentication plugin with standard credential resolution. + * + * This factory method constructs an `AwsIamAuthenticationPlugin` with the default AWS + * credential provider chain and RDS IAM token generator. The credential provider chain + * attempts to resolve AWS credentials from multiple sources in the following order: + * + * 1. System properties (aws.accessKeyId, aws.secretAccessKey, aws.sessionToken) + * 2. Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN) + * 3. Web identity token file (for EKS service accounts, Lambda, etc.) + * 4. AWS profile configuration files (~/.aws/credentials, ~/.aws/config) + * 5. Amazon EC2 instance profile credentials + * 6. ECS container credentials + * + * The token generator creates pre-signed authentication tokens using AWS Signature Version 4 + * that are valid for 15 minutes and specific to the provided RDS endpoint and database user. + * + * @tparam F The effect type that must support file operations, environment access, system + * properties, network operations, UUID generation, and asynchronous operations + * @param region The AWS region where the RDS instance is located (e.g., "us-east-1") + * @param hostname The RDS instance endpoint hostname (e.g., "mydb.abc123.us-east-1.rds.amazonaws.com") + * @param username The database username configured for IAM authentication + * @param port The database port number (default: 3306 for MySQL) + * @return A configured `AwsIamAuthenticationPlugin` instance ready for database authentication + * + * @example {{{ + * import cats.effect.IO + * import ldbc.amazon.plugin.AwsIamAuthenticationPlugin + * + * val plugin = AwsIamAuthenticationPlugin.default[IO]( + * region = "us-east-1", + * hostname = "mydb.abc123.us-east-1.rds.amazonaws.com", + * username = "myuser", + * port = 3306 + * ) + * }}} + * + * @see [[ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain]] for credential resolution details + * @see [[ldbc.amazon.auth.token.RdsIamAuthTokenGenerator]] for token generation implementation + * @since 1.0.0 + */ + def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( + region: String, + hostname: String, + username: String, + port: Int = 3306 + ): AwsIamAuthenticationPlugin[F] = + new AwsIamAuthenticationPlugin[F]( + DefaultCredentialsProviderChain.default[F](region), + new RdsIamAuthTokenGenerator[F](hostname, port, username, region) + ) From 490446fe3f35e6a7a3ebe36dd39eee7e0f1a0037 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 14:56:27 +0900 Subject: [PATCH 187/215] Action sbt scalafmtAll --- .../plugin/AwsIamAuthenticationPlugin.scala | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala index d7a92b56d..263ec542c 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala @@ -10,8 +10,8 @@ import java.nio.charset.StandardCharsets import scodec.bits.ByteVector -import cats.Monad import cats.syntax.all.* +import cats.Monad import cats.effect.* import cats.effect.std.{ Env, SystemProperties, UUIDGen } @@ -19,11 +19,10 @@ import cats.effect.std.{ Env, SystemProperties, UUIDGen } import fs2.io.file.Files import fs2.io.net.* -import ldbc.authentication.plugin.MysqlClearPasswordPlugin - -import ldbc.amazon.identity.{AwsCredentials, AwsCredentialsProvider} import ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain -import ldbc.amazon.auth.token.{AuthTokenGenerator, RdsIamAuthTokenGenerator} +import ldbc.amazon.auth.token.{ AuthTokenGenerator, RdsIamAuthTokenGenerator } +import ldbc.amazon.identity.{ AwsCredentials, AwsCredentialsProvider } +import ldbc.authentication.plugin.MysqlClearPasswordPlugin /** * AWS IAM authentication plugin for connecting to MySQL databases using IAM credentials. @@ -56,8 +55,8 @@ import ldbc.amazon.auth.token.{AuthTokenGenerator, RdsIamAuthTokenGenerator} * @since 1.0.0 */ final class AwsIamAuthenticationPlugin[F[_]: Monad]( - provider: AwsCredentialsProvider[F], - generator: AuthTokenGenerator[F] + provider: AwsCredentialsProvider[F], + generator: AuthTokenGenerator[F] ) extends MysqlClearPasswordPlugin[F]: /** @@ -87,7 +86,7 @@ final class AwsIamAuthenticationPlugin[F[_]: Monad]( override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = for credentials <- provider.resolveCredentials() - token <- generator.generateToken(credentials) + token <- generator.generateToken(credentials) yield ByteVector(token.getBytes(StandardCharsets.UTF_8)) object AwsIamAuthenticationPlugin: @@ -134,10 +133,10 @@ object AwsIamAuthenticationPlugin: * @since 1.0.0 */ def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( - region: String, + region: String, hostname: String, username: String, - port: Int = 3306 + port: Int = 3306 ): AwsIamAuthenticationPlugin[F] = new AwsIamAuthenticationPlugin[F]( DefaultCredentialsProviderChain.default[F](region), From afbfb2944c99b587f31ab224f4eef00ea5b3f5a8 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 15:17:37 +0900 Subject: [PATCH 188/215] Delete unused --- .../ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala index 263ec542c..5a9333b89 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala @@ -18,10 +18,11 @@ import cats.effect.std.{ Env, SystemProperties, UUIDGen } import fs2.io.file.Files import fs2.io.net.* +import fs2.hashing.Hashing import ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain import ldbc.amazon.auth.token.{ AuthTokenGenerator, RdsIamAuthTokenGenerator } -import ldbc.amazon.identity.{ AwsCredentials, AwsCredentialsProvider } +import ldbc.amazon.identity.AwsCredentialsProvider import ldbc.authentication.plugin.MysqlClearPasswordPlugin /** @@ -52,7 +53,6 @@ import ldbc.authentication.plugin.MysqlClearPasswordPlugin * @param generator The token generator for creating RDS IAM authentication tokens * * @see [[https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html AWS RDS IAM Database Authentication]] - * @since 1.0.0 */ final class AwsIamAuthenticationPlugin[F[_]: Monad]( provider: AwsCredentialsProvider[F], @@ -130,9 +130,8 @@ object AwsIamAuthenticationPlugin: * * @see [[ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain]] for credential resolution details * @see [[ldbc.amazon.auth.token.RdsIamAuthTokenGenerator]] for token generation implementation - * @since 1.0.0 */ - def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( + def default[F[_]: Files: Hashing: Env: SystemProperties: Network: UUIDGen: Async]( region: String, hostname: String, username: String, From dd0e4105fb8e8c41322eced57fe031648bd78198 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:36:07 +0900 Subject: [PATCH 189/215] Delete setPlugins method --- .../ldbc/connector/MySQLDataSource.scala | 8 +++-- .../connector/pool/PooledDataSource.scala | 36 +++++-------------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index eefc2e98a..445236b9c 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -483,8 +483,9 @@ object MySQLDataSource: def pooling[F[_]: Async: Network: Console: Hashing: UUIDGen]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, - tracer: Option[Tracer[F]] = None - ): Resource[F, PooledDataSource[F]] = PooledDataSource.fromConfig(config, metricsTracker, tracer) + tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + ): Resource[F, PooledDataSource[F]] = PooledDataSource.fromConfig(config, metricsTracker, tracer, plugins) /** * Creates a pooled DataSource with connection lifecycle hooks. @@ -539,7 +540,8 @@ object MySQLDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ): Resource[F, PooledDataSource[F]] = - PooledDataSource.fromConfigWithBeforeAfter(config, metricsTracker, tracer, before, after) + PooledDataSource.fromConfigWithBeforeAfter(config, metricsTracker, tracer, plugins, before, after) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala index e411fce84..91493df97 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala @@ -169,17 +169,6 @@ trait PooledDataSource[F[_]] extends DataSource[F]: */ def validateConnection(conn: Connection[F]): F[Boolean] - /** - * Sets whether to authentication plugin to be used for communication with the server. - * - * @param p1 - * The authentication plugin used for communication with the server - * @param pn - * List of authentication plugins used for communication with the server - * @return a new MySQLDataSource with the updated setting - */ - def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): PooledDataSource[F] - object PooledDataSource: private case class Impl[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( @@ -651,22 +640,11 @@ object PooledDataSource: plugins = plugins ) - /** - * Sets whether to authentication plugin to be used for communication with the server. - * - * @param p1 - * The authentication plugin used for communication with the server - * @param pn - * List of authentication plugins used for communication with the server - * @return a new MySQLDataSource with the updated setting - */ - override def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): PooledDataSource[F] = - copy(plugins = p1 :: pn.toList) - private[connector] def create[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], idGenerator: F[String], + plugins: List[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None )(using Tracer[F]): Resource[F, PooledDataSource[F]] = @@ -675,7 +653,7 @@ object PooledDataSource: Resource .eval(PoolConfigValidator.validate(config)) .flatMap { _ => - createValidatedPool(config, metricsTracker, idGenerator, before, after) + createValidatedPool(config, metricsTracker, idGenerator, plugins, before, after) } .handleErrorWith { error => Resource.eval(Async[F].raiseError(error)) @@ -685,6 +663,7 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], idGenerator: F[String], + plugins: List[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]], after: Option[(A, Connection[F]) => F[Unit]] )(using Tracer[F]): Resource[F, PooledDataSource[F]] = @@ -739,6 +718,7 @@ object PooledDataSource: keepaliveTime = config.keepaliveTime, connectionTestQuery = config.connectionTestQuery, poolLogger = poolLogger, + plugins = plugins, before = before, after = after ) @@ -786,10 +766,11 @@ object PooledDataSource: def fromConfig[F[_]: Async: Network: Console: Hashing: UUIDGen]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, - tracer: Option[Tracer[F]] = None + tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Resource[F, PooledDataSource[F]] = given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) - create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString)) + create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), plugins) /** * Creates a PooledDataSource with before/after hooks for each connection use. @@ -816,8 +797,9 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ): Resource[F, PooledDataSource[F]] = given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) - create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), before, after) + create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), plugins, before, after) From b00db31e999a6c3bd0e45dd8d45282445b6911ac Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:38:03 +0900 Subject: [PATCH 190/215] Create aws iam authentication plugin example project --- build.sbt | 17 +++- .../aws-iam-auth/src/main/scala/Main.scala | 96 +++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 examples/aws-iam-auth/src/main/scala/Main.scala diff --git a/build.sbt b/build.sbt index c3ecfdadd..45f143317 100644 --- a/build.sbt +++ b/build.sbt @@ -316,6 +316,20 @@ lazy val zioExample = crossProject(JVMPlatform) ) .dependsOn(connector, dsl, zioInterop) +lazy val awsIamAuthExample = crossProject(JVMPlatform) + .crossType(CrossType.Pure) + .withoutSuffixFor(JVMPlatform) + .example("aws-iam-auth", "Aws Iam Authentication example project") + .settings( + libraryDependencies ++= Seq( + "org.http4s" %% "http4s-dsl" % "0.23.33", + "org.http4s" %% "http4s-ember-server" % "0.23.33", + "org.http4s" %% "http4s-circe" % "0.23.33", + "io.circe" %% "circe-generic" % "0.14.10" + ) + ) + .dependsOn(connector, awsAuthenticationPlugin, dsl) + lazy val docs = (project in file("docs")) .settings( description := "Documentation for ldbc", @@ -417,7 +431,8 @@ lazy val examples = Seq( http4sExample, hikariCPExample, otelExample, - zioExample + zioExample, + awsIamAuthExample ) lazy val ldbc = tlCrossRootProject diff --git a/examples/aws-iam-auth/src/main/scala/Main.scala b/examples/aws-iam-auth/src/main/scala/Main.scala new file mode 100644 index 000000000..07179c698 --- /dev/null +++ b/examples/aws-iam-auth/src/main/scala/Main.scala @@ -0,0 +1,96 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +import com.comcast.ip4s.* + +import cats.effect.* + +import io.circe.* +import io.circe.syntax.* + +import cats.effect.std.{Env, Console} + +import ldbc.logging.* +import ldbc.connector.* +import ldbc.dsl.* + +import org.http4s.* +import org.http4s.circe.CirceEntityEncoder.* +import org.http4s.dsl.io.* +import org.http4s.ember.server.EmberServerBuilder + +import ldbc.amazon.plugin.AwsIamAuthenticationPlugin + +case class City( + id: Int, + name: String, + countryCode: String, + district: String, + population: Int +) + +object City: + + given Encoder[City] = Encoder.derived[City] + +object Main extends ResourceApp.Forever: + + private val logHandler: LogHandler[IO] = { + case LogEvent.Success(sql, args) => + Console[IO].println( + s"""Successful Statement Execution: + | $sql + | + | arguments = [${args.mkString(",")}] + |""".stripMargin + ) + case LogEvent.ProcessingFailure(sql, args, failure) => + Console[IO].errorln( + s"""Failed ResultSet Processing: + | $sql + | + | arguments = [${args.mkString(",")}] + |""".stripMargin + ) >> Console[IO].printStackTrace(failure) + case LogEvent.ExecFailure(sql, args, failure) => + Console[IO].errorln( + s"""Failed Statement Execution: + | $sql + | + | arguments = [${args.mkString(",")}] + |""".stripMargin + ) >> Console[IO].printStackTrace(failure) + } + + private def routes(connector: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { + case GET -> Root / "cities" => + for + cities <- sql"SELECT * FROM city".query[City].to[List].readOnly(connector) + result <- Ok(cities.asJson) + yield result + } + + override def run(args: List[String]): Resource[IO, Unit] = + for + hostname <- Resource.eval(Env[IO].get("AURORA_HOST").flatMap { + case Some(v) => IO.pure(v) + case None => IO.raiseError(new RuntimeException("AURORA_HOST is not set")) + }) + username <- Resource.eval(Env[IO].get("AURORA_USER").flatMap { + case Some(v) => IO.pure(v) + case None => IO.raiseError(new RuntimeException("AURORA_USER is not set")) + }) + config = MySQLConfig.default.setHost(hostname).setUser(username).setDatabase("world").setSSL(SSL.Trusted) + plugin = AwsIamAuthenticationPlugin.default[IO]("ap-northeast-1", hostname, username) + datasource <- MySQLDataSource.pooling[IO](config, plugins = List(plugin)) + connector = Connector.fromDataSource(datasource, Some(logHandler)) + _ <- EmberServerBuilder + .default[IO] + .withHost(host"0.0.0.0") + .withPort(port"9000") + .withHttpApp(routes(connector).orNotFound) + .build + yield () From 46c99fc99d54eb5080ddd962c8f2d340bb216cb4 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:39:38 +0900 Subject: [PATCH 191/215] Added healthcheck api --- examples/aws-iam-auth/src/main/scala/Main.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/aws-iam-auth/src/main/scala/Main.scala b/examples/aws-iam-auth/src/main/scala/Main.scala index 07179c698..cc3a3d293 100644 --- a/examples/aws-iam-auth/src/main/scala/Main.scala +++ b/examples/aws-iam-auth/src/main/scala/Main.scala @@ -66,6 +66,7 @@ object Main extends ResourceApp.Forever: } private def routes(connector: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { + case GET -> Root / "healthcheck" -> Ok("Healthcheck") case GET -> Root / "cities" => for cities <- sql"SELECT * FROM city".query[City].to[List].readOnly(connector) From d0a7fe505ec1f5c58eb10bb2324f257f683078ba Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:39:54 +0900 Subject: [PATCH 192/215] Action sbt scalafmtSbt --- build.sbt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build.sbt b/build.sbt index 45f143317..e130edc58 100644 --- a/build.sbt +++ b/build.sbt @@ -322,10 +322,10 @@ lazy val awsIamAuthExample = crossProject(JVMPlatform) .example("aws-iam-auth", "Aws Iam Authentication example project") .settings( libraryDependencies ++= Seq( - "org.http4s" %% "http4s-dsl" % "0.23.33", - "org.http4s" %% "http4s-ember-server" % "0.23.33", - "org.http4s" %% "http4s-circe" % "0.23.33", - "io.circe" %% "circe-generic" % "0.14.10" + "org.http4s" %% "http4s-dsl" % "0.23.33", + "org.http4s" %% "http4s-ember-server" % "0.23.33", + "org.http4s" %% "http4s-circe" % "0.23.33", + "io.circe" %% "circe-generic" % "0.14.10" ) ) .dependsOn(connector, awsAuthenticationPlugin, dsl) From 7df9ebf0a2ab99770fbb3f1e51eb76d264678b54 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:41:33 +0900 Subject: [PATCH 193/215] Fixed compile error --- examples/aws-iam-auth/src/main/scala/Main.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/aws-iam-auth/src/main/scala/Main.scala b/examples/aws-iam-auth/src/main/scala/Main.scala index cc3a3d293..3d3a560e5 100644 --- a/examples/aws-iam-auth/src/main/scala/Main.scala +++ b/examples/aws-iam-auth/src/main/scala/Main.scala @@ -66,7 +66,7 @@ object Main extends ResourceApp.Forever: } private def routes(connector: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { - case GET -> Root / "healthcheck" -> Ok("Healthcheck") + case GET -> Root / "healthcheck" => Ok("Healthcheck") case GET -> Root / "cities" => for cities <- sql"SELECT * FROM city".query[City].to[List].readOnly(connector) From 6db67a933e066b0fea11fc9fcf86fb71423cc6d8 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:42:24 +0900 Subject: [PATCH 194/215] Action sbt scalafmtAll --- .../aws-iam-auth/src/main/scala/Main.scala | 40 +++++++++---------- .../plugin/AwsIamAuthenticationPlugin.scala | 2 +- .../ldbc/connector/MySQLDataSource.scala | 4 +- .../connector/pool/PooledDataSource.scala | 10 ++--- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/aws-iam-auth/src/main/scala/Main.scala b/examples/aws-iam-auth/src/main/scala/Main.scala index 3d3a560e5..94fc9056d 100644 --- a/examples/aws-iam-auth/src/main/scala/Main.scala +++ b/examples/aws-iam-auth/src/main/scala/Main.scala @@ -7,23 +7,23 @@ import com.comcast.ip4s.* import cats.effect.* +import cats.effect.std.{ Console, Env } import io.circe.* import io.circe.syntax.* -import cats.effect.std.{Env, Console} +import ldbc.dsl.* -import ldbc.logging.* import ldbc.connector.* -import ldbc.dsl.* + +import ldbc.amazon.plugin.AwsIamAuthenticationPlugin +import ldbc.logging.* import org.http4s.* import org.http4s.circe.CirceEntityEncoder.* import org.http4s.dsl.io.* import org.http4s.ember.server.EmberServerBuilder -import ldbc.amazon.plugin.AwsIamAuthenticationPlugin - case class City( id: Int, name: String, @@ -44,7 +44,7 @@ object Main extends ResourceApp.Forever: s"""Successful Statement Execution: | $sql | - | arguments = [${args.mkString(",")}] + | arguments = [${ args.mkString(",") }] |""".stripMargin ) case LogEvent.ProcessingFailure(sql, args, failure) => @@ -52,7 +52,7 @@ object Main extends ResourceApp.Forever: s"""Failed ResultSet Processing: | $sql | - | arguments = [${args.mkString(",")}] + | arguments = [${ args.mkString(",") }] |""".stripMargin ) >> Console[IO].printStackTrace(failure) case LogEvent.ExecFailure(sql, args, failure) => @@ -60,14 +60,14 @@ object Main extends ResourceApp.Forever: s"""Failed Statement Execution: | $sql | - | arguments = [${args.mkString(",")}] + | arguments = [${ args.mkString(",") }] |""".stripMargin ) >> Console[IO].printStackTrace(failure) } private def routes(connector: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { case GET -> Root / "healthcheck" => Ok("Healthcheck") - case GET -> Root / "cities" => + case GET -> Root / "cities" => for cities <- sql"SELECT * FROM city".query[City].to[List].readOnly(connector) result <- Ok(cities.asJson) @@ -77,21 +77,21 @@ object Main extends ResourceApp.Forever: override def run(args: List[String]): Resource[IO, Unit] = for hostname <- Resource.eval(Env[IO].get("AURORA_HOST").flatMap { - case Some(v) => IO.pure(v) - case None => IO.raiseError(new RuntimeException("AURORA_HOST is not set")) - }) + case Some(v) => IO.pure(v) + case None => IO.raiseError(new RuntimeException("AURORA_HOST is not set")) + }) username <- Resource.eval(Env[IO].get("AURORA_USER").flatMap { - case Some(v) => IO.pure(v) - case None => IO.raiseError(new RuntimeException("AURORA_USER is not set")) - }) + case Some(v) => IO.pure(v) + case None => IO.raiseError(new RuntimeException("AURORA_USER is not set")) + }) config = MySQLConfig.default.setHost(hostname).setUser(username).setDatabase("world").setSSL(SSL.Trusted) plugin = AwsIamAuthenticationPlugin.default[IO]("ap-northeast-1", hostname, username) datasource <- MySQLDataSource.pooling[IO](config, plugins = List(plugin)) connector = Connector.fromDataSource(datasource, Some(logHandler)) _ <- EmberServerBuilder - .default[IO] - .withHost(host"0.0.0.0") - .withPort(port"9000") - .withHttpApp(routes(connector).orNotFound) - .build + .default[IO] + .withHost(host"0.0.0.0") + .withPort(port"9000") + .withHttpApp(routes(connector).orNotFound) + .build yield () diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala index 5a9333b89..21c125509 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala @@ -16,9 +16,9 @@ import cats.Monad import cats.effect.* import cats.effect.std.{ Env, SystemProperties, UUIDGen } +import fs2.hashing.Hashing import fs2.io.file.Files import fs2.io.net.* -import fs2.hashing.Hashing import ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain import ldbc.amazon.auth.token.{ AuthTokenGenerator, RdsIamAuthTokenGenerator } diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 445236b9c..70b4876cd 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -484,7 +484,7 @@ object MySQLDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Resource[F, PooledDataSource[F]] = PooledDataSource.fromConfig(config, metricsTracker, tracer, plugins) /** @@ -540,7 +540,7 @@ object MySQLDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ): Resource[F, PooledDataSource[F]] = diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala index 91493df97..ae33e225a 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala @@ -644,7 +644,7 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], idGenerator: F[String], - plugins: List[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None )(using Tracer[F]): Resource[F, PooledDataSource[F]] = @@ -663,7 +663,7 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], idGenerator: F[String], - plugins: List[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]], after: Option[(A, Connection[F]) => F[Unit]] )(using Tracer[F]): Resource[F, PooledDataSource[F]] = @@ -718,7 +718,7 @@ object PooledDataSource: keepaliveTime = config.keepaliveTime, connectionTestQuery = config.connectionTestQuery, poolLogger = poolLogger, - plugins = plugins, + plugins = plugins, before = before, after = after ) @@ -767,7 +767,7 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Resource[F, PooledDataSource[F]] = given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), plugins) @@ -797,7 +797,7 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, - plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ): Resource[F, PooledDataSource[F]] = From 4f1b09df813732496d461eb0267876ee4958dfed Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 17:42:39 +0900 Subject: [PATCH 195/215] Action sbt githubWorkflowGenerate --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1aa1b97bf..7255a536e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -394,7 +394,7 @@ jobs: - name: Submit Dependencies uses: scalacenter/sbt-dependency-submission@v2 with: - modules-ignore: ldbcjs_3 ldbcjs_3 otel_3 mcp-ldbc-document-server_sjs1_3 docs_3 docs_3 zio_3 ldbcnative_3 ldbcnative_3 ldbcjvm_3 ldbcjvm_3 hikaricp_3 tests_sjs1_3 tests_sjs1_3 http4s_3 tests_3 tests_3 benchmark_3 benchmark_3 tests_native0.4_3 tests_native0.4_3 + modules-ignore: ldbcjs_3 ldbcjs_3 otel_3 mcp-ldbc-document-server_sjs1_3 docs_3 docs_3 zio_3 ldbcnative_3 ldbcnative_3 ldbcjvm_3 ldbcjvm_3 hikaricp_3 tests_sjs1_3 tests_sjs1_3 http4s_3 aws-iam-auth_3 tests_3 tests_3 benchmark_3 benchmark_3 tests_native0.4_3 tests_native0.4_3 configs-ignore: test scala-tool scala-doc-tool test-internal validate-steward: From 00e9b91685795c06da82342debd7369efb89bdb3 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 18:27:19 +0900 Subject: [PATCH 196/215] Delete unused --- .../scala/ldbc/connector/pool/KeepaliveExecutorTest.scala | 3 --- .../scala/ldbc/connector/pool/PoolStatusReporterTest.scala | 5 ----- 2 files changed, 8 deletions(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala index 9c62b1612..2eb78302f 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/KeepaliveExecutorTest.scala @@ -16,8 +16,6 @@ import ldbc.sql.* import ldbc.connector.* -import ldbc.authentication.plugin.AuthenticationPlugin - class KeepaliveExecutorTest extends FTestPlatform: // Helper to create a mock Connection for testing @@ -150,7 +148,6 @@ class KeepaliveExecutorTest extends FTestPlatform: override def returnToPool(pooled: PooledConnection[F]) = Temporal[F].unit override def removeConnection(pooled: PooledConnection[F]) = Temporal[F].unit override def validateConnection(conn: Connection[F]) = validationFunc(conn) - override def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): PooledDataSource[F] = ??? test("KeepaliveExecutor should start and stop correctly") { for diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala index 3290457e8..bdff319a7 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala @@ -16,8 +16,6 @@ import munit.CatsEffectSuite import ldbc.connector.Connection -import ldbc.authentication.plugin.AuthenticationPlugin - class PoolStatusReporterTest extends CatsEffectSuite: class TestPoolLogger[F[_]: Applicative](var logCount: Int = 0) extends PoolLogger[F]: @@ -69,7 +67,6 @@ class PoolStatusReporterTest extends CatsEffectSuite: def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? def validateConnection(conn: Connection[IO]) = ??? - def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => @@ -112,7 +109,6 @@ class PoolStatusReporterTest extends CatsEffectSuite: def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? def validateConnection(conn: Connection[IO]) = ??? - def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => @@ -173,7 +169,6 @@ class PoolStatusReporterTest extends CatsEffectSuite: def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? def validateConnection(conn: Connection[IO]) = ??? - def setPlugins(p1: AuthenticationPlugin[IO], pn: AuthenticationPlugin[IO]*): PooledDataSource[IO] = ??? } reporter.start(pool, "test-pool").use { _ => From 127db31125891039707d12f84e278a7eaf87f40c Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 18:34:31 +0900 Subject: [PATCH 197/215] Action sbt scalafmtAll --- .../scala/ldbc/connector/pool/PoolStatusReporterTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala index bdff319a7..4bc5c4fde 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/pool/PoolStatusReporterTest.scala @@ -66,7 +66,7 @@ class PoolStatusReporterTest extends CatsEffectSuite: def createNewConnectionForPool() = ??? def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? - def validateConnection(conn: Connection[IO]) = ??? + def validateConnection(conn: Connection[IO]) = ??? } reporter.start(pool, "test-pool").use { _ => @@ -108,7 +108,7 @@ class PoolStatusReporterTest extends CatsEffectSuite: def createNewConnectionForPool() = ??? def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? - def validateConnection(conn: Connection[IO]) = ??? + def validateConnection(conn: Connection[IO]) = ??? } reporter.start(pool, "test-pool").use { _ => @@ -168,7 +168,7 @@ class PoolStatusReporterTest extends CatsEffectSuite: def createNewConnectionForPool() = ??? def returnToPool(pooled: PooledConnection[IO]) = ??? def removeConnection(pooled: PooledConnection[IO]) = ??? - def validateConnection(conn: Connection[IO]) = ??? + def validateConnection(conn: Connection[IO]) = ??? } reporter.start(pool, "test-pool").use { _ => From cffd1701cec1c298465f4c027bb2bc106e2d9bf9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 21:49:24 +0900 Subject: [PATCH 198/215] Added check endpoint --- .../credentials/InstanceProfileCredentialsProvider.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala index d32cbaea4..a65d5f0fa 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -131,9 +131,12 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( * @return The IMDS endpoint URL without trailing slash */ private def getImdsEndpoint(): F[String] = - Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").map { - case Some(endpoint) => endpoint.stripSuffix("/") - case None => DEFAULT_IMD_SEND_POINT + Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").flatMap { + case Some(endpoint) => + Concurrent[F] + .raiseUnless(endpoint.matches("^https?://169\\.254\\.169\\.254.*"))(new SecurityException("Invalid IMDS endpoint")) + .map(_ => endpoint.stripSuffix("/")) + case None => Concurrent[F].pure(DEFAULT_IMD_SEND_POINT) } /** From 5d1be9edd3a18e935929731e5f6c7847aa0420d5 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 21:49:44 +0900 Subject: [PATCH 199/215] Action sbt scalafmtAll --- .../credentials/InstanceProfileCredentialsProvider.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala index a65d5f0fa..aa7275b97 100644 --- a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -134,9 +134,11 @@ final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").flatMap { case Some(endpoint) => Concurrent[F] - .raiseUnless(endpoint.matches("^https?://169\\.254\\.169\\.254.*"))(new SecurityException("Invalid IMDS endpoint")) + .raiseUnless(endpoint.matches("^https?://169\\.254\\.169\\.254.*"))( + new SecurityException("Invalid IMDS endpoint") + ) .map(_ => endpoint.stripSuffix("/")) - case None => Concurrent[F].pure(DEFAULT_IMD_SEND_POINT) + case None => Concurrent[F].pure(DEFAULT_IMD_SEND_POINT) } /** From 3d6b64814e074b5a0b5e13c7a239dc1f2a1b41a9 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 22:16:37 +0900 Subject: [PATCH 200/215] Added escape for string parameter --- .../shared/src/main/scala/ldbc/connector/data/Parameter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala index da9b365f8..6805379ed 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala @@ -97,7 +97,7 @@ object Parameter: def string(value: String): Parameter = new Parameter: override def columnDataType: ColumnDataType = ColumnDataType.MYSQL_TYPE_STRING - override def sql: String = ("'" + value + "'") + override def sql: String = "'" + value.replaceAll("'", "''").replaceAll("\\\\", "\\\\\\\\") + "'" override def encode: BitVector = val bytes = value.getBytes BitVector(bytes.length) |+| BitVector(copyOf(bytes, bytes.length)) From 887f22904d279bd660a45840fea5388ae8adedee Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 22:31:46 +0900 Subject: [PATCH 201/215] Fixed test --- .../src/test/scala/ldbc/connector/data/ParameterTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala index f1db74ccb..02fdb3dd7 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala @@ -164,7 +164,7 @@ class ParameterTest extends FTestPlatform: // Test string with quotes val quotedStringParam = Parameter.string("test'quotes") - assertEquals(quotedStringParam.toString, "'test'quotes'") + assertEquals(quotedStringParam.toString, "'test''quotes'") // Test zero values assertEquals(Parameter.byte(0).toString, "0") From f21cc6a66eae66a738720693f86f9d0377bc8969 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 22:53:54 +0900 Subject: [PATCH 202/215] Added force connection close in pool --- .../src/main/scala/ldbc/connector/pool/PooledDataSource.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala index ae33e225a..c98c8ef99 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala @@ -238,7 +238,8 @@ object PooledDataSource: val closeAll = state.connections.traverse_ { pooled => pooled.finalizer.attempt.flatMap { case Left(error) => - poolLogger.debug(s"Error closing connection ${ pooled.id }: ${ error.getMessage }") + poolLogger.debug(s"Error closing connection ${ pooled.id }: ${ error.getMessage }") >> + pooled.connection.close().attempt.void case Right(_) => Temporal[F].unit } From 470325f7b197ada65f26c59eba6827fbab3265b6 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 23:10:34 +0900 Subject: [PATCH 203/215] Added atomic status check --- .../ldbc/connector/pool/ConcurrentBag.scala | 99 ++++++++++++++----- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala index 20ca1ad09..9fc6d9799 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala @@ -127,7 +127,6 @@ object ConcurrentBag: /** * Create a new ConcurrentBag instance. * - * @param maxFiberLocalSize maximum number of items to store in fiber-local storage * @tparam F the effect type * @tparam T the type of items stored in the bag * @return a new ConcurrentBag instance @@ -249,35 +248,39 @@ object ConcurrentBag: if state == BagEntry.STATE_REMOVED then Temporal[F].unit else // Reset state to not in use - item.setState(BagEntry.STATE_NOT_IN_USE) >> - // Check if item is in the shared list - sharedList.get.flatMap { list => - if list.exists(_ eq item) then - // Item is already in the list, handle waiters - waiters.get.flatMap { waiting => - if waiting > 0 then - // Try to hand off directly to a waiter - handoffQueue.tryOffer(item).void - else - // No waiters, item stays in shared list - Temporal[F].unit - } - else - // Item not in list, add it with distribution - sharedList.modify { list => - val newList = distributeItem(item, list) - (newList, ()) - } >> + item.compareAndSet(BagEntry.STATE_IN_USE, BagEntry.STATE_NOT_IN_USE).flatMap { + case true => // 成功時のみ処理継続 + // Check if item is in the shared list + sharedList.get.flatMap { list => + if list.exists(_ eq item) then + // Item is already in the list, handle waiters waiters.get.flatMap { waiting => if waiting > 0 then // Try to hand off directly to a waiter - handoffQueue.tryOffer(item).flatMap { - case true => Temporal[F].unit - case false => Temporal[F].unit - } - else Temporal[F].unit + handoffQueue.tryOffer(item).void + else + // No waiters, item stays in shared list + Temporal[F].unit } - } + else + // Item not in list, add it with distribution + sharedList.modify { list => + val newList = distributeItem(item, list) + (newList, ()) + } >> + waiters.get.flatMap { waiting => + if waiting > 0 then + // Try to hand off directly to a waiter + handoffQueue.tryOffer(item).flatMap { + case true => Temporal[F].unit + case false => Temporal[F].unit + } + else Temporal[F].unit + } + } + case false => // 他のファイバーが状態を変更済み + handleStateConflict(item) + } } } @@ -334,3 +337,47 @@ object ConcurrentBag: else val (front, back) = list.splitAt(idx - 1) front ++ (item :: back) + + private def handleStateConflict(item: T): F[Unit] = + item.getState.flatMap { + case BagEntry.STATE_REMOVED => + // Item deleted - No action required (normal termination) + Temporal[F].unit + + case BagEntry.STATE_NOT_IN_USE => + // Other fibers have already completed requite - Avoid duplicate processing + Temporal[F].unit + + case BagEntry.STATE_RESERVED => + // Temporary reservation status - Please wait a short while and try again + Temporal[F].sleep(1.milli) >> + item.compareAndSet(BagEntry.STATE_RESERVED, BagEntry.STATE_NOT_IN_USE).flatMap { + case true => continueRequiteProcess(item) + case false => handleStateConflict(item) + } + + case BagEntry.STATE_IN_USE => + // Anomaly: Still in IN_USE state - Detected state inconsistency + item.setState(BagEntry.STATE_NOT_IN_USE) >> + continueRequiteProcess(item) + + case unknownState => + Temporal[F].raiseError( + new IllegalStateException(s"Unknown bag entry state: $unknownState") + ) + } + + private def continueRequiteProcess(item: T): F[Unit] = + // Continue with the normal requite process after state conflict resolution + sharedList.get.flatMap { list => + if list.exists(_ eq item) then + waiters.get.flatMap { waiting => + if waiting > 0 then handoffQueue.tryOffer(item).void + else Temporal[F].unit + } + else + sharedList.modify { list => + val newList = distributeItem(item, list) + (newList, ()) + } + } From 1ab7761495d201339ffbb7461b6588b821d036b0 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 23:10:49 +0900 Subject: [PATCH 204/215] Action sbt scalafmtAll --- .../src/main/scala/ldbc/connector/pool/ConcurrentBag.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala index 9fc6d9799..aa603ad40 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala @@ -352,7 +352,7 @@ object ConcurrentBag: // Temporary reservation status - Please wait a short while and try again Temporal[F].sleep(1.milli) >> item.compareAndSet(BagEntry.STATE_RESERVED, BagEntry.STATE_NOT_IN_USE).flatMap { - case true => continueRequiteProcess(item) + case true => continueRequiteProcess(item) case false => handleStateConflict(item) } From ec77327e2c31d337f189da76089067b0e49233ba Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 23:43:20 +0900 Subject: [PATCH 205/215] Fixed concurrentBag statecheck bug --- .../ldbc/connector/pool/ConcurrentBag.scala | 48 +++++++------------ 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala index aa603ad40..157d07cca 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala @@ -247,39 +247,25 @@ object ConcurrentBag: item.getState.flatMap { state => if state == BagEntry.STATE_REMOVED then Temporal[F].unit else - // Reset state to not in use + // Attempt atomic transition from IN_USE to NOT_IN_USE item.compareAndSet(BagEntry.STATE_IN_USE, BagEntry.STATE_NOT_IN_USE).flatMap { - case true => // 成功時のみ処理継続 - // Check if item is in the shared list - sharedList.get.flatMap { list => - if list.exists(_ eq item) then - // Item is already in the list, handle waiters - waiters.get.flatMap { waiting => - if waiting > 0 then - // Try to hand off directly to a waiter - handoffQueue.tryOffer(item).void - else - // No waiters, item stays in shared list - Temporal[F].unit - } - else - // Item not in list, add it with distribution - sharedList.modify { list => - val newList = distributeItem(item, list) - (newList, ()) - } >> - waiters.get.flatMap { waiting => - if waiting > 0 then - // Try to hand off directly to a waiter - handoffQueue.tryOffer(item).flatMap { - case true => Temporal[F].unit - case false => Temporal[F].unit - } - else Temporal[F].unit - } + case true => // Successfully transitioned from IN_USE + continueRequiteProcess(item) + case false => + // Failed transition - could be NOT_IN_USE already or other state + item.getState.flatMap { currentState => + currentState match { + case BagEntry.STATE_NOT_IN_USE => + // Item already in correct state - continue with requite + continueRequiteProcess(item) + case BagEntry.STATE_REMOVED => + // Item was removed, nothing to do + Temporal[F].unit + case _ => + // Handle other state conflicts + handleStateConflict(item) + } } - case false => // 他のファイバーが状態を変更済み - handleStateConflict(item) } } } From aa1fa2d24022de5e53dae9f07293f071db3b3394 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Sun, 21 Dec 2025 23:43:30 +0900 Subject: [PATCH 206/215] Action sbt scalafmtAll --- .../src/main/scala/ldbc/connector/pool/ConcurrentBag.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala index 157d07cca..1b5806088 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala @@ -251,7 +251,7 @@ object ConcurrentBag: item.compareAndSet(BagEntry.STATE_IN_USE, BagEntry.STATE_NOT_IN_USE).flatMap { case true => // Successfully transitioned from IN_USE continueRequiteProcess(item) - case false => + case false => // Failed transition - could be NOT_IN_USE already or other state item.getState.flatMap { currentState => currentState match { From 8d3660b5086d3cec15732a48969834d8e0e69fa2 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 22 Dec 2025 00:39:06 +0900 Subject: [PATCH 207/215] Added maxPacketSize property --- .../scala/ldbc/connector/Connection.scala | 10 ++++++++ .../scala/ldbc/connector/MySQLConfig.scala | 8 +++++- .../ldbc/connector/MySQLDataSource.scala | 10 +++++++- .../exception/PacketTooBigException.scala | 15 +++++++++++ .../ldbc/connector/net/PacketSocket.scala | 25 ++++++++++++++++--- .../scala/ldbc/connector/net/Protocol.scala | 3 ++- 6 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index aca2f00fa..9ee86e0cd 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -77,6 +77,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] @@ -93,6 +94,7 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, defaultAuthenticationPlugin, plugins, @@ -115,6 +117,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] @@ -131,6 +134,7 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, defaultAuthenticationPlugin, plugins, @@ -151,6 +155,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], @@ -176,6 +181,7 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, defaultAuthenticationPlugin, plugins, @@ -197,6 +203,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: List[AuthenticationPlugin[F]], @@ -219,6 +226,7 @@ object Connection: allowPublicKeyRetrieval, readTimeout, capabilityFlags, + maxAllowedPacket, defaultAuthenticationPlugin, pluginMap ) @@ -260,6 +268,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: List[AuthenticationPlugin[F]], @@ -290,6 +299,7 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, defaultAuthenticationPlugin, plugins, diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala index 9cab703ae..1457ba885 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala @@ -415,6 +415,10 @@ trait MySQLConfig: * @return a new MySQLConfig with the updated setting */ def setPoolName(name: String): MySQLConfig + + def maxAllowedPacket: Int + + def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig /** * Companion object for MySQLConfig providing factory methods. @@ -455,7 +459,8 @@ object MySQLConfig: connectionTestQuery: Option[String] = None, logPoolState: Boolean = false, poolStateLogInterval: FiniteDuration = 30.seconds, - poolName: String = "ldbc-pool" + poolName: String = "ldbc-pool", + maxAllowedPacket: Int = 65535 ) extends MySQLConfig: override def setHost(host: String): MySQLConfig = copy(host = host) @@ -491,6 +496,7 @@ object MySQLConfig: override def setLogPoolState(enabled: Boolean): MySQLConfig = copy(logPoolState = enabled) override def setPoolStateLogInterval(interval: FiniteDuration): MySQLConfig = copy(poolStateLogInterval = interval) override def setPoolName(name: String): MySQLConfig = copy(poolName = name) + override def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig = copy(maxAllowedPacket = maxAllowedPacket) /** * Creates a default MySQLConfig with standard connection parameters. diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 70b4876cd..3f6101e61 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -84,6 +84,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen tracer: Option[Tracer[F]] = None, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = 65535, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, @@ -118,6 +119,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, + maxAllowedPacket = maxAllowedPacket, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, plugins = plugins @@ -138,6 +140,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, + maxAllowedPacket = maxAllowedPacket, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, plugins = plugins @@ -156,6 +159,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, + maxAllowedPacket = maxAllowedPacket, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, plugins = plugins @@ -255,6 +259,9 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen */ def setUseServerPrepStmts(newUseServerPrepStmts: Boolean): MySQLDataSource[F, A] = copy(useServerPrepStmts = newUseServerPrepStmts) + + def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLDataSource[F, A] = + copy(maxAllowedPacket = maxAllowedPacket) /** Sets whether to authentication plugin to be used first for communication with the server. * @param defaultAuthenticationPlugin @@ -389,7 +396,8 @@ object MySQLDataSource: allowPublicKeyRetrieval = config.allowPublicKeyRetrieval, databaseTerm = config.databaseTerm, useCursorFetch = config.useCursorFetch, - useServerPrepStmts = config.useServerPrepStmts + useServerPrepStmts = config.useServerPrepStmts, + maxAllowedPacket = config.maxAllowedPacket ) /** diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala new file mode 100644 index 000000000..c550f9fe2 --- /dev/null +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala @@ -0,0 +1,15 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.connector.exception + +class PacketTooBigException( + packetLength: Int, + maxAllowed: Int +) extends SQLException( + s"Packet for query is too large ($packetLength > $maxAllowed). " + + s"You can change the value by setting the 'maxAllowedPacket' configuration." +) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala index 4a4fcc2cc..850bed4c6 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala @@ -24,6 +24,7 @@ import ldbc.connector.data.CapabilitiesFlags import ldbc.connector.net.packet.* import ldbc.connector.net.packet.response.InitialPacket import ldbc.connector.net.protocol.parseHeader +import ldbc.connector.exception.PacketTooBigException /** * A higher-level `BitVectorSocket` that speaks in terms of `Packet`. @@ -41,10 +42,15 @@ trait PacketSocket[F[_]]: object PacketSocket: + val DEFAULT_MAX_PACKET_SIZE = 65535 // 64KB (JDBC Driver default) + val PROTOCOL_MAX_PACKET_SIZE = 16777215 // 16MB (MySQL protocol limit) + val MIN_PACKET_SIZE = 0 + def fromBitVectorSocket[F[_]: Concurrent: Console]( bvs: BitVectorSocket[F], debugEnabled: Boolean, - sequenceIdRef: Ref[F, Byte] + sequenceIdRef: Ref[F, Byte], + maxAllowedPacket: Int ): PacketSocket[F] = new PacketSocket[F]: private def debug(msg: => String): F[Unit] = @@ -56,6 +62,7 @@ object PacketSocket: (for header <- bvs.read(4) payloadSize = parseHeader(header.toByteArray) + _ <- validatePacketSize(payloadSize) payload <- bvs.read(payloadSize) response = decoder.decodeValue(payload).require _ <- @@ -94,6 +101,17 @@ object PacketSocket: _ <- sequenceIdRef.update(sequenceId => ((sequenceId + 1) % 256).toByte) yield () + private def validatePacketSize(size: Int): F[Unit] = + if size < MIN_PACKET_SIZE then + Concurrent[F].raiseError( + PacketTooBigException(size, maxAllowedPacket) + ) + else if size > maxAllowedPacket then + Concurrent[F].raiseError( + PacketTooBigException(size, maxAllowedPacket) + ) + else Concurrent[F].unit + def apply[F[_]: Console: Temporal]( debug: Boolean, sockets: Resource[F, Socket[F]], @@ -101,8 +119,9 @@ object PacketSocket: sequenceIdRef: Ref[F, Byte], initialPacketRef: Ref[F, Option[InitialPacket]], readTimeout: Duration, - capabilitiesFlags: Set[CapabilitiesFlags] + capabilitiesFlags: Set[CapabilitiesFlags], + maxAllowedPacket: Int ): Resource[F, PacketSocket[F]] = BitVectorSocket(sockets, sequenceIdRef, initialPacketRef, sslOptions, readTimeout, capabilitiesFlags).map( - fromBitVectorSocket(_, debug, sequenceIdRef) + fromBitVectorSocket(_, debug, sequenceIdRef, maxAllowedPacket) ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index 701e8346d..f71e16504 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -583,6 +583,7 @@ object Protocol: allowPublicKeyRetrieval: Boolean = false, readTimeout: Duration, capabilitiesFlags: Set[CapabilitiesFlags], + maxAllowedPacket: Int, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: Map[String, AuthenticationPlugin[F]] ): Resource[F, Protocol[F]] = @@ -590,7 +591,7 @@ object Protocol: sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) initialPacketRef <- Resource.eval(Ref[F].of[Option[InitialPacket]](None)) packetSocket <- - PacketSocket[F](debug, sockets, sslOptions, sequenceIdRef, initialPacketRef, readTimeout, capabilitiesFlags) + PacketSocket[F](debug, sockets, sslOptions, sequenceIdRef, initialPacketRef, readTimeout, capabilitiesFlags, maxAllowedPacket) protocol <- Resource.eval( fromPacketSocket( packetSocket, From 0e678fb74f025e02542e2f3114d3e69ca4cd9d7b Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 22 Dec 2025 00:44:19 +0900 Subject: [PATCH 208/215] Added scaladoc comment --- .../scala/ldbc/connector/MySQLConfig.scala | 39 ++++++++++++++++++- .../ldbc/connector/MySQLDataSource.scala | 8 +++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala index 1457ba885..0fc24bbf9 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala @@ -415,9 +415,46 @@ trait MySQLConfig: * @return a new MySQLConfig with the updated setting */ def setPoolName(name: String): MySQLConfig - + + /** + * Gets the maximum allowed packet size for network communication with MySQL server. + * + * This setting controls the maximum size of packets that can be sent to or received from + * the MySQL server. It helps prevent memory exhaustion attacks and ensures compatibility + * with the MySQL protocol limits. + * + * The value corresponds to the MySQL server's `max_allowed_packet` system variable. + * + * @return the maximum packet size in bytes + */ def maxAllowedPacket: Int + /** + * Sets the maximum allowed packet size for network communication. + * + * This setting provides protection against: + * - Memory exhaustion attacks through oversized packets + * - Denial of Service (DoS) attacks via large data payloads + * - Accidental transmission of extremely large data sets + * + * @param maxAllowedPacket the maximum packet size in bytes + * @return a new MySQLConfig with the updated setting + * + * @example {{{ + * // Set conservative 64KB limit (default) + * config.setMaxAllowedPacket(65535) + * + * // Set practical 1MB limit for applications with moderate BLOB usage + * config.setMaxAllowedPacket(1048576) + * + * // Set maximum protocol limit for applications requiring large data transfers + * config.setMaxAllowedPacket(16777215) + * }}} + * + * @note The default value of 65,535 bytes (64KB) is compatible with MySQL JDBC Driver defaults + * and provides good security against packet-based attacks while accommodating most use cases. + * @see [[https://dev.mysql.com/doc/refman/en/packet-too-large.html MySQL Protocol Packet Limits]] + */ def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig /** diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 3f6101e61..2cbfa0b43 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -48,6 +48,7 @@ import ldbc.DataSource * @param tracer optional OpenTelemetry tracer for distributed tracing * @param useCursorFetch whether to use cursor-based fetching for result sets * @param useServerPrepStmts whether to use server-side prepared statements + * @param maxAllowedPacket Maximum allowed packet size for network communication in bytes. * @param defaultAuthenticationPlugin The authentication plugin used first for communication with the server * @param plugins Additional authentication plugins used for communication with the server * @param before optional hook to execute before a connection is acquired @@ -259,7 +260,12 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen */ def setUseServerPrepStmts(newUseServerPrepStmts: Boolean): MySQLDataSource[F, A] = copy(useServerPrepStmts = newUseServerPrepStmts) - + + /** Sets the maximum allowed packet size for network communication. + * + * @param maxAllowedPacket the maximum packet size in bytes (0 to 16,777,215) + * @return a new MySQLDataSource with the updated packet size limit + */ def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLDataSource[F, A] = copy(maxAllowedPacket = maxAllowedPacket) From 12b2ffa4b628efc025101be720584cb659823832 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 22 Dec 2025 00:44:44 +0900 Subject: [PATCH 209/215] Added scaladoc comment --- .../exception/PacketTooBigException.scala | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala index c550f9fe2..cb7ddd734 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala @@ -6,6 +6,56 @@ package ldbc.connector.exception +/** + * Exception thrown when a MySQL protocol packet exceeds the configured maximum packet size limit. + * + * This exception serves as a security mechanism to prevent memory exhaustion attacks and ensures + * compatibility with MySQL protocol constraints. It is thrown when the connector attempts to + * send or receive a packet that exceeds the `maxAllowedPacket` configuration setting. + * + * The `maxAllowedPacket` setting corresponds to MySQL's `max_allowed_packet` system variable + * and provides protection against: + * + * - **Memory exhaustion attacks**: Prevents attackers from sending oversized packets to consume server memory + * - **Denial of Service (DoS)**: Blocks attempts to overwhelm the system with large data payloads + * - **Accidental large data transfers**: Catches unintentional transmission of extremely large datasets + * - **Protocol violations**: Enforces MySQL protocol packet size constraints (max 16MB) + * + * @param packetLength the actual size of the packet that exceeded the limit, in bytes + * @param maxAllowed the configured maximum packet size limit, in bytes + * + * @example {{{ + * // Typical usage when packet size validation fails + * try { + * // Some operation that sends a large packet + * connection.executeUpdate(queryWithLargeData) + * } catch { + * case ex: PacketTooBigException => + * println(s"Packet too large: ${ex.getMessage}") + * // Consider increasing maxAllowedPacket or reducing data size + * } + * }}} + * + * @example {{{ + * // Configure larger packet size to handle bigger data + * val config = MySQLConfig.default + * .setMaxAllowedPacket(1048576) // 1MB limit + * + * // Or use MySQL protocol maximum + * val config = MySQLConfig.default + * .setMaxAllowedPacket(16777215) // 16MB - protocol maximum + * }}} + * + * @see [[ldbc.connector.MySQLConfig.maxAllowedPacket]] for configuration details + * @see [[ldbc.connector.MySQLConfig.setMaxAllowedPacket]] for setting the packet size limit + * @see [[https://dev.mysql.com/doc/refman/en/packet-too-large.html MySQL Protocol Packet Limits]] + * + * @note The default packet size limit is 65,535 bytes (64KB), which matches MySQL JDBC Driver defaults + * and provides good security against packet-based attacks while accommodating most use cases. + * + * @note This exception extends SQLException to maintain compatibility with JDBC error handling patterns. + * The error message includes both the actual packet size and the configured limit for debugging. + */ class PacketTooBigException( packetLength: Int, maxAllowed: Int From 6d161cbba3e1c057c20c6fbe0fd0ce317b783363 Mon Sep 17 00:00:00 2001 From: takapi327 Date: Mon, 22 Dec 2025 00:45:06 +0900 Subject: [PATCH 210/215] Action sbt scalafmtAll --- .../main/scala/ldbc/connector/Connection.scala | 12 ++++++------ .../main/scala/ldbc/connector/MySQLConfig.scala | 4 ++-- .../scala/ldbc/connector/MySQLDataSource.scala | 10 +++++----- .../exception/PacketTooBigException.scala | 8 ++++---- .../scala/ldbc/connector/net/PacketSocket.scala | 16 ++++++++-------- .../main/scala/ldbc/connector/net/Protocol.scala | 13 +++++++++++-- 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index 9ee86e0cd..ea81e4a10 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -77,7 +77,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] @@ -117,7 +117,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] @@ -155,7 +155,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], @@ -181,7 +181,7 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, - maxAllowedPacket, + maxAllowedPacket, databaseTerm, defaultAuthenticationPlugin, plugins, @@ -203,7 +203,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: List[AuthenticationPlugin[F]], @@ -268,7 +268,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = 65535, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: List[AuthenticationPlugin[F]], diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala index 0fc24bbf9..df91818a2 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala @@ -428,7 +428,7 @@ trait MySQLConfig: * @return the maximum packet size in bytes */ def maxAllowedPacket: Int - + /** * Sets the maximum allowed packet size for network communication. * @@ -497,7 +497,7 @@ object MySQLConfig: logPoolState: Boolean = false, poolStateLogInterval: FiniteDuration = 30.seconds, poolName: String = "ldbc-pool", - maxAllowedPacket: Int = 65535 + maxAllowedPacket: Int = 65535 ) extends MySQLConfig: override def setHost(host: String): MySQLConfig = copy(host = host) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 2cbfa0b43..61d70572c 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -85,7 +85,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen tracer: Option[Tracer[F]] = None, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = 65535, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, @@ -120,7 +120,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - maxAllowedPacket = maxAllowedPacket, + maxAllowedPacket = maxAllowedPacket, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, plugins = plugins @@ -141,7 +141,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - maxAllowedPacket = maxAllowedPacket, + maxAllowedPacket = maxAllowedPacket, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, plugins = plugins @@ -160,7 +160,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - maxAllowedPacket = maxAllowedPacket, + maxAllowedPacket = maxAllowedPacket, databaseTerm = databaseTerm, defaultAuthenticationPlugin = defaultAuthenticationPlugin, plugins = plugins @@ -403,7 +403,7 @@ object MySQLDataSource: databaseTerm = config.databaseTerm, useCursorFetch = config.useCursorFetch, useServerPrepStmts = config.useServerPrepStmts, - maxAllowedPacket = config.maxAllowedPacket + maxAllowedPacket = config.maxAllowedPacket ) /** diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala index cb7ddd734..7360e1c1f 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala @@ -58,8 +58,8 @@ package ldbc.connector.exception */ class PacketTooBigException( packetLength: Int, - maxAllowed: Int + maxAllowed: Int ) extends SQLException( - s"Packet for query is too large ($packetLength > $maxAllowed). " + - s"You can change the value by setting the 'maxAllowedPacket' configuration." -) + s"Packet for query is too large ($packetLength > $maxAllowed). " + + s"You can change the value by setting the 'maxAllowedPacket' configuration." + ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala index 850bed4c6..0714e97be 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala @@ -21,10 +21,10 @@ import fs2.io.net.Socket import fs2.Chunk import ldbc.connector.data.CapabilitiesFlags +import ldbc.connector.exception.PacketTooBigException import ldbc.connector.net.packet.* import ldbc.connector.net.packet.response.InitialPacket import ldbc.connector.net.protocol.parseHeader -import ldbc.connector.exception.PacketTooBigException /** * A higher-level `BitVectorSocket` that speaks in terms of `Packet`. @@ -42,14 +42,14 @@ trait PacketSocket[F[_]]: object PacketSocket: - val DEFAULT_MAX_PACKET_SIZE = 65535 // 64KB (JDBC Driver default) + val DEFAULT_MAX_PACKET_SIZE = 65535 // 64KB (JDBC Driver default) val PROTOCOL_MAX_PACKET_SIZE = 16777215 // 16MB (MySQL protocol limit) - val MIN_PACKET_SIZE = 0 + val MIN_PACKET_SIZE = 0 def fromBitVectorSocket[F[_]: Concurrent: Console]( - bvs: BitVectorSocket[F], - debugEnabled: Boolean, - sequenceIdRef: Ref[F, Byte], + bvs: BitVectorSocket[F], + debugEnabled: Boolean, + sequenceIdRef: Ref[F, Byte], maxAllowedPacket: Int ): PacketSocket[F] = new PacketSocket[F]: @@ -62,7 +62,7 @@ object PacketSocket: (for header <- bvs.read(4) payloadSize = parseHeader(header.toByteArray) - _ <- validatePacketSize(payloadSize) + _ <- validatePacketSize(payloadSize) payload <- bvs.read(payloadSize) response = decoder.decodeValue(payload).require _ <- @@ -120,7 +120,7 @@ object PacketSocket: initialPacketRef: Ref[F, Option[InitialPacket]], readTimeout: Duration, capabilitiesFlags: Set[CapabilitiesFlags], - maxAllowedPacket: Int + maxAllowedPacket: Int ): Resource[F, PacketSocket[F]] = BitVectorSocket(sockets, sequenceIdRef, initialPacketRef, sslOptions, readTimeout, capabilitiesFlags).map( fromBitVectorSocket(_, debug, sequenceIdRef, maxAllowedPacket) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index f71e16504..052aa59cb 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -583,7 +583,7 @@ object Protocol: allowPublicKeyRetrieval: Boolean = false, readTimeout: Duration, capabilitiesFlags: Set[CapabilitiesFlags], - maxAllowedPacket: Int, + maxAllowedPacket: Int, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: Map[String, AuthenticationPlugin[F]] ): Resource[F, Protocol[F]] = @@ -591,7 +591,16 @@ object Protocol: sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) initialPacketRef <- Resource.eval(Ref[F].of[Option[InitialPacket]](None)) packetSocket <- - PacketSocket[F](debug, sockets, sslOptions, sequenceIdRef, initialPacketRef, readTimeout, capabilitiesFlags, maxAllowedPacket) + PacketSocket[F]( + debug, + sockets, + sslOptions, + sequenceIdRef, + initialPacketRef, + readTimeout, + capabilitiesFlags, + maxAllowedPacket + ) protocol <- Resource.eval( fromPacketSocket( packetSocket, From 14b22859bc4bc30e65434cc29c033fb58f92b05d Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 25 Dec 2025 19:28:23 +0900 Subject: [PATCH 211/215] Added max packet size range check --- .../scala/ldbc/connector/MySQLConfig.scala | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala index df91818a2..fd04508e7 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala @@ -439,6 +439,7 @@ trait MySQLConfig: * * @param maxAllowedPacket the maximum packet size in bytes * @return a new MySQLConfig with the updated setting + * @throws IllegalArgumentException if the value is outside the valid range (1024 to 16,777,215) * * @example {{{ * // Set conservative 64KB limit (default) @@ -453,6 +454,7 @@ trait MySQLConfig: * * @note The default value of 65,535 bytes (64KB) is compatible with MySQL JDBC Driver defaults * and provides good security against packet-based attacks while accommodating most use cases. + * @note Valid range: 1,024 bytes (1KB) minimum to 16,777,215 bytes (16MB) maximum (MySQL protocol limit) * @see [[https://dev.mysql.com/doc/refman/en/packet-too-large.html MySQL Protocol Packet Limits]] */ def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig @@ -462,6 +464,15 @@ trait MySQLConfig: */ object MySQLConfig: + /** Minimum allowed packet size in bytes (1KB) */ + val MIN_PACKET_SIZE: Int = 1024 + + /** Maximum allowed packet size in bytes (16MB - MySQL protocol limit) */ + val MAX_PACKET_SIZE: Int = 16777215 + + /** Default packet size in bytes (64KB - MySQL JDBC Driver compatible) */ + val DEFAULT_PACKET_SIZE: Int = 65535 + /** Default socket options applied to all connections. */ private[ldbc] val defaultSocketOptions: List[SocketOption] = List(SocketOption.noDelay(true)) @@ -497,7 +508,7 @@ object MySQLConfig: logPoolState: Boolean = false, poolStateLogInterval: FiniteDuration = 30.seconds, poolName: String = "ldbc-pool", - maxAllowedPacket: Int = 65535 + maxAllowedPacket: Int = DEFAULT_PACKET_SIZE ) extends MySQLConfig: override def setHost(host: String): MySQLConfig = copy(host = host) @@ -533,7 +544,11 @@ object MySQLConfig: override def setLogPoolState(enabled: Boolean): MySQLConfig = copy(logPoolState = enabled) override def setPoolStateLogInterval(interval: FiniteDuration): MySQLConfig = copy(poolStateLogInterval = interval) override def setPoolName(name: String): MySQLConfig = copy(poolName = name) - override def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig = copy(maxAllowedPacket = maxAllowedPacket) + override def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig = { + require(maxAllowedPacket >= MIN_PACKET_SIZE, s"maxAllowedPacket must be at least $MIN_PACKET_SIZE bytes, but got $maxAllowedPacket") + require(maxAllowedPacket <= MAX_PACKET_SIZE, s"maxAllowedPacket must not exceed $MAX_PACKET_SIZE bytes (MySQL protocol limit), but got $maxAllowedPacket") + copy(maxAllowedPacket = maxAllowedPacket) + } /** * Creates a default MySQLConfig with standard connection parameters. From 75f9549682f489273a7e1dc2db1dd00ccafde18a Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 25 Dec 2025 19:28:41 +0900 Subject: [PATCH 212/215] Added max packet size range check for datasource --- .../main/scala/ldbc/connector/MySQLDataSource.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 61d70572c..684887225 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -85,7 +85,7 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen tracer: Option[Tracer[F]] = None, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, @@ -263,11 +263,15 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen /** Sets the maximum allowed packet size for network communication. * - * @param maxAllowedPacket the maximum packet size in bytes (0 to 16,777,215) + * @param maxAllowedPacket the maximum packet size in bytes (1,024 to 16,777,215) * @return a new MySQLDataSource with the updated packet size limit + * @throws IllegalArgumentException if the value is outside the valid range */ - def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLDataSource[F, A] = + def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLDataSource[F, A] = { + require(maxAllowedPacket >= MySQLConfig.MIN_PACKET_SIZE, s"maxAllowedPacket must be at least ${MySQLConfig.MIN_PACKET_SIZE} bytes, but got $maxAllowedPacket") + require(maxAllowedPacket <= MySQLConfig.MAX_PACKET_SIZE, s"maxAllowedPacket must not exceed ${MySQLConfig.MAX_PACKET_SIZE} bytes (MySQL protocol limit), but got $maxAllowedPacket") copy(maxAllowedPacket = maxAllowedPacket) + } /** Sets whether to authentication plugin to be used first for communication with the server. * @param defaultAuthenticationPlugin From c9e3ec93418887c60c2b28c683efacee6f47266a Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 25 Dec 2025 19:29:09 +0900 Subject: [PATCH 213/215] Change use default max packet size --- .../src/main/scala/ldbc/connector/Connection.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index ea81e4a10..eb355d5a7 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -77,7 +77,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] @@ -117,7 +117,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] @@ -155,7 +155,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], @@ -203,7 +203,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: List[AuthenticationPlugin[F]], @@ -268,7 +268,7 @@ object Connection: allowPublicKeyRetrieval: Boolean = false, useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, - maxAllowedPacket: Int = 65535, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], plugins: List[AuthenticationPlugin[F]], From 7e4d4127a68b16919b160a222ed4cd5d5f0cb97f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 25 Dec 2025 19:39:51 +0900 Subject: [PATCH 214/215] Added max packet size test --- .../ldbc/connector/MySQLConfigTest.scala | 65 +++++++++++++++++++ .../ldbc/connector/MySQLDataSourceTest.scala | 61 +++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala index 7717485a4..acf6aafdb 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala @@ -30,6 +30,7 @@ class MySQLConfigTest extends FTestPlatform: assertEquals(config.databaseTerm, Some(DatabaseMetaData.DatabaseTerm.CATALOG)) assertEquals(config.useCursorFetch, false) assertEquals(config.useServerPrepStmts, false) + assertEquals(config.maxAllowedPacket, MySQLConfig.DEFAULT_PACKET_SIZE) } test("setHost should update host value") { @@ -130,6 +131,68 @@ class MySQLConfigTest extends FTestPlatform: assertEquals(updated.useServerPrepStmts, true) } + test("setMaxAllowedPacket should update maxAllowedPacket value") { + val config = MySQLConfig.default + val updated = config.setMaxAllowedPacket(1048576) // 1MB + + assertEquals(updated.maxAllowedPacket, 1048576) + // Ensure other values remain unchanged + assertEquals(updated.host, config.host) + assertEquals(updated.port, config.port) + } + + test("setMaxAllowedPacket should accept minimum valid value") { + val config = MySQLConfig.default + val updated = config.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MIN_PACKET_SIZE) + } + + test("setMaxAllowedPacket should accept maximum valid value") { + val config = MySQLConfig.default + val updated = config.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MAX_PACKET_SIZE) + } + + test("setMaxAllowedPacket should reject values below minimum") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE - 1) + } + } + + test("setMaxAllowedPacket should reject values above maximum") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE + 1) + } + } + + test("setMaxAllowedPacket should reject zero value") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(0) + } + } + + test("setMaxAllowedPacket should reject negative values") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(-1) + } + } + + test("MySQLConfig constants should have expected values") { + assertEquals(MySQLConfig.MIN_PACKET_SIZE, 1024) + assertEquals(MySQLConfig.MAX_PACKET_SIZE, 16777215) + assertEquals(MySQLConfig.DEFAULT_PACKET_SIZE, 65535) + } + test("MySQLConfig should be immutable - original config should not change") { val original = MySQLConfig.default val originalHost = original.host @@ -167,6 +230,7 @@ class MySQLConfigTest extends FTestPlatform: .setAllowPublicKeyRetrieval(true) .setUseCursorFetch(true) .setUseServerPrepStmts(true) + .setMaxAllowedPacket(1048576) assertEquals(config.host, "localhost") assertEquals(config.port, 3307) @@ -178,6 +242,7 @@ class MySQLConfigTest extends FTestPlatform: assertEquals(config.allowPublicKeyRetrieval, true) assertEquals(config.useCursorFetch, true) assertEquals(config.useServerPrepStmts, true) + assertEquals(config.maxAllowedPacket, 1048576) } test("MySQLConfig with custom socket options") { diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala index e343c0366..c729cfdfa 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala @@ -40,6 +40,7 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(dataSource.databaseTerm, Some(DatabaseMetaData.DatabaseTerm.CATALOG)) assertEquals(dataSource.useCursorFetch, false) assertEquals(dataSource.useServerPrepStmts, false) + assertEquals(dataSource.maxAllowedPacket, MySQLConfig.DEFAULT_PACKET_SIZE) assertEquals(dataSource.before, None) assertEquals(dataSource.after, None) } @@ -149,6 +150,62 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(updated.useServerPrepStmts, true) } + test("setMaxAllowedPacket should update maxAllowedPacket value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + val updated = dataSource.setMaxAllowedPacket(1048576) // 1MB + + assertEquals(updated.maxAllowedPacket, 1048576) + // Ensure other values remain unchanged + assertEquals(updated.host, dataSource.host) + assertEquals(updated.port, dataSource.port) + } + + test("setMaxAllowedPacket should accept minimum valid value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + val updated = dataSource.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MIN_PACKET_SIZE) + } + + test("setMaxAllowedPacket should accept maximum valid value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + val updated = dataSource.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MAX_PACKET_SIZE) + } + + test("setMaxAllowedPacket should reject values below minimum") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE - 1) + } + } + + test("setMaxAllowedPacket should reject values above maximum") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE + 1) + } + } + + test("setMaxAllowedPacket should reject zero value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(0) + } + } + + test("setMaxAllowedPacket should reject negative values") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(-1) + } + } + test("withBefore should add a before hook and change type parameter") { val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") val beforeHook: Connection[IO] => IO[String] = _ => IO.pure("before result") @@ -191,6 +248,7 @@ class MySQLDataSourceTest extends FTestPlatform: .setAllowPublicKeyRetrieval(true) .setUseCursorFetch(true) .setUseServerPrepStmts(true) + .setMaxAllowedPacket(1048576) assertEquals(dataSource.host, "127.0.0.1") assertEquals(dataSource.port, 3307) @@ -202,6 +260,7 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(dataSource.allowPublicKeyRetrieval, true) assertEquals(dataSource.useCursorFetch, true) assertEquals(dataSource.useServerPrepStmts, true) + assertEquals(dataSource.maxAllowedPacket, 1048576) } test("MySQLDataSource.fromConfig should create DataSource from MySQLConfig") { @@ -212,6 +271,7 @@ class MySQLDataSourceTest extends FTestPlatform: .setPassword("configpass") .setDatabase("configdb") .setDebug(true) + .setMaxAllowedPacket(2097152) // 2MB val dataSource = MySQLDataSource.fromConfig[IO](config) @@ -221,6 +281,7 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(dataSource.password, Some("configpass")) assertEquals(dataSource.database, Some("configdb")) assertEquals(dataSource.debug, true) + assertEquals(dataSource.maxAllowedPacket, 2097152) } test("MySQLDataSource.default should create DataSource with default config") { From 9e2689224eb6a3f84af2f0941de466619eef0e9f Mon Sep 17 00:00:00 2001 From: takapi327 Date: Thu, 25 Dec 2025 19:40:43 +0900 Subject: [PATCH 215/215] Action sbt scalafmtAll --- .../main/scala/ldbc/connector/MySQLConfig.scala | 16 +++++++++++----- .../scala/ldbc/connector/MySQLDataSource.scala | 10 ++++++++-- .../scala/ldbc/connector/MySQLConfigTest.scala | 8 ++++---- .../ldbc/connector/MySQLDataSourceTest.scala | 8 ++++---- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala index fd04508e7..486c2dd18 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala @@ -466,10 +466,10 @@ object MySQLConfig: /** Minimum allowed packet size in bytes (1KB) */ val MIN_PACKET_SIZE: Int = 1024 - + /** Maximum allowed packet size in bytes (16MB - MySQL protocol limit) */ val MAX_PACKET_SIZE: Int = 16777215 - + /** Default packet size in bytes (64KB - MySQL JDBC Driver compatible) */ val DEFAULT_PACKET_SIZE: Int = 65535 @@ -544,9 +544,15 @@ object MySQLConfig: override def setLogPoolState(enabled: Boolean): MySQLConfig = copy(logPoolState = enabled) override def setPoolStateLogInterval(interval: FiniteDuration): MySQLConfig = copy(poolStateLogInterval = interval) override def setPoolName(name: String): MySQLConfig = copy(poolName = name) - override def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig = { - require(maxAllowedPacket >= MIN_PACKET_SIZE, s"maxAllowedPacket must be at least $MIN_PACKET_SIZE bytes, but got $maxAllowedPacket") - require(maxAllowedPacket <= MAX_PACKET_SIZE, s"maxAllowedPacket must not exceed $MAX_PACKET_SIZE bytes (MySQL protocol limit), but got $maxAllowedPacket") + override def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig = { + require( + maxAllowedPacket >= MIN_PACKET_SIZE, + s"maxAllowedPacket must be at least $MIN_PACKET_SIZE bytes, but got $maxAllowedPacket" + ) + require( + maxAllowedPacket <= MAX_PACKET_SIZE, + s"maxAllowedPacket must not exceed $MAX_PACKET_SIZE bytes (MySQL protocol limit), but got $maxAllowedPacket" + ) copy(maxAllowedPacket = maxAllowedPacket) } diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index 684887225..22599d55c 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -268,8 +268,14 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen * @throws IllegalArgumentException if the value is outside the valid range */ def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLDataSource[F, A] = { - require(maxAllowedPacket >= MySQLConfig.MIN_PACKET_SIZE, s"maxAllowedPacket must be at least ${MySQLConfig.MIN_PACKET_SIZE} bytes, but got $maxAllowedPacket") - require(maxAllowedPacket <= MySQLConfig.MAX_PACKET_SIZE, s"maxAllowedPacket must not exceed ${MySQLConfig.MAX_PACKET_SIZE} bytes (MySQL protocol limit), but got $maxAllowedPacket") + require( + maxAllowedPacket >= MySQLConfig.MIN_PACKET_SIZE, + s"maxAllowedPacket must be at least ${ MySQLConfig.MIN_PACKET_SIZE } bytes, but got $maxAllowedPacket" + ) + require( + maxAllowedPacket <= MySQLConfig.MAX_PACKET_SIZE, + s"maxAllowedPacket must not exceed ${ MySQLConfig.MAX_PACKET_SIZE } bytes (MySQL protocol limit), but got $maxAllowedPacket" + ) copy(maxAllowedPacket = maxAllowedPacket) } diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala index acf6aafdb..e84f6a958 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala @@ -157,7 +157,7 @@ class MySQLConfigTest extends FTestPlatform: test("setMaxAllowedPacket should reject values below minimum") { val config = MySQLConfig.default - + intercept[IllegalArgumentException] { config.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE - 1) } @@ -165,7 +165,7 @@ class MySQLConfigTest extends FTestPlatform: test("setMaxAllowedPacket should reject values above maximum") { val config = MySQLConfig.default - + intercept[IllegalArgumentException] { config.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE + 1) } @@ -173,7 +173,7 @@ class MySQLConfigTest extends FTestPlatform: test("setMaxAllowedPacket should reject zero value") { val config = MySQLConfig.default - + intercept[IllegalArgumentException] { config.setMaxAllowedPacket(0) } @@ -181,7 +181,7 @@ class MySQLConfigTest extends FTestPlatform: test("setMaxAllowedPacket should reject negative values") { val config = MySQLConfig.default - + intercept[IllegalArgumentException] { config.setMaxAllowedPacket(-1) } diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala index c729cfdfa..e5f63e101 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala @@ -176,7 +176,7 @@ class MySQLDataSourceTest extends FTestPlatform: test("setMaxAllowedPacket should reject values below minimum") { val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") - + intercept[IllegalArgumentException] { dataSource.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE - 1) } @@ -184,7 +184,7 @@ class MySQLDataSourceTest extends FTestPlatform: test("setMaxAllowedPacket should reject values above maximum") { val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") - + intercept[IllegalArgumentException] { dataSource.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE + 1) } @@ -192,7 +192,7 @@ class MySQLDataSourceTest extends FTestPlatform: test("setMaxAllowedPacket should reject zero value") { val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") - + intercept[IllegalArgumentException] { dataSource.setMaxAllowedPacket(0) } @@ -200,7 +200,7 @@ class MySQLDataSourceTest extends FTestPlatform: test("setMaxAllowedPacket should reject negative values") { val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") - + intercept[IllegalArgumentException] { dataSource.setMaxAllowedPacket(-1) }